in chatlearn/utils/arguments.py [0:0]
def _parse_params(self, param_dict):
"""Parse params from param_dict."""
def set_param(user_args, config_cls, instance):
for attribute, default_value in get_attributes(config_cls):
if attribute in user_args:
value = user_args[attribute]
if attribute == "colocation":
colocation_list = []
for group in value:
colocation_list.append(group.replace(' ', '').split(','))
value = colocation_list
elif attribute == "data_ratio":
if isinstance(value, str):
value = [int(v) for v in value.split(',')]
else:
value = default_value
original_value = getattr(instance, attribute)
if original_value is not None:
assert isinstance(original_value, type(value)), \
f"{instance}.{attribute} should be type of {type(original_value)} but got {type(value)}"
setattr(instance, attribute, value)
for user_attribute in user_args:
if not hasattr(config_cls, user_attribute):
if hasattr(instance, "_args_dict"):
getattr(instance, "_args_dict")[user_attribute] = user_args[user_attribute]
else:
raise RuntimeError(f"attribute {user_attribute} not defined in {config_cls.__name__}")
instance.validate()
for model_name, model_args in param_dict["models"].items():
model_config = ModelConfig()
model_config.config_dir = self.config_dir
for user_attribute, user_value in model_args.items():
if hasattr(ModelConfig, user_attribute):
original_value = getattr(ModelConfig, user_attribute)
if 'num_device' == user_attribute:
logger.warning("num_device is deprecated, please use num_gpu instead")
if 'num_gpu' not in model_args.keys():
setattr(model_config, "num_gpu", user_value)
else:
logger.warning("both num_device and num_gpu are set, use num_gpu")
continue
if 'lora' == user_attribute:
set_param(user_value, LoraConfig, model_config.lora)
user_value = model_config.lora
elif "batch_generation" == user_attribute:
set_param(user_value, BatchGenerationConfig, model_config.batch_generation)
user_value = model_config.batch_generation
if original_value is not None:
assert isinstance(user_value, type(original_value)), \
f"ModelConfig.{user_attribute} should be type of {type(original_value)} but got {type(user_value)} ({user_value})"
setattr(model_config, user_attribute, user_value)
else:
logger.warning(f"unknown argument {user_attribute}")
self.models[model_name] = model_config
if model_config.model_config_file:
model_config.model_config_file = get_path(model_config.model_config_file, self.config_dir)
model_config.args_dict = parse_args_from_yaml(model_config.model_config_file, self.config_dir)
if "runtime" in param_dict:
set_param(param_dict["runtime"], RuntimeConfig, self.runtime_args)
elif "rlhf" in param_dict:
logger.warning("rlhf is deprecated, please use runtime as section name")
set_param(param_dict["rlhf"], RuntimeConfig, self.runtime_args)
if "runtime_env" in param_dict:
set_param(param_dict["runtime_env"], RuntimeEnvConfig, self.env_args)
if self.runtime_args.log_config_file:
self.runtime_args.log_config_file = get_path(self.runtime_args.log_config_file, self.config_dir)
self.runtime_args.log_args_dict = parse_args_from_yaml(self.runtime_args.log_config_file, self.config_dir)
def _get_and_check_type(value, default_value, key):
# To be noticed: all str type values should in lower case.
if isinstance(value, str):
value = value.lower()
if default_value is None:
return value
if not isinstance(value, type(default_value)):
raise ValueError("%s type error, expected: %s." \
% (key, type(default_value)))
return value