def build_grad_buffers()

in chatlearn/models/megatron/memory_manager/trainer_v1v2.py [0:0]


    def build_grad_buffers(self):
        """
        build grad buffers and related tensors
        """
        if not self._grad_buffers_freed:
            log_rank_0('Call build_grad_buffers when already built. Ignore it.')
            return

        optimizer = self._optimizer
        params_dtype = self._params_dtype
        grad_dtype_to_params = self._grad_dtype_to_params

        # Re-allocate data of grad_buffers, including data of buckets, see Megatron DistributedDataParallel#__init__.
        # Also set `main_grad` for parameters.
        for dtype, buffer in self.get_grad_buffers().items():
            numel_padded = self._grad_buffers_numels[dtype]
            buffer.data = torch.zeros(
                numel_padded,
                dtype=dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )

            if self._megatron_version == MegatronVersion.V2:
                for bucket, numel in zip(buffer.buckets, self._grad_buffers_bucket_sizes[dtype]):
                    bucket.data = buffer.get(torch.Size([numel]), bucket.offset)

            params = grad_dtype_to_params[dtype]
            data_start_index = 0
            for param in params[::-1]:
                if not param.requires_grad:
                    continue
                this_numel = param.data.nelement()
                data_end_index = data_start_index + this_numel
                param.main_grad = buffer.get(param.data.shape, data_start_index)
                data_start_index = data_end_index

        if not self._use_distributed_optimizer:
            self._grad_buffers_freed = False
            return

        # Re-allocate param_buffers, see Megatron DistributedOptimizer#__init__.
        optimizer.param_buffers = []
        for _, _ in enumerate(optimizer.models):
            current_param_buffers = {}
            for dtype, grad_buffer in self.get_grad_buffers().items():
                current_param_buffers[dtype] = []
                if self._megatron_version == MegatronVersion.V2:
                    for bucket in grad_buffer.buckets:
                        try:
                            storage = bucket.data.storage()._untyped()
                        # pylint: disable-next=bare-except
                        except:
                            storage = bucket.data.storage().untyped()

                        param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage)
                        param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()]
                        current_param_buffers[dtype].append(param_buffer)
                elif self._megatron_version == MegatronVersion.V1:
                    try:
                        storage = grad_buffer.data.storage()._untyped()
                    # pylint: disable-next=bare-except
                    except:
                        storage = grad_buffer.data.storage().untyped()
                    param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage)
                    param_buffer = param_buffer[: grad_buffer.numel_padded]
                    current_param_buffers[dtype] = param_buffer
            optimizer.param_buffers.append(current_param_buffers)

        self._grad_buffers_freed = False