def mp_split_model_rank0()

in modules/SwissArmyTransformer/sat/mpu/operation.py [0:0]


def mp_split_model_rank0(model, model_full, use_node_group=True):
    """
    This function loads partitions from rank 0.
    It takes less memory when world size is large.
    """
    group = get_node_group() if use_node_group else get_model_parallel_group()
    src = get_node_src_rank() if use_node_group else get_model_parallel_src_rank()
    local_world_size = get_node_world_size() if use_node_group else get_model_parallel_world_size()
    def iter_repartition(new_model, module):
        for (new_name, sub_new_model), (name, sub_module) in zip(new_model.named_children(), module.named_children()):
            if isinstance(sub_module, (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)):
                new_weights, new_biases = sub_module.partition()
                for i in range(local_world_size):
                    if i == 0:
                        sub_new_model.weight.data.copy_(new_weights[src%len(new_weights)])
                    else:
                        torch.distributed.send(new_weights[(src+i)%len(new_weights)].cuda(), src+i)
                if new_biases:
                    for i in range(local_world_size):
                        if i == 0:
                            sub_new_model.bias.data.copy_(new_biases[src%len(new_weights)])
                        else:
                            torch.distributed.send(new_biases[(src+i)%len(new_biases)].cuda(), src+i)
            else:
                for (nn, np), (n, p) in zip(sub_new_model.named_parameters(recurse=False), sub_module.named_parameters(recurse=False)):
                    np.data.copy_(torch.clone(p.data).detach())
                    torch.distributed.broadcast(np.data, src, group=group)
            iter_repartition(sub_new_model, sub_module)
    iter_repartition(model, model_full)