chatlearn/utils/arguments.py (449 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """arguments from command or yaml.""" import argparse import ast import os from typing import List, Optional, Union import re import yaml from chatlearn.utils.constant import ( DYNAMIC_BATCH_SIZE, LORA_LAYER, RAY_PG_STRATEGY, PARAM_SYNC_COMM_TYPE, ROUTED_EXPERT_REGROUPING_COMM_TYPE, TrainingShffuleMode) from chatlearn.utils.logger import logger from chatlearn.utils.utils import get_attributes def get_path(fn, folder): if not fn.startswith("/") and not fn.startswith(folder): fn = os.path.join(folder, fn) assert os.path.exists(fn), f'{fn} not exists' return fn def convert_type(data): try: return ast.literal_eval(data) except Exception: return data def parse_value(value): if isinstance(value, dict): return {k: parse_value(v) for k, v in value.items()} if isinstance(value, str): if value.strip().startswith("${"): # ${env_name:default_value} placeholder = value.replace("${", "")[:-1] placeholder = placeholder.split(":") env_name = placeholder[0] if env_name in os.environ: value = convert_type(os.environ[env_name]) else: if len(placeholder) > 1: value = convert_type(placeholder[1]) else: logger.warning(f"cannot find value for {env_name}, set to None") value = None # handling scientific notation(e.g., "5e-6", "5E+10") elif re.match(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$", value): try: value = float(value) except Exception: pass return value def update_dict(src, dst): # do not overwrite for k, v in src.items(): if k not in dst: dst[k] = v else: if isinstance(v, dict) and isinstance(dst[k], dict): update_dict(v, dst[k]) def parse_args_from_yaml(config_file, config_dir): with open(config_file, 'r', encoding='utf-8') as stream: config_vars = yaml.load(stream, Loader=yaml.SafeLoader) # empty yaml file if config_vars is None: return {} config_vars = {key: parse_value(value) for key, value in config_vars.items()} if 'includes' in config_vars: includes_vars = {} # iterate in reverse order, so the next include overwrite the prev for base in reversed(config_vars["includes"]): base_path = get_path(base, config_dir) base_config = parse_args_from_yaml(base_path, config_dir) update_dict(base_config, includes_vars) update_dict(includes_vars, config_vars) return config_vars def parse_args(): """Parse all arguments.""" parser = argparse.ArgumentParser(description='ChatLearn Arguments', allow_abbrev=False) parser.add_argument("-c", "--config", required=False, help="where to load YAML configuration", metavar="FILE") args, _ = parser.parse_known_args() if args.config: config_dir = os.path.dirname(args.config) args_yaml = parse_args_from_yaml(args.config, config_dir) else: config_dir = None args_yaml = None config = Config(args_yaml, config_dir) return config class BaseConfig: """Base class includes some common format functions.""" def __init__(self): self._finalize = True def __str__(self): members = [attr for attr in dir(self) \ if not callable(getattr(self, attr)) and not attr.startswith("__")] ser_str = self.__class__.__name__ + " {\n" for key in members: if key.startswith('_'): continue attr = getattr(self, key) attr = '"{}"'.format(attr) if isinstance(attr, str) else attr ser_str += " %s = %s,\n" % (key, attr) ser_str += "}" return ser_str def __repr__(self): return self.__str__() def validate(self): pass class SubConfig(BaseConfig): """Sub Config""" _is_changed = False def __setattr__(self, name, value): if not name.startswith("_") and getattr(self, name) != value: self._is_changed = True super().__setattr__(name, value) def is_changed(self): return self._is_changed class LoraConfig(SubConfig): """Config for lora""" #: enable lora, default False. enable_lora: bool = False #: The "name_scope" parameter is used to specify a particular module to be converted to its LoRA. #: By default, it is set to None, which means there is no restriction on the module and any module #: can be converted using the "lora_layer" parameter. However, if "name_scope" is set to a specific #: value (e.g., "encoder"), only the modules whose name_scope contains the value "encoder" will be converted to LoRA. part_module_name: str = None #: The rank value of the LoRA, which is the r dimension of the A/B matrix. lora_dim: int = 8 #: The LoRA dropout ratio refers to whether dropout computation is inserted in the forward pass #: of the LoRA layer. By default, the dropout ratio is set to 0.0. lora_dropout: float = 0.0 #: When adding the values of the LoRA A and B matrices to the original weight matrix, #: the scaling value is set as "W = W + A * B * lora_scaling". By default, the scaling value #: is set to 1.0. lora_scaling: float = 1.0 #: The layer class names involved in LoRA training in the model, separated by commas. lora_layer: str = LORA_LAYER #: LoRA training is enabled only in the ColumnParallelLinear layer of the MHA QKV module. column_only_qkv: bool = False class BatchGenerationConfig(SubConfig): """Config for batch generation ranking and memory-efficiency.""" #: [optional] sort prompts by length each episode. ranking: bool = False #: [optional] min prompt length in the first stage of batch generation. min_prompt_length: int = 0 class ModelConfig(BaseConfig): """Config for model.""" #: [legacy] number of GPU used for one model, default 0. num_device: int = 0 #: [required] number of GPU used for one model, default 0, same as num_device num_gpu: int = 0 #: [required] number of GPU used for one model, default 0 num_cpu: int = 0 #: [optional] gpu per process, e.g., for PyTorch DDP, Megatron, DeepSpeed, `gpu_per_process` is set to 1 gpu_per_process: int = None #: [optional] cpu per process cpu_per_process: int = None #: [optional] number of module replica, #: for gpu model, num_replica = num_gpu // (TP * PP * DP * EP), #: for cpu model, num_replica = num_cpu // cpu_per_process num_replica: int = 1 #: [required] whether model is trainable trainable: bool = False #: [optional] tensor model parallel size tensor_model_parallel_size: int = None #: [optional] pipeline model parallel size pipeline_model_parallel_size: int = None #: [optional] expert model parallel size for Megatron-Core expert_model_parallel_size: int = None #: [optional] zero size zero_size: int = None #: [optional] FSDP parallel size fsdp_size: int = None #: [optional] config file for model model_config_file: str = "" config_dir: str = "" #: [optional] model type, e.g., Torch/Tensorflow, etc model_type: str = "" #: [optional] placeholder for other args args_dict: dict = None #: [optional] generation batch size, will overwrite generation batch size in RuntimeConfig generation_batch_size: int = None #: lora config lora: LoraConfig = None #: batch generation config batch_generation: BatchGenerationConfig = None #: offload optimizer states offload_optimizer_states = False #: parameter sync frequency sync_frequency = 1 #: offload weights offload_weights = False #: free grad buffers free_grad_buffers = False #: overall switch for offload optimizer states/weights and free grad buffers free_memory = False #: force to free memory force_free_memory = False def __init__(self): super().__init__() self.args_dict = {} self.lora = LoraConfig() self.batch_generation = BatchGenerationConfig() def __str__(self): members = [attr for attr in dir(self) \ if not callable(getattr(self, attr)) and not attr.startswith("__")] ser_str = self.__class__.__name__ + " {\n" for key in members: if key.startswith('_'): continue attr = getattr(self, key) if key in ["lora", "batch_generation"]: if not attr.is_changed(): continue attr = '"{}"'.format(attr) if isinstance(attr, str) else attr ser_str += " %s = %s,\n" % (key, attr) ser_str += "}" return ser_str class RuntimeConfig(BaseConfig): """training related configs.""" #: [required] number of episodes. One episode includes a inference and training loop. num_episode: int = 5000 #: [required] number of samples per episode. sample_per_episode: int = 1000 #: [optional] number of training epoch per episode. default set to 1. num_training_epoch: int = 1 #: [optional] max iteration per sample, for mcts-style search algorithm max_iteration_per_sample = 1 #: [required] generation(inference) batch size. generation_batch_size: int = 2 #: [required] training micro batch size. train_micro_batch_size: int = 2 #: [required] training global batch size. train_global_batch_size: int = None #: [required] save checkpoint per `save_episode_interval` episodes. save_episode_interval: int = None #: [optional] log time and memory per `log_interval` iterations. log_interval: int = 1 #: [required]: data_path for dataset or a List of data_path for different kind of datasets data_path: Optional[Union[List[str], str]] = None #: [optional]: the ratio for each kind of data_path in a training episode, default: None data_ratio: Optional[Union[List[int], int]] = None #: [optional]: shuffle in each epoch of dataset, default: True data_shuffle: Optional[bool] = True #: [optional]: rerank batch of data by row, default: True data_rerank: Optional[bool] = False #: [optional]: colocate models into the same device colocation: List[str] = [] #: [optional]: eval every N episode, if 0, will not eval eval_episode_interval: int = 0 #: [optional]: enable resume training when data checkpoint is set enable_resume_training: bool = True #: [optional]: checkpoint for dataloader data_checkpoint_path: str = None #: [optional]: max data checkpoint nums max_data_ckpt_nums: int = None #: [optional]: load data checkpoint from iteration load_data_checkpoint_iteration: int = None #: [optional]: stream_data_loader type, ["fixed", "dynamic"] stream_data_loader_type: str = "fixed" #: private debug: bool = False #: enable nsys nvtx nsys: bool = False #: profiler dir profiler_dir: str = None #: coalesce_buffer size in mb coalesced_buffer_mb: int = 100 #: concurrent parameter sync concurrent_comm: bool = True #: parameter sync communication type, broadcast/p2p param_sync_comm_type: str = PARAM_SYNC_COMM_TYPE.BROADCAST.value #: parameter sync max workers param_sync_max_workers: int = None #: communication type to regroup routed experts, allgather/alltoall routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL #: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes #: if `max_relay_episode` is set to 0, then relay is disabled max_relay_episode: int = 0 #: relay after n episodes relay_episode_offset: int = 0 #: training shuffle mode training_shuffle_mode: str = TrainingShffuleMode.BATCH #: consumed samples consumed_samples: int = 0 #: concurrent model setup concurrent_setup: bool = False #: bucket size in the memory manager to reduce peak memory bucket_size_mb_in_memory_manager: int = 1024 #: free collective group after parameter synchronization and rebuild before next synchronization free_sync_collective_group: bool = False #: [optional] cpu only model schedule policy, PACK or SPREAD #: PACK: All provided bundles are packed onto a single node on a best-effort basis. #: SPREAD: Each bundle is spread onto separate nodes on a best-effort basis. cpu_schedule_strategy: str = RAY_PG_STRATEGY.SPREAD.value #: exp name for each run exp_name: str = "CHATLEARN" #: output dir output_dir: str = "./" #: validate param sync validate_param_sync: bool = False #: whether to eval before training enable_eval_before_training: bool = False #: policy to regroup queue policy_to_regroup_queue: str = "global_barrier" #: configuration file path for logging log_config_file: str = "" #: [optional] placeholder for log_args_dict log_args_dict: dict = None def __init__(self): super().__init__() self._args_dict = {} def get(self, key): """ Get other config by key. Args ---- key: str key to get config """ if key not in self._args_dict: logger.warning(f"{key} not found in RuntimeConfig") else: return self._args_dict[key] def validate(self): """ :meta private: """ for key in self._args_dict: if key == "save_interval": raise Exception("save_interval is deprecated, please use save_episode_interval to save checkpoints") class RuntimeEnvConfig(BaseConfig): """Runtime env config, you can refer https://docs.ray.io/en/latest/ray-core/handling-dependencies.html for more information.""" #: pip install packages pip: List[str] = [] #: python modules py_modules: List[str] = [] #: working directory working_dir: str = os.getcwd() #: platform, e.g., DLC platform: str = "" #: excludes files from packaging excludes: List[str] = [] def __init__(self): super().__init__() self._args_dict = {} def get(self, key): """ Get other config by key Args ---- key: str Key to get config. """ if key not in self._args_dict: logger.warning(f"{key} not found in RuntimeConfig") else: return self._args_dict[key] class Config(BaseConfig): """A class to manage chatlearn configuration. Args ---- param_dict: dict dict format of parameters.""" def __init__(self, param_dict=None, config_dir=None): super().__init__() self._finalize = False self.models = {} self.env_args = RuntimeEnvConfig() self.runtime_args = RuntimeConfig() self.config_dir = config_dir self._active_module_args = None self.initialized = False if param_dict: self._parse_params(param_dict) self._validate_params() # remove later, just for compatibility self.rlhf_args = self.runtime_args self._finalize = True def _parse_params(self, param_dict): """Parse params from param_dict.""" def set_param(user_args, config_cls, instance): for attribute, default_value in get_attributes(config_cls): if attribute in user_args: value = user_args[attribute] if attribute == "colocation": colocation_list = [] for group in value: colocation_list.append(group.replace(' ', '').split(',')) value = colocation_list elif attribute == "data_ratio": if isinstance(value, str): value = [int(v) for v in value.split(',')] else: value = default_value original_value = getattr(instance, attribute) if original_value is not None: assert isinstance(original_value, type(value)), \ f"{instance}.{attribute} should be type of {type(original_value)} but got {type(value)}" setattr(instance, attribute, value) for user_attribute in user_args: if not hasattr(config_cls, user_attribute): if hasattr(instance, "_args_dict"): getattr(instance, "_args_dict")[user_attribute] = user_args[user_attribute] else: raise RuntimeError(f"attribute {user_attribute} not defined in {config_cls.__name__}") instance.validate() for model_name, model_args in param_dict["models"].items(): model_config = ModelConfig() model_config.config_dir = self.config_dir for user_attribute, user_value in model_args.items(): if hasattr(ModelConfig, user_attribute): original_value = getattr(ModelConfig, user_attribute) if 'num_device' == user_attribute: logger.warning("num_device is deprecated, please use num_gpu instead") if 'num_gpu' not in model_args.keys(): setattr(model_config, "num_gpu", user_value) else: logger.warning("both num_device and num_gpu are set, use num_gpu") continue if 'lora' == user_attribute: set_param(user_value, LoraConfig, model_config.lora) user_value = model_config.lora elif "batch_generation" == user_attribute: set_param(user_value, BatchGenerationConfig, model_config.batch_generation) user_value = model_config.batch_generation if original_value is not None: assert isinstance(user_value, type(original_value)), \ f"ModelConfig.{user_attribute} should be type of {type(original_value)} but got {type(user_value)} ({user_value})" setattr(model_config, user_attribute, user_value) else: logger.warning(f"unknown argument {user_attribute}") self.models[model_name] = model_config if model_config.model_config_file: model_config.model_config_file = get_path(model_config.model_config_file, self.config_dir) model_config.args_dict = parse_args_from_yaml(model_config.model_config_file, self.config_dir) if "runtime" in param_dict: set_param(param_dict["runtime"], RuntimeConfig, self.runtime_args) elif "rlhf" in param_dict: logger.warning("rlhf is deprecated, please use runtime as section name") set_param(param_dict["rlhf"], RuntimeConfig, self.runtime_args) if "runtime_env" in param_dict: set_param(param_dict["runtime_env"], RuntimeEnvConfig, self.env_args) if self.runtime_args.log_config_file: self.runtime_args.log_config_file = get_path(self.runtime_args.log_config_file, self.config_dir) self.runtime_args.log_args_dict = parse_args_from_yaml(self.runtime_args.log_config_file, self.config_dir) def _get_and_check_type(value, default_value, key): # To be noticed: all str type values should in lower case. if isinstance(value, str): value = value.lower() if default_value is None: return value if not isinstance(value, type(default_value)): raise ValueError("%s type error, expected: %s." \ % (key, type(default_value))) return value def _validate_params(self): if self.runtime_args.train_global_batch_size is None: self.runtime_args.train_global_batch_size = self.runtime_args.train_micro_batch_size assert self.runtime_args.train_global_batch_size % self.runtime_args.train_micro_batch_size == 0, \ f"train_global_batch_size should be times of train_micro_batch_size," \ f"but got {self.runtime_args.train_global_batch_size}/{self.runtime_args.train_micro_batch_size}" assert self.runtime_args.train_global_batch_size <= self.runtime_args.sample_per_episode, \ "train_global_batch_size should be less than or equal to sample_per_episode, " \ f"got {self.runtime_args.train_global_batch_size} and {self.runtime_args.sample_per_episode}" assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"] assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY] assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE) if isinstance(self.runtime_args.data_path, list): assert self.runtime_args.data_ratio is not None and isinstance(self.runtime_args.data_ratio, list), ( f"expect data_ratio to be list when data_path is list, got {self.runtime_args.data_ratio}" ) assert len(self.runtime_args.data_path) == len(self.runtime_args.data_ratio), ( "expect data_path and data_ratio to have same length, " f"got {len(self.runtime_args.data_path)} and {len(self.runtime_args.data_ratio)}" ) for model_name, model_args in self.models.items(): if model_args.num_gpu >= 1: if model_args.gpu_per_process is None: model_args.gpu_per_process = 1 else: assert model_args.gpu_per_process <= model_args.num_gpu, \ f"{model_name}: gpu_per_process: {model_args.gpu_per_process}, num_cpu: {model_args.num_gpu}" elif model_args.num_cpu >= 1: if model_args.cpu_per_process is None: model_args.cpu_per_process = 1 else: assert model_args.cpu_per_process <= model_args.num_cpu, \ f"{model_name}: cpu_per_process: {model_args.cpu_per_process}, num_cpu: {model_args.num_cpu}" if model_args.generation_batch_size is not None and model_args.generation_batch_size <= 0: model_args.generation_batch_size = DYNAMIC_BATCH_SIZE if model_args.generation_batch_size is None: if self.runtime_args.generation_batch_size: model_args.generation_batch_size = self.runtime_args.generation_batch_size for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "zero_size"]: if model_args.args_dict.get(key) is not None: setattr(model_args, key, model_args.args_dict.get(key)) assert getattr(model_args, key) >= 1 elif getattr(model_args, key) is None: setattr(model_args, key, 1) for key in ["fsdp_size"]: if getattr(model_args, key) is not None: setattr(model_args, key, getattr(model_args, key)) if getattr(model_args, key) == -1: print(f"set_fsdp_size {getattr(model_args, key)} to num_gpu: {model_args.num_gpu}") setattr(model_args, key, model_args.num_gpu) assert getattr(model_args, key) >= 1 elif getattr(model_args, key) is None: setattr(model_args, key, 1) ep_size = model_args.args_dict.get("expert_model_parallel_size") moe_ep_size = model_args.args_dict.get("moe_expert_model_parallel_size") if ep_size is not None and moe_ep_size is not None: assert ep_size == moe_ep_size, ( f"{model_name}: if you set moe_expert_model_parallel_size ({moe_ep_size}), " f"it must be equal to expert_model_parallel_size ({ep_size})" ) finalized_ep_size = ep_size elif ep_size is not None: finalized_ep_size = ep_size elif moe_ep_size is not None: finalized_ep_size = moe_ep_size else: finalized_ep_size = 1 assert finalized_ep_size >= 1 setattr(model_args, "expert_model_parallel_size", finalized_ep_size) if model_args.tensor_model_parallel_size > 1 or model_args.pipeline_model_parallel_size > 1 or model_args.expert_model_parallel_size > 1: assert model_args.zero_size == 1 or model_args.zero_size is None assert model_args.fsdp_size == 1 or model_args.fsdp_size is None assert model_args.num_gpu % ( model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size) == 0, \ f"{model_name}: num_gpu must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size * " \ f"expert_model_parallel_size, but got num_gpu = {model_args.num_gpu}, " \ f"tensor_model_parallel_size = {model_args.tensor_model_parallel_size}, " \ f"pipeline_model_parallel_size = {model_args.pipeline_model_parallel_size}, and "\ f"expert_model_parallel_size = {model_args.expert_model_parallel_size}." assert model_args.num_gpu > 0 or model_args.num_cpu > 0, \ f"{model_name} num_gpu: {model_args.num_gpu}, num_cpu: {model_args.num_cpu}, at least one of them should be set" if model_args.num_gpu >= 1: if model_args.zero_size > 1: assert model_args.num_gpu % model_args.zero_size == 0 model_args.num_replica = model_args.num_gpu // model_args.zero_size elif model_args.fsdp_size > 1: # For FSDP, num_gpu must be divisible by fsdp_size assert model_args.num_gpu % model_args.fsdp_size == 0 model_args.num_replica = model_args.num_gpu // ( model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size \ * model_args.expert_model_parallel_size * model_args.fsdp_size) else: model_args.num_replica = model_args.num_gpu // ( model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size) elif model_args.num_cpu >= 1: model_args.num_replica = model_args.num_cpu // model_args.cpu_per_process assert model_args.num_replica * model_args.generation_batch_size <= self.runtime_args.sample_per_episode, \ f"{model_name}: num_replica * batch_size {model_args.num_replica}*{model_args.generation_batch_size} " + \ f"should be less than or equal to sample_per_episode {self.runtime_args.sample_per_episode}" if model_args.batch_generation.min_prompt_length: logger.info(f"Enable batch generation: \ min_prompt_length = {model_args.batch_generation.min_prompt_length}") if model_args.free_memory: model_args.offload_weights = True if model_args.trainable: model_args.free_grad_buffers = True model_args.offload_optimizer_states = True if self.runtime_args.colocation and len(self.runtime_args.colocation) > 0: model_set = set() for colocate_models in self.runtime_args.colocation: for model_name in colocate_models: assert model_name not in model_set, f"Model {model_name} should only appear once in colocation group" model_set.add(model_name) if self.runtime_args.exp_name not in self.runtime_args.output_dir: self.runtime_args.output_dir = f"{self.runtime_args.output_dir}/{self.runtime_args.exp_name}" logger.info(f"Env Config: \n{self.env_args}") logger.info(f"Runtime Config: \n{self.runtime_args}") for name, model_args in self.models.items(): logger.info(f"Model({name}) Config: \n{model_args}") @property def active_module_args(self): return self._active_module_args @active_module_args.setter def active_module_args(self, config): self._active_module_args = config