in src/nanotron/parallel/tensor_parallel/functional.py [0:0]
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
tp_mode = ctx.tp_mode
handle1: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather:
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
if group.size() == 1:
total_tensor = tensor
else:
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = MemoryBuffer().get(
"allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the tensor gradient computation
total_tensor = unsharded_tensor
else:
total_tensor = tensor
grad_tensor = grad_output.matmul(weight)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)
handle2: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
if group.size() == 1:
sub_grad_tensor = grad_tensor
else:
sub_grad_tensor = torch.empty(
ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# Asynchronous all-reduce
handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
else:
raise ValueError()
grad_bias = grad_output.sum(dim=0) if use_bias else None
if handle1 is not None:
handle1.wait()
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = grad_output.t().matmul(total_tensor)
if handle2 is not None:
handle2.wait()
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return sub_grad_tensor, grad_weight, grad_bias, None, None, None
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
return grad_tensor, grad_weight, grad_bias, None, None, None
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")