in smdebug/pytorch/hook.py [0:0]
def forward_pre_hook(self, module, inputs):
# Write the gradients of the past step if the writer is still available.
if self.writer is not None:
self._close_writers()
self._close_tb_writer()
if not self.prepared_collections:
# at this point we need all collections to be ready
# this may not be the case at creation of hook
# as user's code after hook might add collections
self._prepare_collections()
self.prepared_collections = True
self._increment_step()
## prepararing for step metrics
# last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
# or last operation can be backward (train backward loop just finished and we are at forward again)
# we will log all outstanding forward and backward events
self.log_outstanding_timeline_metrics()
self.step_event = self._TraceEventData(
phase="Step:" + str(self.mode),
op_name="Step:" + str(self.mode),
start_time=time.time(),
dur=0, # end time of step_event will be updated every time a forward event or backward is called after this
pid=os.getpid(),
step_num=str(self.mode_steps[self.mode]),
)
self.parent_forward_event = self._TraceEventData(
phase="Forward",
op_name=module._module_name,
start_time=time.time(),
dur=0, # end time of parent_forward_event will be updated every time a forward event is called after this
pid=os.getpid(),
step_num=str(self.mode_steps[self.mode]),
)
self.profiler_config_parser.load_config()
self.profiler_config_parser.handle_step_start_python_profiling(self.mode, self.step)
if (
self.autograd_profiler_enabled
and not self.profiler_config_parser.config.detailed_profiling_config.disabled
):
self._collect_torch_profiling_data_if_profiler_enabled()
# should we re-enable profiling for this step?
if (
self.profiler_config_parser.should_save_metrics(
MetricsCategory.DETAILED_PROFILING, self.step
)
and not self.autograd_profiler_enabled
):
self.autograd_profiler_enabled = True
if is_pt_1_5():
torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False))
self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
elif is_pt_1_7():
torch.autograd._enable_profiler(
torch.autograd.ProfilerConfig(self.profiler, False, False, False)
)
self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
elif is_pt_1_8():
torch.autograd._enable_profiler_legacy(
torch.autograd.ProfilerConfig(self.profiler, False, False, False, False)
)
self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
elif is_pt_1_6():
torch.autograd._enable_profiler(
torch.autograd.ProfilerConfig(self.profiler, False, False)
)
self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
else:
self.logger.warn(
f"The detailed profiling using autograd profiler is not supported for torch version "
f"{torch.__version__}"
)
self.autograd_profiler_enabled = False
if self.is_smdataparallel_profiling:
# Stop smdataparallel profiling at end step
stop_smdataparallel_profiler(
smdataparallel, self.profiler_config_parser.config.local_path
)
self.is_smdataparallel_profiling = False
if self.profiler_config_parser.should_save_metrics(
MetricsCategory.SMDATAPARALLEL_PROFILING, self.step
):
start_smdataparallel_profiler(
smdataparallel, self.profiler_config_parser.config.local_path
)
self.is_smdataparallel_profiling = True
if self._get_collections_to_save_for_step():
self._initialize_writers()
self._log_params(module)
if self.last_saved_step is not None and not self.exported_collections:
self.export_collections()
self.exported_collections = True
self.first_forward_submodule_name = None