in step7_pipeline_parallel_afab/tensor_parallel.py [0:0]
def forward(ctx, input):
if pgm.process_group_manager.tp_world_size == 1:
return input
last_dim = input.dim() - 1
# Need contiguous tensors for collectives -> https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L321
input = input.contiguous()
tensor_list = [torch.empty_like(input) for _ in range(pgm.process_group_manager.tp_world_size)]
tensor_list[pgm.process_group_manager.tp_rank] = input
dist.all_gather(tensor_list, input, group=pgm.process_group_manager.tp_group)
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output