Project-BasicAlgorithm/core/training/params.py (123 lines of code) (raw):

# Licensed to Apache Software Foundation (ASF) under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Apache Software Foundation (ASF) licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import inspect import json import warnings from copy import deepcopy class Params: def __init__( self, cls, param_file=None, param_str=None, search_params=None, **kwargs ): input_params = self.parse_file(param_file) str_input_params = self.parse_param_str(param_str) input_params.update(str_input_params) input_params.update(kwargs) self.input_params = self.check_input_params(cls, input_params) self.search_params = self.check_search_params(cls, search_params) def check_input_params(self, cls, input_params): # check and adaptive parameter type for input_params default_params = self.load_cls_default_params(cls) for key, value in input_params.items(): parse_func = getattr(self, key, None) if parse_func: value = parse_func(value) elif key in default_params: value = self.match_type(default_params[key], value) input_params[key] = value return input_params def check_search_params(self, cls, search_params): # check and adaptive parameter type for search_params search_params = self.parse_param_str(search_params) default_params = self.load_cls_default_params(cls) for key, values in search_params.items(): try: values = eval(values) except Exception as _: warnings.warn(f"value : {values} error, is must be list of something") continue parse_func = getattr(self, key, None) new_values = [] for value in values: if parse_func: value = parse_func(value) elif key in default_params: value = self.match_type(default_params[key], value) new_values.append(value) search_params[key] = new_values return search_params @staticmethod def load_cls_default_params(cls): default_values = deepcopy(cls.__init__.__defaults__) var_names = cls.__init__.__code__.co_varnames var_names = [name for name in var_names if name not in {"self", "kwargs"}] assert len(default_values) == len(var_names) return dict(zip(var_names, default_values)) @staticmethod def match_type(refer_var, input_var): if isinstance(refer_var, bool): tag = str(input_var).lower() if tag not in {"true", "false"}: warnings.warn(f"value : {input_var} error, set to False") input_var = tag == "true" elif refer_var is not None: refer_type = type(refer_var) input_var = refer_type(input_var) return input_var @staticmethod def parse_file(path): path = path or "" if not path.strip(): return {} with open(path, "r") as r_f: params = json.load(r_f) return params @staticmethod def parse_param_str(param_str): param_str = param_str or "" if not param_str.strip(): return {} name_value_pairs = param_str.split(";") pairs = [] for name_value_pair in name_value_pairs: if not name_value_pair.strip(): continue k_v = name_value_pair.split("=") if len(k_v) != 2: warnings.warn(f"{name_value_pair} error, will be ignore") continue key, value = name_value_pair.split("=") pairs.append((key.strip(), value.strip())) params = dict(pairs) return params @staticmethod def class_weight(value): try: value = eval(value) except: value = value return value def __str__(self): input_params_message = str(self.input_params) search_params_message = str(self.search_params) message = f"input_params: {input_params_message}\nsearch_params: {search_params_message}" return message __repr__ = __str__ class LightGBMParams(Params): ... class XGBoostParams(Params): @staticmethod def load_cls_default_params(cls): from xgboost import XGBModel params = Params.load_cls_default_params(XGBModel) params["objective"] = "binary:logistic" params["use_label_encoder"] = True return params class LrParams(Params): @staticmethod def load_cls_default_params(cls): from sklearn.linear_model import LogisticRegression params = deepcopy(inspect.getfullargspec(LogisticRegression).kwonlydefaults) params["penalty"] = "l2" return params class SVMParams(Params): @staticmethod def load_cls_default_params(cls): from sklearn.svm import SVC params = deepcopy(inspect.getfullargspec(SVC).kwonlydefaults) return params