Project-BasicAlgorithm/core/training/params.py (123 lines of code) (raw):
# Licensed to Apache Software Foundation (ASF) under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Apache Software Foundation (ASF) licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
import json
import warnings
from copy import deepcopy
class Params:
def __init__(
self, cls, param_file=None, param_str=None, search_params=None, **kwargs
):
input_params = self.parse_file(param_file)
str_input_params = self.parse_param_str(param_str)
input_params.update(str_input_params)
input_params.update(kwargs)
self.input_params = self.check_input_params(cls, input_params)
self.search_params = self.check_search_params(cls, search_params)
def check_input_params(self, cls, input_params):
# check and adaptive parameter type for input_params
default_params = self.load_cls_default_params(cls)
for key, value in input_params.items():
parse_func = getattr(self, key, None)
if parse_func:
value = parse_func(value)
elif key in default_params:
value = self.match_type(default_params[key], value)
input_params[key] = value
return input_params
def check_search_params(self, cls, search_params):
# check and adaptive parameter type for search_params
search_params = self.parse_param_str(search_params)
default_params = self.load_cls_default_params(cls)
for key, values in search_params.items():
try:
values = eval(values)
except Exception as _:
warnings.warn(f"value : {values} error, is must be list of something")
continue
parse_func = getattr(self, key, None)
new_values = []
for value in values:
if parse_func:
value = parse_func(value)
elif key in default_params:
value = self.match_type(default_params[key], value)
new_values.append(value)
search_params[key] = new_values
return search_params
@staticmethod
def load_cls_default_params(cls):
default_values = deepcopy(cls.__init__.__defaults__)
var_names = cls.__init__.__code__.co_varnames
var_names = [name for name in var_names if name not in {"self", "kwargs"}]
assert len(default_values) == len(var_names)
return dict(zip(var_names, default_values))
@staticmethod
def match_type(refer_var, input_var):
if isinstance(refer_var, bool):
tag = str(input_var).lower()
if tag not in {"true", "false"}:
warnings.warn(f"value : {input_var} error, set to False")
input_var = tag == "true"
elif refer_var is not None:
refer_type = type(refer_var)
input_var = refer_type(input_var)
return input_var
@staticmethod
def parse_file(path):
path = path or ""
if not path.strip():
return {}
with open(path, "r") as r_f:
params = json.load(r_f)
return params
@staticmethod
def parse_param_str(param_str):
param_str = param_str or ""
if not param_str.strip():
return {}
name_value_pairs = param_str.split(";")
pairs = []
for name_value_pair in name_value_pairs:
if not name_value_pair.strip():
continue
k_v = name_value_pair.split("=")
if len(k_v) != 2:
warnings.warn(f"{name_value_pair} error, will be ignore")
continue
key, value = name_value_pair.split("=")
pairs.append((key.strip(), value.strip()))
params = dict(pairs)
return params
@staticmethod
def class_weight(value):
try:
value = eval(value)
except:
value = value
return value
def __str__(self):
input_params_message = str(self.input_params)
search_params_message = str(self.search_params)
message = f"input_params: {input_params_message}\nsearch_params: {search_params_message}"
return message
__repr__ = __str__
class LightGBMParams(Params):
...
class XGBoostParams(Params):
@staticmethod
def load_cls_default_params(cls):
from xgboost import XGBModel
params = Params.load_cls_default_params(XGBModel)
params["objective"] = "binary:logistic"
params["use_label_encoder"] = True
return params
class LrParams(Params):
@staticmethod
def load_cls_default_params(cls):
from sklearn.linear_model import LogisticRegression
params = deepcopy(inspect.getfullargspec(LogisticRegression).kwonlydefaults)
params["penalty"] = "l2"
return params
class SVMParams(Params):
@staticmethod
def load_cls_default_params(cls):
from sklearn.svm import SVC
params = deepcopy(inspect.getfullargspec(SVC).kwonlydefaults)
return params