in chatlearn/runtime/decorator.py [0:0]
def timeit(func, func_name):
def inner(self, *args, **kwargs):
if self.runtime_args.nsys:
nvtx.range_push(func_name)
if self.is_last_rank():
# for the class inherited from base, it may call multiple times, so use the first start time
if not self.timers(func_name).started_:
self.timers(func_name).start()
ret = func(self, *args, **kwargs)
self.timers(func_name).stop()
else:
ret = func(self, *args, **kwargs)
if self.profiler is not None and self._iteration > 0 and self._iteration <=2 and self.replica_id == 0 \
and func_name in ["forward_step", "train_step"]:
self.profiler.step()
if self.profiler is not None and self._iteration ==3 and self.replica_id == 0 and func_name in ["forward_step", "train_step"]:
self.profiler.stop()
if self.runtime_args.nsys:
nvtx.range_pop()
return ret
return inner