in workload_applyer.py [0:0]
def _apply_reduce_scatter(self, item):
group = self.comm_group_info[item.comm_group]
num_elements = item.msg_size // 2
padding_size = (
(group.size() - num_elements % group.size())
if num_elements % group.size()
else 0
)
num_elements = num_elements + padding_size
input_tensor = torch.narrow(self.buffer, 0, 0, num_elements)
group = self.comm_group_info[item.comm_group]
output_tensor_size = input_tensor.numel() // group.size()
group_rank = torch.distributed.get_group_rank(group, self.rank)
output_tensor = torch.narrow(
input_tensor, 0, group_rank * output_tensor_size, output_tensor_size
)
return torch.distributed.reduce_scatter_tensor(
output_tensor, input_tensor, group=group, async_op=False
)