in picotron/utils.py [0:0]
def get_num_params(model):
"""Calculate total number of parameters accounting for tensor parallelism and pipeline parallelism.
For TP: Parameters in attention/mlp/embed/final_proj are sharded, so multiply by tp_world_size
For PP: Need to gather parameter counts across pipeline stages
For DP: Parameters are replicated, so only count once
Note:
FSDP: Parameters are sharded across data parallel ranks
"""
tp_world_size = pgm.process_group_manager.tp_world_size
# Count parameters in current PP rank
local_num_params = 0
for name, param in model.named_parameters():
# Parameters split across TP ranks
# TODO: LayerNorm is also split across TP ranks for sequence parallelism
if any(tp_keyword in name.lower() for tp_keyword in ['attention', 'mlp', 'embed', 'final_proj']):
local_num_params += param.numel() * tp_world_size
else:
# Parameters replicated across TP ranks (layer norm, biases)
local_num_params += param.numel()
# Gather parameter counts from all PP ranks
param_counts = torch.tensor(local_num_params, device='cuda')
# Sum up parameters across all PP ranks
dist.all_reduce(param_counts, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.pp_group)
return param_counts.item()