in training/trainer.py [0:0]
def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
"""
Prints the model and the number of parameters in the model.
# Multiple packages provide this info in a nice table format
# However, they need us to provide an `input` (as they also write down the output sizes)
# Our models are complex, and a single input is restrictive.
# https://github.com/sksq96/pytorch-summary
# https://github.com/nmhkahn/torchsummaryX
"""
if get_rank() != 0:
return
param_kwargs = {}
trainable_parameters = sum(
p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
)
total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
non_trainable_parameters = total_parameters - trainable_parameters
logging.info("==" * 10)
logging.info(f"Summary for model {type(model)}")
logging.info(f"Model is {model}")
logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
logging.info(
f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
)
logging.info(
f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
)
logging.info("==" * 10)
if log_dir:
output_fpath = os.path.join(log_dir, "model.txt")
with g_pathmgr.open(output_fpath, "w") as f:
print(model, file=f)