in chatlearn/utils/arguments.py [0:0]
def _validate_params(self):
if self.runtime_args.train_global_batch_size is None:
self.runtime_args.train_global_batch_size = self.runtime_args.train_micro_batch_size
assert self.runtime_args.train_global_batch_size % self.runtime_args.train_micro_batch_size == 0, \
f"train_global_batch_size should be times of train_micro_batch_size," \
f"but got {self.runtime_args.train_global_batch_size}/{self.runtime_args.train_micro_batch_size}"
assert self.runtime_args.train_global_batch_size <= self.runtime_args.sample_per_episode, \
"train_global_batch_size should be less than or equal to sample_per_episode, " \
f"got {self.runtime_args.train_global_batch_size} and {self.runtime_args.sample_per_episode}"
assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"]
assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY]
assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE)
if isinstance(self.runtime_args.data_path, list):
assert self.runtime_args.data_ratio is not None and isinstance(self.runtime_args.data_ratio, list), (
f"expect data_ratio to be list when data_path is list, got {self.runtime_args.data_ratio}"
)
assert len(self.runtime_args.data_path) == len(self.runtime_args.data_ratio), (
"expect data_path and data_ratio to have same length, "
f"got {len(self.runtime_args.data_path)} and {len(self.runtime_args.data_ratio)}"
)
for model_name, model_args in self.models.items():
if model_args.num_gpu >= 1:
if model_args.gpu_per_process is None:
model_args.gpu_per_process = 1
else:
assert model_args.gpu_per_process <= model_args.num_gpu, \
f"{model_name}: gpu_per_process: {model_args.gpu_per_process}, num_cpu: {model_args.num_gpu}"
elif model_args.num_cpu >= 1:
if model_args.cpu_per_process is None:
model_args.cpu_per_process = 1
else:
assert model_args.cpu_per_process <= model_args.num_cpu, \
f"{model_name}: cpu_per_process: {model_args.cpu_per_process}, num_cpu: {model_args.num_cpu}"
if model_args.generation_batch_size is not None and model_args.generation_batch_size <= 0:
model_args.generation_batch_size = DYNAMIC_BATCH_SIZE
if model_args.generation_batch_size is None:
if self.runtime_args.generation_batch_size:
model_args.generation_batch_size = self.runtime_args.generation_batch_size
for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "zero_size"]:
if model_args.args_dict.get(key) is not None:
setattr(model_args, key, model_args.args_dict.get(key))
assert getattr(model_args, key) >= 1
elif getattr(model_args, key) is None:
setattr(model_args, key, 1)
for key in ["fsdp_size"]:
if getattr(model_args, key) is not None:
setattr(model_args, key, getattr(model_args, key))
if getattr(model_args, key) == -1:
print(f"set_fsdp_size {getattr(model_args, key)} to num_gpu: {model_args.num_gpu}")
setattr(model_args, key, model_args.num_gpu)
assert getattr(model_args, key) >= 1
elif getattr(model_args, key) is None:
setattr(model_args, key, 1)
ep_size = model_args.args_dict.get("expert_model_parallel_size")
moe_ep_size = model_args.args_dict.get("moe_expert_model_parallel_size")
if ep_size is not None and moe_ep_size is not None:
assert ep_size == moe_ep_size, (
f"{model_name}: if you set moe_expert_model_parallel_size ({moe_ep_size}), "
f"it must be equal to expert_model_parallel_size ({ep_size})"
)
finalized_ep_size = ep_size
elif ep_size is not None:
finalized_ep_size = ep_size
elif moe_ep_size is not None:
finalized_ep_size = moe_ep_size
else:
finalized_ep_size = 1
assert finalized_ep_size >= 1
setattr(model_args, "expert_model_parallel_size", finalized_ep_size)
if model_args.tensor_model_parallel_size > 1 or model_args.pipeline_model_parallel_size > 1 or model_args.expert_model_parallel_size > 1:
assert model_args.zero_size == 1 or model_args.zero_size is None
assert model_args.fsdp_size == 1 or model_args.fsdp_size is None
assert model_args.num_gpu % (
model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size) == 0, \
f"{model_name}: num_gpu must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size * " \
f"expert_model_parallel_size, but got num_gpu = {model_args.num_gpu}, " \
f"tensor_model_parallel_size = {model_args.tensor_model_parallel_size}, " \
f"pipeline_model_parallel_size = {model_args.pipeline_model_parallel_size}, and "\
f"expert_model_parallel_size = {model_args.expert_model_parallel_size}."
assert model_args.num_gpu > 0 or model_args.num_cpu > 0, \
f"{model_name} num_gpu: {model_args.num_gpu}, num_cpu: {model_args.num_cpu}, at least one of them should be set"
if model_args.num_gpu >= 1:
if model_args.zero_size > 1:
assert model_args.num_gpu % model_args.zero_size == 0
model_args.num_replica = model_args.num_gpu // model_args.zero_size
elif model_args.fsdp_size > 1:
# For FSDP, num_gpu must be divisible by fsdp_size
assert model_args.num_gpu % model_args.fsdp_size == 0
model_args.num_replica = model_args.num_gpu // (
model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size \
* model_args.expert_model_parallel_size * model_args.fsdp_size)
else:
model_args.num_replica = model_args.num_gpu // (
model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size)
elif model_args.num_cpu >= 1:
model_args.num_replica = model_args.num_cpu // model_args.cpu_per_process
assert model_args.num_replica * model_args.generation_batch_size <= self.runtime_args.sample_per_episode, \
f"{model_name}: num_replica * batch_size {model_args.num_replica}*{model_args.generation_batch_size} " + \
f"should be less than or equal to sample_per_episode {self.runtime_args.sample_per_episode}"
if model_args.batch_generation.min_prompt_length:
logger.info(f"Enable batch generation: \
min_prompt_length = {model_args.batch_generation.min_prompt_length}")
if model_args.free_memory:
model_args.offload_weights = True
if model_args.trainable:
model_args.free_grad_buffers = True
model_args.offload_optimizer_states = True
if self.runtime_args.colocation and len(self.runtime_args.colocation) > 0:
model_set = set()
for colocate_models in self.runtime_args.colocation:
for model_name in colocate_models:
assert model_name not in model_set, f"Model {model_name} should only appear once in colocation group"
model_set.add(model_name)
if self.runtime_args.exp_name not in self.runtime_args.output_dir:
self.runtime_args.output_dir = f"{self.runtime_args.output_dir}/{self.runtime_args.exp_name}"
logger.info(f"Env Config: \n{self.env_args}")
logger.info(f"Runtime Config: \n{self.runtime_args}")
for name, model_args in self.models.items():
logger.info(f"Model({name}) Config: \n{model_args}")