chatlearn/models/megatron/memory_manager/trainer_v3.py (178 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Trainer Memery manager for Megatron V3""" from typing import List, Optional import torch from chatlearn.models.megatron.memory_manager.base_trainer import BaseTrainerMemoryManager from chatlearn.utils.flat_tensors import BucketizedFlatTensors, FlatTensors from chatlearn.utils.logger import log_rank_0 from chatlearn.utils.megatron_import_helper import tensor_parallel from chatlearn.utils.megatron_import_memory_helper import BufferType from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, check_megatron_versions check_megatron_versions([MegatronVersion.V3]) __all__ = ['TrainerMemoryManagerV3'] class TrainerMemoryManagerV3(BaseTrainerMemoryManager): """ Memory manager for Megatron V3 trainer modules. """ def __init__( self, model, optimizer, use_distributed_optimizer, accumulate_allreduce_grads_in_fp32, params_dtype, bucket_size_mb=0, ): super().__init__( model, optimizer, use_distributed_optimizer, accumulate_allreduce_grads_in_fp32, params_dtype, bucket_size_mb, ) self._weights_offloaded = False self._grad_buffers_freed = False self._buffers = self._get_buffers(model) self._group_flat_weights: Optional[List[BucketizedFlatTensors]] = None @staticmethod def _get_buffers(model): processed_buffers = set() buffers = [] for _, buffer in model.param_to_buffer.items(): if buffer not in processed_buffers: processed_buffers = set() processed_buffers.add(buffer) buffers.append(buffer) return buffers def offload_weights(self): """ offload weights """ if self._weights_offloaded: log_rank_0('Call offload_weights when already offloaded. Ignore it.') return optimizer = self._optimizer # TODO(jiqi): support expert parallel params # In the V3 version, when distributed optimizer is used, parameter data are managed together with # gradients in buffers. if self._use_distributed_optimizer: optimizer.shard_float16_groups.clear() optimizer.shard_fp32_groups.clear() optimizer.pbuf_view_items.clear() if self._group_flat_weights is None: self._group_flat_weights = [] for buffer in self._buffers: assert buffer.param_data is not None self._group_flat_weights.append( BucketizedFlatTensors([buffer.param_data], self._bucket_size_mb, 'cpu') ) # Remove references from params for p, _ in self._model.param_to_buffer.items(): # save the shape for reconstruction p._saved_shape = p.shape p.data = FlatTensors._EMPTY_TENSOR # Remove references from buckets for buffer in self._buffers: for bucket in buffer.buckets: bucket.param_data = None else: if self._group_flat_weights is None: self._group_flat_weights = self._flat_param_groups( [ optimizer.float16_groups, optimizer.fp32_from_fp32_groups, ], ) # Offload param_data of buffers for flat_weights in self._group_flat_weights: flat_weights.copy_to_primary_store() self._model.grad_accs.clear() self._weights_offloaded = True def onload_weights(self): """ onload weights """ if not self._weights_offloaded: log_rank_0('Call onload_weights when already onloaded. Ignore it.') return optimizer = self._optimizer # Onload param_data of buffers for flat_weights in self._group_flat_weights: flat_weights.copy_to_gpu_buffer() if self._use_distributed_optimizer: # Reconstruct references from buckets for buffer in self._buffers: assert buffer.param_data is not None for bucket_id, bucket in enumerate(buffer.buckets): (start_index, end_index) = buffer.bucket_indices[bucket_id] bucket.param_data = None if buffer.param_data is not None: bucket.param_data = buffer._get( torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM ) # Reconstruct references from params for param, buffer in self._model.param_to_buffer.items(): data_start_index, _, bucket_id = buffer.param_index_map[param] if buffer.param_data is not None: param.data = buffer._get(param._saved_shape, data_start_index, buffer_type=BufferType.PARAM) model = self._model # Re-register grad acc hooks, see Megatron DistributedDataParallel#__init__. model.grad_accs = [] for param in model.module.parameters(): if param.requires_grad: # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator function. grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(model._make_param_hook(param, model.param_to_buffer)) model.grad_accs.append(grad_acc) if not self._use_distributed_optimizer: self._weights_offloaded = False return optimizer.pbuf_view_items = optimizer._get_model_param_buffer_dp_views() shard_float16_groups = optimizer.shard_float16_groups shard_fp32_groups = optimizer.shard_fp32_groups param_gbuf_map = optimizer.model_param_gbuf_map opt_group_ranges = optimizer.opt_group_ranges model_gbuf_ranges = optimizer.gbuf_ranges # Rebuild shard_float16_groups and shard_fp32_groups, # see Megatron DistributedOptimizer#build_model_and_main_param_groups. for _, group_range in enumerate(opt_group_ranges): shard_float16_params_this_group = [] shard_fp32_params_this_group = [] shard_float16_groups.append(shard_float16_params_this_group) shard_fp32_groups.append(shard_fp32_params_this_group) for model_param in group_range["params"]: assert model_param.requires_grad gbuf_index, dtype, bucket_index = param_gbuf_map[model_param] gbuf_range = model_gbuf_ranges[gbuf_index][dtype][bucket_index] param_range = gbuf_range["param_map"][model_param]["param"] # fp16, bf16 params. if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) if hasattr(model_param, 'shared'): shard_model_param.shared = model_param.shared shard_float16_params_this_group.append(shard_model_param) # fp32 params. elif model_param.type() == 'torch.cuda.FloatTensor': shard_model_param = model_param.view(-1)[param_range.start : param_range.end] shard_fp32_params_this_group.append(shard_model_param) tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) if hasattr(model_param, 'shared'): shard_model_param.shared = model_param.shared else: raise TypeError( 'Wrapped parameters must be one of ' 'torch.cuda.FloatTensor, ' 'torch.cuda.HalfTensor, or ' 'torch.cuda.BFloat16Tensor. ' 'Received {}'.format(model_param.type()) ) self._weights_offloaded = False def free_grad_buffers(self): """ free grad buffers and related tensors """ if self._grad_buffers_freed: log_rank_0('Call free_grad_buffers when already freed. Ignore it.') return optimizer = self._optimizer # This is necessary, but don't know why. optimizer.zero_grad(True) # Remove references from params for p, buffer in self._model.param_to_buffer.items(): del p.main_grad # Remove references from buckets and free grad_data of buffer for buffer in self._buffers: for bucket in buffer.buckets: del bucket.grad_data del buffer.grad_data self._grad_buffers_freed = True def build_grad_buffers(self): """ build grad buffers and related tensors """ if not self._grad_buffers_freed: log_rank_0('Call build_grad_buffers when already built. Ignore it.') return # Build buffers and reconstruct references from buckets for buffer in self._buffers: buffer.grad_data = torch.zeros( buffer.numel, dtype=buffer.grad_dtype, device=torch.cuda.current_device(), requires_grad=False, ) for bucket_id, bucket in enumerate(buffer.buckets): (start_index, end_index) = buffer.bucket_indices[bucket_id] bucket.grad_data = buffer._get( torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD ) # Reconstruct references from params for param, buffer in self._model.param_to_buffer.items(): data_start_index, _, bucket_id = buffer.param_index_map[param] param.main_grad = buffer._get(param.data.shape, data_start_index, buffer_type=BufferType.GRAD) self._grad_buffers_freed = False