def _get_model_param_count()

in optimum/neuron/utils/training_utils.py [0:0]


def _get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"]):
    """Counts the number of parameters of the model."""
    import torch_xla.core.xla_model as xm
    from neuronx_distributed.parallel_layers.parallel_state import (
        get_pipeline_model_parallel_group,
        get_pipeline_model_parallel_rank,
        get_pipeline_model_parallel_size,
        get_tensor_model_parallel_size,
        model_parallel_is_initialized,
    )
    from neuronx_distributed.pipeline import NxDPPModel
    from neuronx_distributed.pipeline.partition import analyze_shared_weights_across_stages

    if isinstance(model, NxDPPModel):
        named_parameters = model.local_named_parameters()
        shared = analyze_shared_weights_across_stages(model.traced_model, model.partitions)
        shared_parameters_across_pipeline_stages = {
            t[0]: t[1] for shared_parameter_info in shared for t in shared_parameter_info
        }
    else:
        named_parameters = model.named_parameters()
        shared_parameters_across_pipeline_stages = {}

    # We make sure `named_parameters` is not an iterator because we are going to iterate over it twice.
    named_parameters = list(named_parameters)

    if torch.distributed.is_initialized() and model_parallel_is_initialized():
        tp_size = get_tensor_model_parallel_size()
        pp_size = get_pipeline_model_parallel_size()
        pp_rank = get_pipeline_model_parallel_rank()
    else:
        tp_size = 1
        pp_size = 1
        pp_rank = 0

    def numel(parameter_name, parameter) -> int:
        should_count_param = shared_parameters_across_pipeline_stages.get(parameter_name, pp_rank) == pp_rank

        num_elements = parameter.numel()
        if getattr(parameter, "tensor_model_parallel", False):
            num_elements *= tp_size

        if parameter.__class__.__name__ == "Params4bit":
            if hasattr(parameter, "element_size"):
                num_bytes = parameter.element_size()
            elif not hasattr(parameter, "quant_storage"):
                num_bytes = 1
            else:
                num_bytes = parameter.quant_storage.itemsize
            num_elements = num_elements * 2 * num_bytes

        return num_elements if should_count_param else 0

    def reduce_param_count_over_pp_ranks(param_count: int):
        param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device())
        param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
        xm.mark_step()
        param_count = int(param_count.detach().cpu().item())
        return param_count

    all_param_count = sum(numel(n, p) for n, p in named_parameters)
    trainable_param_count = sum(numel(n, p) for n, p in named_parameters if p.requires_grad)
    if pp_size > 1:
        all_param_count = reduce_param_count_over_pp_ranks(all_param_count)
        trainable_param_count = reduce_param_count_over_pp_ranks(trainable_param_count)

    return trainable_param_count, all_param_count