chatlearn/models/deepspeed_module.py (149 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DeepSpeed module"""
from datetime import timedelta
import math
import os
import random
import deepspeed
import numpy as np
import torch
from torch import distributed as dist
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.integrations import HfDeepSpeedConfig
from transformers.trainer import get_scheduler
from chatlearn.utils.utils import dict_to_simplenamespace
from .deepspeed.deepspeed_utils import get_eval_ds_config, get_tokenizer, get_train_ds_config, create_optimizer
from .deepspeed.deepspeed_utils import save_hf_format, save_zero_three_model
from .torch_module import TorchModule
class DeepSpeedModule(TorchModule):
"""DeepSpeedModule is the class for models accelerated with DeepSpeed.
Args
----
name : str
model name
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.trainable:
# inference only
if self.model_args.get("train_micro_batch_size") != self.module_args.generation_batch_size:
self._logger.info(
f"{self.name} Overwrite train_micro_batch_size with generation_batch_size {self.module_args.generation_batch_size}")
self.train_micro_batch_size = self.module_args.generation_batch_size
else:
self.train_micro_batch_size = self.runtime_args.train_micro_batch_size
self.train_global_batch_size = self.runtime_args.train_global_batch_size
self.zero_size = self.module_args.zero_size
def set_seed(self, seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_distributed(self, timeout):
self.set_seed(self.seed)
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
deepspeed.init_distributed(timeout=timeout)
def prepare(self, *models_or_model_optim_pairs):
ret = []
for arg in models_or_model_optim_pairs:
if not isinstance(arg, tuple):
ret.append(self._ds_init_eval_model(arg))
else:
assert len(arg) == 3, f'Expect (model, optimizer, scheduler) pair, got a tuple with size "{len(arg)}"'
ret.append(self._ds_init_train_model(*arg))
return ret[0] if len(ret) == 1 else ret
def _ds_init_eval_model(self, model):
ds_config = self.get_ds_eval_config(offload=getattr(model, "_offload", False))
local_rank = int(os.environ['LOCAL_RANK'])
engine, *_ = deepspeed.initialize(
model=model,
args={"local_rank": local_rank},
config=ds_config,
dist_init_required=True,
)
model = engine
return model
def _ds_init_train_model(self, model, optim, scheduler):
ds_config = self.get_ds_train_config()
local_rank = int(os.environ['LOCAL_RANK'])
engine, optim, _, scheduler = deepspeed.initialize(
model=model,
optimizer=optim,
lr_scheduler=scheduler,
config=ds_config,
args={"local_rank": local_rank},
dist_init_required=True,
)
model = engine
return model, optim, scheduler
def get_ds_eval_config(self, offload=False):
# DS Config
ds_config = get_eval_ds_config(offload=offload, stage=self.zero_stage if self.zero_stage == 3 else 0, bf16=self.bf16)
ds_config["train_micro_batch_size_per_gpu"] = self.train_micro_batch_size
ds_config["train_batch_size"] = self.train_micro_batch_size * self.zero_size
return ds_config
def get_ds_train_config(self):
# DS Config
ds_config = get_train_ds_config(
offload=False,
adam_offload=self.adam_offload,
stage=self.zero_stage,
bf16=self.bf16,
max_norm=self.max_norm,
grad_accum_dtype="fp32",
disable_trace_cache=self.disable_trace_cache,
)
ds_config["train_micro_batch_size_per_gpu"] = self.train_micro_batch_size
ds_config["gradient_accumulation_steps"] = self.train_global_batch_size // self.train_micro_batch_size // self.world_size
ds_config["train_batch_size"] = self.train_global_batch_size
return ds_config
def create_model(self, args):
# TODO: try attn_implementation="flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
args.pretrain_or_model,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if self.bf16 else "auto"
)
return model
def model_setup(self):
super().model_setup()
args = dict_to_simplenamespace(self.model_args)
self.prompt_max_len = getattr(args, "prompt_max_len", 1024)
self.args = args
self.zero_stage = getattr(args, "zero_stage", 3)
self.bf16 = args.bf16
self.seed = getattr(args, "seed", 42)
self.max_norm = getattr(args, "max_norm", 1.0)
dist_timeout = getattr(args, 'distributed_timeout', 30)
self.setup_distributed(timedelta(minutes=dist_timeout))
# TODO: deal with offload later
ds_config = self.get_ds_eval_config(offload=False)
# efficiently deploy DeepSpeed stage 3, you must instantiate the HfDeepSpeedConfig
# object before instantiating the model.
# https://huggingface.co/transformers/v4.9.2/main_classes/deepspeed.html
dschf = HfDeepSpeedConfig(ds_config) if ds_config is not None and self.zero_stage == 3 else None # pylint: disable=unused-variable
model = self.create_model(self.args)
self.tokenizer = get_tokenizer(
args.pretrain_or_model, model, "left", use_fast=True
)
if self.trainable:
if getattr(args, "gradient_checkpointing", False):
model.gradient_checkpointing_enable()
self.disable_trace_cache = True
learning_rate = float(args.learning_rate)
self.adam_offload = False
num_update_steps_per_episodes = self.runtime_args.sample_per_episode // self.train_global_batch_size
l2 = float(args.l2)
max_steps = math.ceil(self.runtime_args.num_episode * num_update_steps_per_episodes)
optimizer = create_optimizer(
model, self.adam_offload, lr=learning_rate, betas=(0.9, 0.95), weight_decay=l2
)
scheduler = get_scheduler("cosine_with_min_lr",
optimizer,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps,
scheduler_specific_kwargs={"min_lr": learning_rate * 0.1},)
self.model, self.optimizer, self.scheduler = self.prepare((model, optimizer, scheduler))
else:
self.model = self.prepare(model)
self.generation_config = GenerationConfig.from_pretrained(args.pretrain_or_model, trust_remote_code=True)
self.tokenizer.eos_token_id = self.generation_config.eos_token_id
if not self.trainable:
self.model.eval()
@property
def data_parallel_size(self):
"""
:meta private:
"""
return dist.get_world_size()
@property
def data_parallel_rank(self):
"""
:meta private:
"""
return dist.get_rank()
def save_checkpoint(self, iteration):
save_dir = f"{self.runtime_args.output_dir}/save_model/{self.name}/{iteration}"
save_hf_format(self.model, self.tokenizer, save_dir)
save_zero_three_model(self.model, torch.distributed.get_rank(), save_dir, self.zero_stage)
self._logger.info(f"save checkpoint to {save_dir}")