in src/nanotron/parallel/tensor_parallel/functional.py [0:0]
def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
ctx.use_bias = bias is not None
ctx.tp_mode = tp_mode
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.tensor_shape = tensor.size()
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
group_size = group.size()
current_rank = dist.get_rank(group)
if group_size == 1:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
else:
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
# ctx.save_for_backward(tensor, weight)
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
gathered_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
gathered_tensor = MemoryBuffer().get(
"allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
)
else:
gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)
# Compute a shard of column_linear in the same time of AllGather
# We could compute the matmul of current holding shard and the current rank's weight
# We assume that rank 0 holds w0, rank 1 holds w1, etc.
# weights: w0 w1 w2 w3
# rank 0: X - - -
# rank 1: - X - -
# rank 2: - - X -
# rank 3: - - - X
# We call the corresponding shard of output "same_device_shard"
output_size = weight.shape[0]
gathered_output = torch.empty(
gathered_batch_size,
*intermediate_size,
output_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
before_shard, same_device_shard, after_shard = torch.split(
gathered_output,
split_size_or_sections=[
sharded_batch_size * current_rank,
sharded_batch_size,
sharded_batch_size * (group_size - current_rank - 1),
],
dim=0,
)
first_dims = math.prod([sharded_batch_size, *intermediate_size])
if bias is None:
torch.mm(
input=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
# Wait communication
handle.wait()
if tp_recompute_allgather:
ctx.save_for_backward(tensor, weight)
else:
ctx.save_for_backward(gathered_tensor, weight)
# Compute all the other shards that are obtained from AllGather
# weights: w0 w1 w2 w3
# rank 0: - X X X
# rank 1: X - X X
# rank 2: X X - X
# rank 3: X X X -
# As they could be not contiguous (r1 and r2) vertically as they are separated by "same_device_shard"
# We need to compute them separately, i.e. "before_shard" and "after_shard"
# For r0, "before_shard" is empty. For r3, "after_shard" is empty.
if before_shard.numel() > 0:
first_dims = math.prod(before_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
if after_shard.numel() > 0:
first_dims = math.prod(after_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
return gathered_output
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")