in vissl/hooks/__init__.py [0:0]
def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
"""
The utility function that prepares all the hoooks that will be used in training
based on user selection. Some basic hooks are used by default.
Optional hooks:
- Tensorboard hook,
- loss specific hooks (swav loss, deepcluster loss, moco loss) used only when the
loss is being used
- model complexity hook (if user wants to compute model flops, activations, params)
enable the hook via HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY = True
Returns:
hooks (List(functions)): list containing the hook functions that will be used
"""
hooks = []
# conditionally add hooks based on use-case
if cfg.HOOKS.PERF_STATS.MONITOR_PERF_STATS:
perf_stat_freq = (
cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY
if cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY > 0
else None
)
hooks.append(LogPerfTimeMetricsHook(perf_stat_freq))
# add the loss hooks based on the loss being used
hooks = add_loss_hooks(hooks, cfg.LOSS, cfg)
if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY:
hooks.extend([SSLModelComplexityHook()])
if cfg.HOOKS.LOG_GPU_STATS:
hooks.extend([LogGpuStatsHook()])
if cfg.HOOKS.MEMORY_SUMMARY.PRINT_MEMORY_SUMMARY:
hooks.extend([LogGpuMemoryHook(cfg.HOOKS.MEMORY_SUMMARY.LOG_ITERATION_NUM)])
if cfg.HOOKS.MEMORY_SUMMARY.DUMP_MEMORY_ON_EXCEPTION:
hooks.append(DumpMemoryOnException())
if cfg.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD:
assert is_tensorboard_available(), (
"Tensorboard must be installed to use it. Please install tensorboard using:"
"If pip environment: `pip install tensorboard` "
"If using conda and you prefer conda install of tensorboard: "
"`conda install -c conda-forge tensorboard`"
)
tb_hook = get_tensorboard_hook(cfg)
hooks.extend([tb_hook])
if cfg.MODEL.GRAD_CLIP.USE_GRAD_CLIP:
hooks.extend(
[
GradClipHook(
norm_type=cfg.MODEL.GRAD_CLIP.NORM_TYPE,
max_norm=cfg.MODEL.GRAD_CLIP.MAX_NORM,
)
]
)
# hooks that are used irrespective of workflow type
rolling_btime_freq = (
cfg.HOOKS.PERF_STATS.ROLLING_BTIME_FREQ
if cfg.HOOKS.PERF_STATS.ROLLING_BTIME_FREQ > 0
else None
)
if CudaSynchronizeHook.is_enabled(cfg.MODEL):
hooks.append(CudaSynchronizeHook())
if ProfilingHook.is_enabled(cfg.PROFILING):
hooks.append(ProfilingHook(profiling_config=cfg.PROFILING))
world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
checkpoint_folder = get_checkpoint_folder(cfg)
hooks.extend(
[
SetDataSamplerEpochHook(),
FreezeParametersHook(),
LogLossMetricsCheckpointHook(world_size),
LogLossLrEtaHook(checkpoint_folder, rolling_btime_freq),
]
)
if cfg.METERS.model_output_mask:
hooks.extend([ModelOutputMaskHook()])
if cfg.HOOKS.CHECK_NAN:
hooks.extend([CheckNanLossHook(), CheckNanModelOutputHook()])
return hooks