def _parse_params()

in chatlearn/utils/arguments.py [0:0]


    def _parse_params(self, param_dict):
        """Parse params from param_dict."""

        def set_param(user_args, config_cls, instance):
            for attribute, default_value in get_attributes(config_cls):
                if attribute in user_args:
                    value = user_args[attribute]
                    if attribute == "colocation":
                        colocation_list = []
                        for group in value:
                            colocation_list.append(group.replace(' ', '').split(','))
                        value = colocation_list
                    elif attribute == "data_ratio":
                        if isinstance(value, str):
                            value = [int(v) for v in value.split(',')]
                else:
                    value = default_value
                original_value = getattr(instance, attribute)
                if original_value is not None:
                    assert isinstance(original_value, type(value)), \
                        f"{instance}.{attribute} should be type of {type(original_value)} but got {type(value)}"

                setattr(instance, attribute, value)
            for user_attribute in user_args:
                if not hasattr(config_cls, user_attribute):
                    if hasattr(instance, "_args_dict"):
                        getattr(instance, "_args_dict")[user_attribute] = user_args[user_attribute]
                    else:
                        raise RuntimeError(f"attribute {user_attribute} not defined in {config_cls.__name__}")
            instance.validate()

        for model_name, model_args in param_dict["models"].items():
            model_config = ModelConfig()
            model_config.config_dir = self.config_dir
            for user_attribute, user_value in model_args.items():
                if hasattr(ModelConfig, user_attribute):
                    original_value = getattr(ModelConfig, user_attribute)
                    if 'num_device' == user_attribute:
                        logger.warning("num_device is deprecated, please use num_gpu instead")
                        if 'num_gpu' not in model_args.keys():
                            setattr(model_config, "num_gpu", user_value)
                        else:
                            logger.warning("both num_device and num_gpu are set, use num_gpu")
                            continue
                    if 'lora' == user_attribute:
                        set_param(user_value, LoraConfig, model_config.lora)
                        user_value = model_config.lora
                    elif "batch_generation" == user_attribute:
                        set_param(user_value, BatchGenerationConfig, model_config.batch_generation)
                        user_value = model_config.batch_generation
                    if original_value is not None:
                        assert isinstance(user_value, type(original_value)), \
                            f"ModelConfig.{user_attribute} should be type of {type(original_value)} but got {type(user_value)} ({user_value})"
                    setattr(model_config, user_attribute, user_value)
                else:
                    logger.warning(f"unknown argument {user_attribute}")

            self.models[model_name] = model_config
            if model_config.model_config_file:
                model_config.model_config_file = get_path(model_config.model_config_file, self.config_dir)
                model_config.args_dict = parse_args_from_yaml(model_config.model_config_file, self.config_dir)
        if "runtime" in param_dict:
            set_param(param_dict["runtime"], RuntimeConfig, self.runtime_args)
        elif "rlhf" in param_dict:
            logger.warning("rlhf is deprecated, please use runtime as section name")
            set_param(param_dict["rlhf"], RuntimeConfig, self.runtime_args)
        if "runtime_env" in param_dict:
            set_param(param_dict["runtime_env"], RuntimeEnvConfig, self.env_args)

        if self.runtime_args.log_config_file:
            self.runtime_args.log_config_file = get_path(self.runtime_args.log_config_file, self.config_dir)
            self.runtime_args.log_args_dict = parse_args_from_yaml(self.runtime_args.log_config_file, self.config_dir)

        def _get_and_check_type(value, default_value, key):
            # To be noticed: all str type values should in lower case.
            if isinstance(value, str):
                value = value.lower()
            if default_value is None:
                return value
            if not isinstance(value, type(default_value)):
                raise ValueError("%s type error, expected: %s." \
                                 % (key, type(default_value)))
            return value