in chatlearn/models/megatron/memory_manager/trainer_v1v2.py [0:0]
def onload_weights(self):
"""
onload weights
"""
if not self._weights_offloaded:
log_rank_0('Call onload_weights when already onloaded. Ignore it.')
return
optimizer = self._optimizer
for flat_weights in self._group_flat_weights:
flat_weights.copy_to_gpu_buffer()
model = self._model
# Re-register grad acc hooks, see Megatron DistributedDataParallel#__init__.
model.grad_accs = []
for param in model.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
if self._megatron_version == MegatronVersion.V2:
grad_acc.register_hook(model._make_param_hook(param, model.param_to_grad_buffer))
elif self._megatron_version == MegatronVersion.V1:
grad_acc.register_hook(model._make_param_hook(param))
model.grad_accs.append(grad_acc)
if not self._use_distributed_optimizer:
self._weights_offloaded = False
return
shard_float16_groups = optimizer.shard_float16_groups
shard_fp32_groups = optimizer.shard_fp32_groups
param_gbuf_map = optimizer.model_param_gbuf_map
opt_group_ranges = optimizer.opt_group_ranges
model_gbuf_ranges = optimizer.model_gbuf_ranges
# Rebuild shard_float16_groups and shard_fp32_groups,
# see Megatron DistributedOptimizer#build_model_and_main_param_groups.
for _, group_range in enumerate(opt_group_ranges):
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
for model_param in group_range["params"]:
assert model_param.requires_grad
if self._megatron_version == MegatronVersion.V2:
model_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
elif self._megatron_version == MegatronVersion.V1:
model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"]
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end]
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_float16_params_this_group.append(shard_model_param)
# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_fp32_params_this_group.append(shard_model_param)
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(model_param.type())
)
self._weights_offloaded = False