def train_model()

in Project-BasicAlgorithm/core/utils.py [0:0]


def train_model(model_cls, params, train_x, train_y):
    """
    train model directly, or train model with searching params
    """

    model = model_cls(**params.input_params)

    if params.search_params:
        optimized_model = GridSearchCV(estimator=model, param_grid=params.search_params)
        optimized_model.fit(train_x, train_y)
        model = optimized_model.best_estimator_
        params = optimized_model.cv_results_['params']
        mean_test_score = optimized_model.cv_results_['mean_test_score']
        for param, score in zip(params, mean_test_score):
            print(param, score)
    else:
        model.fit(train_x, train_y)
    return model