def backward()

in src/nanotron/parallel/tensor_parallel/functional.py [0:0]


    def backward(ctx, grad_output):
        tensor, weight = ctx.saved_tensors
        group = ctx.group
        use_bias = ctx.use_bias

        handle: Optional[dist.Work] = None

        sharded_batch_size, *rest_size = grad_output.shape

        if group.size() == 1:
            total_grad_output = grad_output
        else:
            unsharded_batch_size = sharded_batch_size * group.size()

            total_grad_output = MemoryBuffer().get(
                "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
            )

            # Doing gather + slicing during the NeMo forward pass can make this tensor
            # not be contiguous. PyTorch only checks if the tensor is contiguous, and only
            # clones it if it's not contiguous:
            # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
            grad_output = grad_output.contiguous()

            handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)

        # total_grad_output: [b, s, h_out]
        # weight: [h_out, h_in/n]
        # total_grad_tensor: [b, s, h_in/n]
        # grad_output: [b/n, s, h_out]
        sharded_batch_size, *rest_size_grad_output = grad_output.shape
        rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]]

        if group.size() == 1:
            total_grad_tensor = grad_output.matmul(weight)
        else:
            unsharded_batch_size = sharded_batch_size * group.size()
            total_grad_tensor = torch.empty(
                unsharded_batch_size,
                *rest_size_grad_tensor,
                device=grad_output.device,
                dtype=grad_output.dtype,
                requires_grad=False,
            )
            before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split(
                total_grad_tensor,
                split_size_or_sections=[
                    sharded_batch_size * dist.get_rank(group),
                    sharded_batch_size,
                    sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
                ],
                dim=0,
            )
            # compute local shard
            torch.mm(
                input=grad_output.view(-1, grad_output.shape[-1]),
                mat2=weight,
                out=same_device_shard_grad_tensor.view(-1, weight.shape[1]),
            )

            if handle is not None:
                handle.wait()

            before_shard_grad_output, _, after_shard_grad_output = torch.split(
                total_grad_output,
                split_size_or_sections=[
                    sharded_batch_size * dist.get_rank(group),
                    sharded_batch_size,
                    sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
                ],
                dim=0,
            )

            # before shard compute
            if before_shard_grad_tensor.numel() > 0:
                torch.mm(
                    input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]),
                    mat2=weight,
                    out=before_shard_grad_tensor.view(-1, weight.shape[1]),
                )
            # after shard compute
            if after_shard_grad_tensor.numel() > 0:
                torch.mm(
                    input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]),
                    mat2=weight,
                    out=after_shard_grad_tensor.view(-1, weight.shape[1]),
                )

        # Convert the tensor shapes to 2D for execution compatibility
        tensor = tensor.contiguous()
        tensor_first_dims, tensor_last_dim = tensor.shape[:-1], tensor.shape[-1]
        tensor = tensor.view(math.prod(tensor_first_dims), tensor_last_dim)

        # Convert the tensor shapes to 2D for execution compatibility
        total_grad_output_first_dims, total_grad_output_last_dim = (
            total_grad_output.shape[:-1],
            total_grad_output.shape[-1],
        )
        total_grad_output = total_grad_output.view(math.prod(total_grad_output_first_dims), total_grad_output_last_dim)

        # TODO @thomasw21: This sounds like we don't have the optimal physical layout
        grad_weight = total_grad_output.t().matmul(tensor)
        grad_bias = total_grad_output.sum(dim=0) if use_bias else None

        return total_grad_tensor, grad_weight, grad_bias, None, None