in src/nanotron/optim/gradient_accumulator.py [0:0]
def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
# nonlocal s
# DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation.
# See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details.
dp_cp_pg = state.dp_cp_pg
accumulator = state.accumulator
param_id_to_name = state.param_id_to_name
# Add new incoming gradient
# with torch.cuda.stream(s):
for param, grad in zip(bucket.parameters(), bucket.gradients()):
name = param_id_to_name[id(param)]
fp32_grad_buffer = accumulator.get_grad_buffer(name)
fp32_grad_buffer.add_(grad.view_as(fp32_grad_buffer))
# sync across dp
if dp_cp_pg.size() == 1:
fut = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut
if reduce_scatter:
raise NotImplementedError("Not implemented")
assert hasattr(accumulator, "param_name_to_offsets")
grad_buffer_tensor_list = [
accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
]
device = grad_buffer_tensor_list[0].device
dtype = grad_buffer_tensor_list[0].dtype
output_tensor_list = [
grad_buffer[slice(*accumulator.param_name_to_offsets[param_id_to_name[id(param)]])]
if param_id_to_name[id(param)] in accumulator.param_name_to_offsets
else torch.empty(0, dtype=dtype, device=device)
for grad_buffer, param in zip(grad_buffer_tensor_list, bucket.parameters())
]
input_tensor_lists = [
torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // dp_pg.size())
for grad_buffer in grad_buffer_tensor_list
]
dist.reduce_scatter_coalesced(
output_tensor_list=output_tensor_list,
input_tensor_lists=input_tensor_lists,
op=reduce_op,
group=dp_cp_pg,
async_op=True,
)
else:
grad_buffer_tensor_list = [
accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
]
accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced(
grad_buffer_tensor_list, group=dp_cp_pg, async_op=True, op=reduce_op
)
# we shouldn't wait for this future for the rest of the backward
# with torch.cuda.stream(s):
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
half_grad_bucket = bucket.buffer()
fut.set_result(half_grad_bucket)
return fut # We don't care about the new half grad values, so we return the old ones