in code/custom_hook.py [0:0]
def forward_hook(self, module, inputs, outputs):
module_name = self.module_maps[module]
self._write_inputs(module_name, inputs)
outputs.register_hook(self.backward_hook(module_name + "_output"))
#record running mean and var of BatchNorm layers
if isinstance(module, torch.nn.BatchNorm2d):
self._write_outputs(module_name + ".running_mean", module.running_mean)
self._write_outputs(module_name + ".running_var", module.running_var)
self._write_outputs(module_name, outputs)
self.last_saved_step = self.step