def forward()

in src/nanotron/parallel/tensor_parallel/functional.py [0:0]


    def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
        ctx.use_bias = bias is not None
        ctx.tp_mode = tp_mode
        ctx.group = group
        ctx.tp_recompute_allgather = tp_recompute_allgather
        ctx.tensor_shape = tensor.size()

        if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
            gathered_tensor = tensor
            ctx.save_for_backward(tensor, weight)
            return F.linear(gathered_tensor, weight, bias)
        elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
            group_size = group.size()
            current_rank = dist.get_rank(group)
            if group_size == 1:
                gathered_tensor = tensor
                ctx.save_for_backward(tensor, weight)
                return F.linear(gathered_tensor, weight, bias)
            else:
                # `tensor` can sometimes not be contiguous
                # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
                tensor = tensor.contiguous()
                # ctx.save_for_backward(tensor, weight)

                # TODO @thomasw21: gather along another dimension
                sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
                if group is None:
                    group = dist.distributed_c10d._get_default_group()
                gathered_batch_size = sharded_batch_size * group.size()

                if tp_recompute_allgather:
                    gathered_tensor = MemoryBuffer().get(
                        "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
                    )
                else:
                    gathered_tensor = torch.empty(
                        gathered_batch_size,
                        *intermediate_size,
                        hidden_size,
                        device=tensor.device,
                        dtype=tensor.dtype,
                        requires_grad=False,
                    )

                handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)

                # Compute a shard of column_linear in the same time of AllGather
                # We could compute the matmul of current holding shard and the current rank's weight
                # We assume that rank 0 holds w0, rank 1 holds w1, etc.
                # weights: w0 w1 w2 w3
                # rank 0:  X  -  -  -
                # rank 1:  -  X  -  -
                # rank 2:  -  -  X  -
                # rank 3:  -  -  -  X
                # We call the corresponding shard of output "same_device_shard"
                output_size = weight.shape[0]
                gathered_output = torch.empty(
                    gathered_batch_size,
                    *intermediate_size,
                    output_size,
                    device=tensor.device,
                    dtype=tensor.dtype,
                    requires_grad=tensor.requires_grad,
                )
                before_shard, same_device_shard, after_shard = torch.split(
                    gathered_output,
                    split_size_or_sections=[
                        sharded_batch_size * current_rank,
                        sharded_batch_size,
                        sharded_batch_size * (group_size - current_rank - 1),
                    ],
                    dim=0,
                )
                first_dims = math.prod([sharded_batch_size, *intermediate_size])
                if bias is None:
                    torch.mm(
                        input=tensor.view(first_dims, hidden_size),
                        mat2=weight.t(),
                        out=same_device_shard.view(first_dims, output_size),
                    )
                else:
                    torch.addmm(
                        input=bias[None, :],
                        mat1=tensor.view(first_dims, hidden_size),
                        mat2=weight.t(),
                        out=same_device_shard.view(first_dims, output_size),
                    )

                # Wait communication
                handle.wait()
                if tp_recompute_allgather:
                    ctx.save_for_backward(tensor, weight)
                else:
                    ctx.save_for_backward(gathered_tensor, weight)

                # Compute all the other shards that are obtained from AllGather
                # weights: w0 w1 w2 w3
                # rank 0:  -  X  X  X
                # rank 1:  X  -  X  X
                # rank 2:  X  X  -  X
                # rank 3:  X  X  X  -
                # As they could be not contiguous (r1 and r2) vertically as they are separated by "same_device_shard"
                # We need to compute them separately, i.e. "before_shard" and "after_shard"
                # For r0, "before_shard" is empty. For r3, "after_shard" is empty.
                if before_shard.numel() > 0:
                    first_dims = math.prod(before_shard.shape[:-1])
                    if bias is None:
                        torch.mm(
                            input=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
                            mat2=weight.t(),
                            out=before_shard.view(first_dims, output_size),
                        )
                    else:
                        torch.addmm(
                            input=bias[None, :],
                            mat1=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
                            mat2=weight.t(),
                            out=before_shard.view(first_dims, output_size),
                        )
                if after_shard.numel() > 0:
                    first_dims = math.prod(after_shard.shape[:-1])
                    if bias is None:
                        torch.mm(
                            input=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
                                first_dims, hidden_size
                            ),
                            mat2=weight.t(),
                            out=after_shard.view(first_dims, output_size),
                        )
                    else:
                        torch.addmm(
                            input=bias[None, :],
                            mat1=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
                                first_dims, hidden_size
                            ),
                            mat2=weight.t(),
                            out=after_shard.view(first_dims, output_size),
                        )

                return gathered_output
        else:
            raise ValueError(f"Got unexpected mode: {tp_mode}.")