in src/nanotron/parallel/tensor_parallel/functional.py [0:0]
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
handle: Optional[dist.Work] = None
sharded_batch_size, *rest_size = grad_output.shape
if group.size() == 1:
total_grad_output = grad_output
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_output = MemoryBuffer().get(
"allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
# total_grad_output: [b, s, h_out]
# weight: [h_out, h_in/n]
# total_grad_tensor: [b, s, h_in/n]
# grad_output: [b/n, s, h_out]
sharded_batch_size, *rest_size_grad_output = grad_output.shape
rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]]
if group.size() == 1:
total_grad_tensor = grad_output.matmul(weight)
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_tensor = torch.empty(
unsharded_batch_size,
*rest_size_grad_tensor,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
)
before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split(
total_grad_tensor,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# compute local shard
torch.mm(
input=grad_output.view(-1, grad_output.shape[-1]),
mat2=weight,
out=same_device_shard_grad_tensor.view(-1, weight.shape[1]),
)
if handle is not None:
handle.wait()
before_shard_grad_output, _, after_shard_grad_output = torch.split(
total_grad_output,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# before shard compute
if before_shard_grad_tensor.numel() > 0:
torch.mm(
input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]),
mat2=weight,
out=before_shard_grad_tensor.view(-1, weight.shape[1]),
)
# after shard compute
if after_shard_grad_tensor.numel() > 0:
torch.mm(
input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]),
mat2=weight,
out=after_shard_grad_tensor.view(-1, weight.shape[1]),
)
# Convert the tensor shapes to 2D for execution compatibility
tensor = tensor.contiguous()
tensor_first_dims, tensor_last_dim = tensor.shape[:-1], tensor.shape[-1]
tensor = tensor.view(math.prod(tensor_first_dims), tensor_last_dim)
# Convert the tensor shapes to 2D for execution compatibility
total_grad_output_first_dims, total_grad_output_last_dim = (
total_grad_output.shape[:-1],
total_grad_output.shape[-1],
)
total_grad_output = total_grad_output.view(math.prod(total_grad_output_first_dims), total_grad_output_last_dim)
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None