| import torch.cuda | |
| class Metric: | |
| """ | |
| Dumb utility to collect and report average wall-time metrics. | |
| """ | |
| def __init__(self, label): | |
| self.label = label | |
| self.measurements = [] | |
| def collect(self, measurement): | |
| self.measurements.append(measurement) | |
| def get_measurements(self): | |
| return self.measurements[:] | |
| def report(self): | |
| print( | |
| self.label, | |
| torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0), | |
| ) | |
| def monitor_method_cuda_wall_times(metric, obj, methodname): | |
| """ | |
| Measure timings for a method on an object or class. | |
| For instance: | |
| >>> metric = Metric('!LNORM') | |
| >>> monitor_method_wall_times(metric, LayerNorm, 'forward') | |
| """ | |
| oldmeth = getattr(obj, methodname) | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| def newmeth(*args, **kw): | |
| start_event.record() | |
| try: | |
| return oldmeth(*args, **kw) | |
| finally: | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed = start_event.elapsed_time(end_event) | |
| metric.collect(elapsed) | |
| metric.report() | |
| setattr(obj, methodname, newmeth) | |