in optimum/neuron/utils/training_utils.py [0:0]
def _get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"]):
"""Counts the number of parameters of the model."""
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_size,
get_tensor_model_parallel_size,
model_parallel_is_initialized,
)
from neuronx_distributed.pipeline import NxDPPModel
from neuronx_distributed.pipeline.partition import analyze_shared_weights_across_stages
if isinstance(model, NxDPPModel):
named_parameters = model.local_named_parameters()
shared = analyze_shared_weights_across_stages(model.traced_model, model.partitions)
shared_parameters_across_pipeline_stages = {
t[0]: t[1] for shared_parameter_info in shared for t in shared_parameter_info
}
else:
named_parameters = model.named_parameters()
shared_parameters_across_pipeline_stages = {}
# We make sure `named_parameters` is not an iterator because we are going to iterate over it twice.
named_parameters = list(named_parameters)
if torch.distributed.is_initialized() and model_parallel_is_initialized():
tp_size = get_tensor_model_parallel_size()
pp_size = get_pipeline_model_parallel_size()
pp_rank = get_pipeline_model_parallel_rank()
else:
tp_size = 1
pp_size = 1
pp_rank = 0
def numel(parameter_name, parameter) -> int:
should_count_param = shared_parameters_across_pipeline_stages.get(parameter_name, pp_rank) == pp_rank
num_elements = parameter.numel()
if getattr(parameter, "tensor_model_parallel", False):
num_elements *= tp_size
if parameter.__class__.__name__ == "Params4bit":
if hasattr(parameter, "element_size"):
num_bytes = parameter.element_size()
elif not hasattr(parameter, "quant_storage"):
num_bytes = 1
else:
num_bytes = parameter.quant_storage.itemsize
num_elements = num_elements * 2 * num_bytes
return num_elements if should_count_param else 0
def reduce_param_count_over_pp_ranks(param_count: int):
param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device())
param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
xm.mark_step()
param_count = int(param_count.detach().cpu().item())
return param_count
all_param_count = sum(numel(n, p) for n, p in named_parameters)
trainable_param_count = sum(numel(n, p) for n, p in named_parameters if p.requires_grad)
if pp_size > 1:
all_param_count = reduce_param_count_over_pp_ranks(all_param_count)
trainable_param_count = reduce_param_count_over_pp_ranks(trainable_param_count)
return trainable_param_count, all_param_count