in torchrec/distributed/comm_ops.py [0:0]
def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
myreq = ctx.myreq
a2ai = ctx.a2ai
pg = ctx.pg
world_size = dist.get_world_size(pg)
my_rank = dist.get_rank(pg)
dim_sum_per_rank = a2ai.dim_sum_per_rank
D_local_sum = dim_sum_per_rank[my_rank]
if a2ai.mixed_dim:
(B_local, D_global_sum) = grad_output.shape
sharded_grad_input_sizes = (world_size, B_local, D_local_sum)
else:
(B_local, T_global, D) = grad_output.shape
D_global_sum = T_global * D
grad_output = grad_output.view(B_local, -1)
T_local = D_local_sum // D
sharded_grad_input_sizes = (world_size, B_local, T_local, D)
assert sum(dim_sum_per_rank) == D_global_sum
sharded_grad_output = _recat_pooled_embedding_grad_out(
grad_output.contiguous(),
dim_sum_per_rank,
)
sharded_grad_input = torch.empty(
sharded_grad_input_sizes, device=grad_output.device, dtype=grad_output.dtype
)
with record_function("## alltoall_bwd_single ##"):
req = dist.all_to_all_single(
output=sharded_grad_input,
input=sharded_grad_output,
output_split_sizes=None,
input_split_sizes=[
B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank
],
group=pg,
async_op=True,
)
myreq.req = req
myreq.tensor = sharded_grad_input
# Note - this mismatch is by design! We return sharded_grad_output to allow PyTorch shape matching to proceed correctly.
return (None, None, sharded_grad_output)