def forward_pre_hook()

in smdebug/pytorch/hook.py [0:0]


    def forward_pre_hook(self, module, inputs):
        # Write the gradients of the past step if the writer is still available.
        if self.writer is not None:
            self._close_writers()
        self._close_tb_writer()

        if not self.prepared_collections:
            # at this point we need all collections to be ready
            # this may not be the case at creation of hook
            # as user's code after hook might add collections
            self._prepare_collections()
            self.prepared_collections = True

        self._increment_step()

        ## prepararing for step metrics
        # last operation can be forward( eval loop is running or multiple forward for example RNN can have multiple call to forward of module)
        # or last operation can be backward (train backward loop just finished and we are at forward again)

        # we will log all outstanding forward and backward events
        self.log_outstanding_timeline_metrics()

        self.step_event = self._TraceEventData(
            phase="Step:" + str(self.mode),
            op_name="Step:" + str(self.mode),
            start_time=time.time(),
            dur=0,  # end time of step_event will be updated every time a forward event or backward is called after this
            pid=os.getpid(),
            step_num=str(self.mode_steps[self.mode]),
        )
        self.parent_forward_event = self._TraceEventData(
            phase="Forward",
            op_name=module._module_name,
            start_time=time.time(),
            dur=0,  # end time of parent_forward_event will be updated every time a forward event is called after this
            pid=os.getpid(),
            step_num=str(self.mode_steps[self.mode]),
        )

        self.profiler_config_parser.load_config()
        self.profiler_config_parser.handle_step_start_python_profiling(self.mode, self.step)

        if (
            self.autograd_profiler_enabled
            and not self.profiler_config_parser.config.detailed_profiling_config.disabled
        ):
            self._collect_torch_profiling_data_if_profiler_enabled()

        # should we re-enable profiling for this step?
        if (
            self.profiler_config_parser.should_save_metrics(
                MetricsCategory.DETAILED_PROFILING, self.step
            )
            and not self.autograd_profiler_enabled
        ):
            self.autograd_profiler_enabled = True
            if is_pt_1_5():
                torch.autograd._enable_profiler(torch.autograd.ProfilerConfig(self.profiler, False))
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            elif is_pt_1_7():
                torch.autograd._enable_profiler(
                    torch.autograd.ProfilerConfig(self.profiler, False, False, False)
                )
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            elif is_pt_1_8():
                torch.autograd._enable_profiler_legacy(
                    torch.autograd.ProfilerConfig(self.profiler, False, False, False, False)
                )
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            elif is_pt_1_6():
                torch.autograd._enable_profiler(
                    torch.autograd.ProfilerConfig(self.profiler, False, False)
                )
                self.start_profiler_time_us = time.time() * CONVERT_TO_MICROSECS
            else:
                self.logger.warn(
                    f"The detailed profiling using autograd profiler is not supported for torch version "
                    f"{torch.__version__}"
                )
                self.autograd_profiler_enabled = False

        if self.is_smdataparallel_profiling:
            # Stop smdataparallel profiling at end step
            stop_smdataparallel_profiler(
                smdataparallel, self.profiler_config_parser.config.local_path
            )
        self.is_smdataparallel_profiling = False
        if self.profiler_config_parser.should_save_metrics(
            MetricsCategory.SMDATAPARALLEL_PROFILING, self.step
        ):
            start_smdataparallel_profiler(
                smdataparallel, self.profiler_config_parser.config.local_path
            )
            self.is_smdataparallel_profiling = True

        if self._get_collections_to_save_for_step():
            self._initialize_writers()
            self._log_params(module)

        if self.last_saved_step is not None and not self.exported_collections:
            self.export_collections()
            self.exported_collections = True

        self.first_forward_submodule_name = None