chatlearn/models/megatron/memory_manager/trainer_v1v2.py (215 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer Memery manager for Megatron V1 and V2"""
from typing import List, Optional
import torch
from chatlearn.models.megatron.memory_manager.base_trainer import BaseTrainerMemoryManager
from chatlearn.utils.flat_tensors import BucketizedFlatTensors
from chatlearn.utils.logger import log_rank_0
from chatlearn.utils.megatron_import_helper import tensor_parallel
from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, check_megatron_versions
check_megatron_versions([MegatronVersion.V1, MegatronVersion.V2])
__all__ = ['TrainerMemoryManagerV1V2']
class TrainerMemoryManagerV1V2(BaseTrainerMemoryManager):
"""
Memory manager for Megatron V1 and V2 trainer modules.
"""
def __init__(
self,
model,
optimizer,
use_distributed_optimizer,
accumulate_allreduce_grads_in_fp32,
params_dtype,
bucket_size_mb=0,
):
super().__init__(
model,
optimizer,
use_distributed_optimizer,
accumulate_allreduce_grads_in_fp32,
params_dtype,
bucket_size_mb,
)
self._weights_offloaded = False
self._grad_buffers_freed = False
self._grad_dtype_to_params = self._get_grad_dtype_to_params(model, accumulate_allreduce_grads_in_fp32)
self._group_flat_weights: Optional[List[BucketizedFlatTensors]] = None
self._grad_buffers_numels = None
self._grad_buffers_bucket_sizes = None
def get_grad_buffers(self):
if self._megatron_version == MegatronVersion.V2:
return self._model.grad_buffers
elif self._megatron_version == MegatronVersion.V1:
return self._model._grad_buffers
@staticmethod
def _get_grad_dtype_to_params(model, accumulate_allreduce_grads_in_fp32):
# Group parameters by their gradient type.
grad_dtype_to_params = {}
for _, param in model.module.named_parameters():
if param.requires_grad and getattr(param, 'allreduce', True):
param.grad_added_to_main_grad = False
dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype
params = grad_dtype_to_params.get(dtype, [])
params.append(param)
grad_dtype_to_params[dtype] = params
return grad_dtype_to_params
def offload_weights(self):
"""
offload weights
"""
if self._weights_offloaded:
log_rank_0('Call offload_weights when already offloaded. Ignore it.')
return
optimizer = self._optimizer
if self._use_distributed_optimizer:
optimizer.shard_float16_groups.clear()
optimizer.shard_fp32_groups.clear()
if self._group_flat_weights is None:
if self._use_distributed_optimizer:
self._group_flat_weights = self._flat_param_groups(
[
optimizer.model_float16_groups,
optimizer.model_fp32_groups,
],
)
else:
self._group_flat_weights = self._flat_param_groups(
[
optimizer.float16_groups,
optimizer.fp32_from_fp32_groups,
],
)
for flat_weights in self._group_flat_weights:
flat_weights.copy_to_primary_store()
self._model.grad_accs.clear()
self._weights_offloaded = True
def onload_weights(self):
"""
onload weights
"""
if not self._weights_offloaded:
log_rank_0('Call onload_weights when already onloaded. Ignore it.')
return
optimizer = self._optimizer
for flat_weights in self._group_flat_weights:
flat_weights.copy_to_gpu_buffer()
model = self._model
# Re-register grad acc hooks, see Megatron DistributedDataParallel#__init__.
model.grad_accs = []
for param in model.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
if self._megatron_version == MegatronVersion.V2:
grad_acc.register_hook(model._make_param_hook(param, model.param_to_grad_buffer))
elif self._megatron_version == MegatronVersion.V1:
grad_acc.register_hook(model._make_param_hook(param))
model.grad_accs.append(grad_acc)
if not self._use_distributed_optimizer:
self._weights_offloaded = False
return
shard_float16_groups = optimizer.shard_float16_groups
shard_fp32_groups = optimizer.shard_fp32_groups
param_gbuf_map = optimizer.model_param_gbuf_map
opt_group_ranges = optimizer.opt_group_ranges
model_gbuf_ranges = optimizer.model_gbuf_ranges
# Rebuild shard_float16_groups and shard_fp32_groups,
# see Megatron DistributedOptimizer#build_model_and_main_param_groups.
for _, group_range in enumerate(opt_group_ranges):
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
for model_param in group_range["params"]:
assert model_param.requires_grad
if self._megatron_version == MegatronVersion.V2:
model_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
elif self._megatron_version == MegatronVersion.V1:
model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"]
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end]
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_float16_params_this_group.append(shard_model_param)
# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_fp32_params_this_group.append(shard_model_param)
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(model_param.type())
)
self._weights_offloaded = False
def free_grad_buffers(self):
"""
free grad buffers and related tensors
"""
if self._grad_buffers_freed:
log_rank_0('Call free_grad_buffers when already freed. Ignore it.')
return
optimizer = self._optimizer
grad_dtype_to_params = self._grad_dtype_to_params
# This is necessary, but don't know why.
optimizer.zero_grad(True)
if self._use_distributed_optimizer:
# Release param_buffers because they share storage with grad_buffers.
# Note: param_buffers are only available in DistributedOptimizer.
optimizer.param_buffers.clear()
# Release grad_buffers, including buckets in GradBuffer for newer Megatron version.
# Release `main_grad` of parameters.
self._grad_buffers_numels = {}
self._grad_buffers_bucket_sizes = {}
for dtype, buffer in self.get_grad_buffers().items():
for p in grad_dtype_to_params[dtype]:
del p.main_grad
self._grad_buffers_numels[dtype] = buffer.numel_padded
if self._megatron_version == MegatronVersion.V2:
bucket_sizes = []
for bucket in buffer.buckets:
bucket_sizes.append(bucket.data.numel())
bucket.data = None
self._grad_buffers_bucket_sizes[dtype] = bucket_sizes
buffer.data = None
self._grad_buffers_freed = True
def build_grad_buffers(self):
"""
build grad buffers and related tensors
"""
if not self._grad_buffers_freed:
log_rank_0('Call build_grad_buffers when already built. Ignore it.')
return
optimizer = self._optimizer
params_dtype = self._params_dtype
grad_dtype_to_params = self._grad_dtype_to_params
# Re-allocate data of grad_buffers, including data of buckets, see Megatron DistributedDataParallel#__init__.
# Also set `main_grad` for parameters.
for dtype, buffer in self.get_grad_buffers().items():
numel_padded = self._grad_buffers_numels[dtype]
buffer.data = torch.zeros(
numel_padded,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
if self._megatron_version == MegatronVersion.V2:
for bucket, numel in zip(buffer.buckets, self._grad_buffers_bucket_sizes[dtype]):
bucket.data = buffer.get(torch.Size([numel]), bucket.offset)
params = grad_dtype_to_params[dtype]
data_start_index = 0
for param in params[::-1]:
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
param.main_grad = buffer.get(param.data.shape, data_start_index)
data_start_index = data_end_index
if not self._use_distributed_optimizer:
self._grad_buffers_freed = False
return
# Re-allocate param_buffers, see Megatron DistributedOptimizer#__init__.
optimizer.param_buffers = []
for _, _ in enumerate(optimizer.models):
current_param_buffers = {}
for dtype, grad_buffer in self.get_grad_buffers().items():
current_param_buffers[dtype] = []
if self._megatron_version == MegatronVersion.V2:
for bucket in grad_buffer.buckets:
try:
storage = bucket.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = bucket.data.storage().untyped()
param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage)
param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()]
current_param_buffers[dtype].append(param_buffer)
elif self._megatron_version == MegatronVersion.V1:
try:
storage = grad_buffer.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = grad_buffer.data.storage().untyped()
param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage)
param_buffer = param_buffer[: grad_buffer.numel_padded]
current_param_buffers[dtype] = param_buffer
optimizer.param_buffers.append(current_param_buffers)
self._grad_buffers_freed = False