in picotron/tensor_parallel/tp_communications.py [0:0]
def forward(ctx, x):
if pgm.process_group_manager.tp_world_size == 1:
return x
last_dim = x.dim() - 1
# Need contiguous tensors for collectives -> https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L321
x = x.contiguous()
tensor_list = [torch.empty_like(x) for _ in range(pgm.process_group_manager.tp_world_size)]
tensor_list[pgm.process_group_manager.tp_rank] = x
dist.all_gather(tensor_list, x, group=pgm.process_group_manager.tp_group)
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output