Project-AutoML/automl/params.py (52 lines of code) (raw):

# Licensed to Apache Software Foundation (ASF) under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Apache Software Foundation (ASF) licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import json import warnings from logging import getLogger logger = getLogger(__name__) class Params: def __init__(self, param_file=None, param_str=None, **kwargs): input_params = self.parse_file(param_file) str_input_params = self.parse_param_str(param_str) input_params.update(str_input_params) input_params.update(kwargs) self.input_params = self.check_input_params(input_params) logger.info(f"params : {self}") def check_input_params(self, input_params): # check and adaptive parameter type for input_params for key, value in input_params.items(): try: value = eval(value, {}, {}) except Exception as e: value = value continue input_params[key] = value return input_params @staticmethod def parse_file(path): path = path or "" if not path.strip(): return {} with open(path, "r") as r_f: params = json.load(r_f) return params @staticmethod def parse_param_str(param_str): param_str = param_str or "" if not param_str.strip(): return {} name_value_pairs = param_str.split(";") pairs = [] for name_value_pair in name_value_pairs: if not name_value_pair.strip(): continue k_v = name_value_pair.split("=") if len(k_v) != 2: warnings.warn(f"{name_value_pair} error, will be ignore") continue key, value = name_value_pair.split("=") pairs.append((key.strip(), value.strip())) params = dict(pairs) return params def __str__(self): input_params_message = str(self.input_params) message = f"input_params: {input_params_message}" return message __repr__ = __str__