def backward()

in torchrec/distributed/comm_ops.py [0:0]


    def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
        myreq = ctx.myreq
        a2ai = ctx.a2ai
        pg = ctx.pg
        world_size = dist.get_world_size(pg)
        my_rank = dist.get_rank(pg)
        dim_sum_per_rank = a2ai.dim_sum_per_rank

        D_local_sum = dim_sum_per_rank[my_rank]
        if a2ai.mixed_dim:
            (B_local, D_global_sum) = grad_output.shape
            sharded_grad_input_sizes = (world_size, B_local, D_local_sum)
        else:
            (B_local, T_global, D) = grad_output.shape
            D_global_sum = T_global * D
            grad_output = grad_output.view(B_local, -1)
            T_local = D_local_sum // D
            sharded_grad_input_sizes = (world_size, B_local, T_local, D)
        assert sum(dim_sum_per_rank) == D_global_sum

        sharded_grad_output = _recat_pooled_embedding_grad_out(
            grad_output.contiguous(),
            dim_sum_per_rank,
        )

        sharded_grad_input = torch.empty(
            sharded_grad_input_sizes, device=grad_output.device, dtype=grad_output.dtype
        )
        with record_function("## alltoall_bwd_single ##"):
            req = dist.all_to_all_single(
                output=sharded_grad_input,
                input=sharded_grad_output,
                output_split_sizes=None,
                input_split_sizes=[
                    B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank
                ],
                group=pg,
                async_op=True,
            )
        myreq.req = req
        myreq.tensor = sharded_grad_input
        # Note - this mismatch is by design! We return sharded_grad_output to allow PyTorch shape matching to proceed correctly.
        return (None, None, sharded_grad_output)