in chatlearn/models/megatron/memory_manager/trainer_v1v2.py [0:0]
def build_grad_buffers(self):
"""
build grad buffers and related tensors
"""
if not self._grad_buffers_freed:
log_rank_0('Call build_grad_buffers when already built. Ignore it.')
return
optimizer = self._optimizer
params_dtype = self._params_dtype
grad_dtype_to_params = self._grad_dtype_to_params
# Re-allocate data of grad_buffers, including data of buckets, see Megatron DistributedDataParallel#__init__.
# Also set `main_grad` for parameters.
for dtype, buffer in self.get_grad_buffers().items():
numel_padded = self._grad_buffers_numels[dtype]
buffer.data = torch.zeros(
numel_padded,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
if self._megatron_version == MegatronVersion.V2:
for bucket, numel in zip(buffer.buckets, self._grad_buffers_bucket_sizes[dtype]):
bucket.data = buffer.get(torch.Size([numel]), bucket.offset)
params = grad_dtype_to_params[dtype]
data_start_index = 0
for param in params[::-1]:
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
param.main_grad = buffer.get(param.data.shape, data_start_index)
data_start_index = data_end_index
if not self._use_distributed_optimizer:
self._grad_buffers_freed = False
return
# Re-allocate param_buffers, see Megatron DistributedOptimizer#__init__.
optimizer.param_buffers = []
for _, _ in enumerate(optimizer.models):
current_param_buffers = {}
for dtype, grad_buffer in self.get_grad_buffers().items():
current_param_buffers[dtype] = []
if self._megatron_version == MegatronVersion.V2:
for bucket in grad_buffer.buckets:
try:
storage = bucket.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = bucket.data.storage().untyped()
param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage)
param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()]
current_param_buffers[dtype].append(param_buffer)
elif self._megatron_version == MegatronVersion.V1:
try:
storage = grad_buffer.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = grad_buffer.data.storage().untyped()
param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage)
param_buffer = param_buffer[: grad_buffer.numel_padded]
current_param_buffers[dtype] = param_buffer
optimizer.param_buffers.append(current_param_buffers)
self._grad_buffers_freed = False