def fp32_accum_hook()

in src/nanotron/optim/gradient_accumulator.py [0:0]


    def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
        # nonlocal s
        # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation.
        # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details.
        dp_cp_pg = state.dp_cp_pg
        accumulator = state.accumulator
        param_id_to_name = state.param_id_to_name

        # Add new incoming gradient
        # with torch.cuda.stream(s):
        for param, grad in zip(bucket.parameters(), bucket.gradients()):
            name = param_id_to_name[id(param)]
            fp32_grad_buffer = accumulator.get_grad_buffer(name)
            fp32_grad_buffer.add_(grad.view_as(fp32_grad_buffer))

        # sync across dp
        if dp_cp_pg.size() == 1:
            fut = torch.futures.Future()
            fut.set_result(bucket.buffer())
            return fut

        if reduce_scatter:
            raise NotImplementedError("Not implemented")
            assert hasattr(accumulator, "param_name_to_offsets")
            grad_buffer_tensor_list = [
                accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
            ]
            device = grad_buffer_tensor_list[0].device
            dtype = grad_buffer_tensor_list[0].dtype
            output_tensor_list = [
                grad_buffer[slice(*accumulator.param_name_to_offsets[param_id_to_name[id(param)]])]
                if param_id_to_name[id(param)] in accumulator.param_name_to_offsets
                else torch.empty(0, dtype=dtype, device=device)
                for grad_buffer, param in zip(grad_buffer_tensor_list, bucket.parameters())
            ]
            input_tensor_lists = [
                torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // dp_pg.size())
                for grad_buffer in grad_buffer_tensor_list
            ]
            dist.reduce_scatter_coalesced(
                output_tensor_list=output_tensor_list,
                input_tensor_lists=input_tensor_lists,
                op=reduce_op,
                group=dp_cp_pg,
                async_op=True,
            )
        else:
            grad_buffer_tensor_list = [
                accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
            ]
            accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced(
                grad_buffer_tensor_list, group=dp_cp_pg, async_op=True, op=reduce_op
            )
            # we shouldn't wait for this future for the rest of the backward

        # with torch.cuda.stream(s):
        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
        half_grad_bucket = bucket.buffer()
        fut.set_result(half_grad_bucket)
        return fut  # We don't care about the new half grad values, so we return the old ones