in modules/SwissArmyTransformer/sat/mpu/operation.py [0:0]
def mp_split_model_rank0(model, model_full, use_node_group=True):
"""
This function loads partitions from rank 0.
It takes less memory when world size is large.
"""
group = get_node_group() if use_node_group else get_model_parallel_group()
src = get_node_src_rank() if use_node_group else get_model_parallel_src_rank()
local_world_size = get_node_world_size() if use_node_group else get_model_parallel_world_size()
def iter_repartition(new_model, module):
for (new_name, sub_new_model), (name, sub_module) in zip(new_model.named_children(), module.named_children()):
if isinstance(sub_module, (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)):
new_weights, new_biases = sub_module.partition()
for i in range(local_world_size):
if i == 0:
sub_new_model.weight.data.copy_(new_weights[src%len(new_weights)])
else:
torch.distributed.send(new_weights[(src+i)%len(new_weights)].cuda(), src+i)
if new_biases:
for i in range(local_world_size):
if i == 0:
sub_new_model.bias.data.copy_(new_biases[src%len(new_weights)])
else:
torch.distributed.send(new_biases[(src+i)%len(new_biases)].cuda(), src+i)
else:
for (nn, np), (n, p) in zip(sub_new_model.named_parameters(recurse=False), sub_module.named_parameters(recurse=False)):
np.data.copy_(torch.clone(p.data).detach())
torch.distributed.broadcast(np.data, src, group=group)
iter_repartition(sub_new_model, sub_module)
iter_repartition(model, model_full)