chatlearn/models/megatron/memory_manager/base_trainer.py (117 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class and creator function for trainer memory managers."""
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
from chatlearn.utils.flat_tensors import BucketizedFlatTensors
from chatlearn.utils.logger import log_rank_0
from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version
from chatlearn.utils.megatron_import_helper import (
DistributedDataParallel,
MixedPrecisionOptimizer,
DistributedOptimizer,
Float16OptimizerWithFloat16Params,
)
def create_trainer_memory_manager(
model,
optimizer,
use_distributed_optimizer,
accumulate_allreduce_grads_in_fp32,
params_dtype,
bucket_size_mb=0,
) -> 'BaseTrainerMemoryManager':
"""
Create a trainer memory manager based on megatron version.
"""
version = get_megatron_version()
if version in [MegatronVersion.V1, MegatronVersion.V2]:
# pylint: disable-next=import-outside-toplevel
from chatlearn.models.megatron.memory_manager.trainer_v1v2 import TrainerMemoryManagerV1V2
cls = TrainerMemoryManagerV1V2
elif version in [MegatronVersion.V3]:
# pylint: disable-next=import-outside-toplevel
from chatlearn.models.megatron.memory_manager.trainer_v3 import TrainerMemoryManagerV3
cls = TrainerMemoryManagerV3
elif version in [MegatronVersion.V4]:
# pylint: disable-next=import-outside-toplevel
from chatlearn.models.megatron.memory_manager.trainer_v4 import TrainerMemoryManagerV4
cls = TrainerMemoryManagerV4
else:
raise ValueError(f'Unsupported version of Megatron for trainer memory manager: {version}')
return cls(
model,
optimizer,
use_distributed_optimizer,
accumulate_allreduce_grads_in_fp32,
params_dtype,
bucket_size_mb,
)
class BaseTrainerMemoryManager(ABC):
"""
Base class for Megatron trainer memory managers, which provides common routines for all versions, such as
optimizer states offloading, and main weights offloading.
"""
def __init__(
self,
model,
optimizer,
use_distributed_optimizer,
accumulate_allreduce_grads_in_fp32,
params_dtype,
bucket_size_mb=0,
):
self._model = model
self._optimizer = optimizer
self._accumulate_allreduce_grads_in_fp32 = accumulate_allreduce_grads_in_fp32
self._params_dtype = params_dtype
self._use_distributed_optimizer = use_distributed_optimizer
self._bucket_size_mb = bucket_size_mb
assert isinstance(
model, (DistributedDataParallel,)
), f'Only support model type DistributedDataParallel, current type is {str(type(model))}.'
assert isinstance(
optimizer, (MixedPrecisionOptimizer,)
), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.'
# sanity check
if self._use_distributed_optimizer:
assert isinstance(optimizer, DistributedOptimizer)
else:
log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params')
assert isinstance(optimizer, Float16OptimizerWithFloat16Params)
self._main_weights_offloaded = False
self._group_flat_main_weights: Optional[List[BucketizedFlatTensors]] = None
self._megatron_version = get_megatron_version()
def _optimizer_load_state_bucket_into_device(self, device):
"""put the state bucket onto a device"""
state_dict = self._optimizer.optimizer.state_dict()
for tensors in state_dict['state'].values():
keys = list(tensors.keys())
for key in keys:
# compatible with transformer_engine v1.10, state['master_param']=None
if tensors[key] is not None:
tensors[key] = tensors[key].to(device=device, non_blocking=True)
# make sure the loading is finished before returning
torch.cuda.synchronize()
def offload_optimizer_states(self):
"""
offload optimizer states
"""
self._optimizer_load_state_bucket_into_device(device='cpu')
def onload_optimizer_states(self):
"""
onload optimizer states
"""
self._optimizer_load_state_bucket_into_device(device=torch.cuda.current_device())
def _flat_param_groups(self, multi_groups: List[List[List[torch.Tensor]]]):
"""
Flatten parameters in param groups.
"""
return [
BucketizedFlatTensors(group, primary_store_device='cpu', bucket_size_mb=self._bucket_size_mb)
for groups in multi_groups
for group in groups
]
def offload_main_weights(self):
"""
offload main weights
"""
if self._main_weights_offloaded:
log_rank_0('Call offload_main_weights when already offloaded. Ignore it.')
return
if self._group_flat_main_weights is None:
if self._use_distributed_optimizer:
self._group_flat_main_weights = self._flat_param_groups(
[self._optimizer.shard_fp32_from_float16_groups]
)
else:
self._group_flat_main_weights = self._flat_param_groups([self._optimizer.fp32_from_float16_groups])
for flat_main_weights in self._group_flat_main_weights:
flat_main_weights.copy_to_primary_store()
self._main_weights_offloaded = True
def onload_main_weights(self):
"""
onload weights and allocate grads
"""
if not self._main_weights_offloaded:
log_rank_0('Call onload_main_weights when already onloaded. Ignore it.')
return
for flat_main_weights in self._group_flat_main_weights:
flat_main_weights.copy_to_gpu_buffer()
self._main_weights_offloaded = False
@abstractmethod
def offload_weights(self):
"""
offload weights
"""
@abstractmethod
def onload_weights(self):
"""
onload weights
"""
@abstractmethod
def free_grad_buffers(self):
"""
free grad buffers and related tensors
"""
@abstractmethod
def build_grad_buffers(self):
"""
build grad buffers and related tensors
"""