def forward_hook()

in code/custom_hook.py [0:0]


    def forward_hook(self, module, inputs, outputs):
        module_name = self.module_maps[module]   
        self._write_inputs(module_name, inputs)

        outputs.register_hook(self.backward_hook(module_name + "_output"))

        #record running mean and var of BatchNorm layers
        if isinstance(module, torch.nn.BatchNorm2d):
            self._write_outputs(module_name + ".running_mean", module.running_mean)
            self._write_outputs(module_name + ".running_var", module.running_var)

        self._write_outputs(module_name, outputs)
        self.last_saved_step = self.step