in utils/benchmark_logger.py [0:0]
def log_timing(self, name):
def decorator(func):
def wrapper(*args, **kwargs):
self.timer.start()
result = func(*args, **kwargs)
elapsed_time_ms = self.timer.stop()
log_item = next((item for item in args if isinstance(item, LogItem)))
if log_item.additional == 'overlap':
log_item.elapsed_time = 0
else:
log_item.elapsed_time = elapsed_time_ms
self.comm_log.add_comm_log(log_item)
if torch.distributed.get_rank() == 0:
logger.info(log_item.view_as_ds_log())
return result
return wrapper
return decorator