def main()

in Project-BasicAlgorithm/train.py [0:0]


def main(algorithm, data_path, label_column, model_name, random_state, param_file, params, search_params):

    train_x, train_y, test_x, test_y = load_data(
        data_path, label_column, random_state=random_state
    )
    training_func = get_training_func(algorithm)

    with mlflow.start_run() as run:
        model, metrics = training_func(train_x,
                                       train_y,
                                       test_x,
                                       test_y,
                                       param_file=param_file,
                                       params=params,
                                       search_params=search_params,
                                       )
        print(metrics)
        mlflow.log_metrics(metrics)
        mlflow.sklearn.log_model(model, artifact_path="sklearn_model")

    if model_name:
        create_model_version(
            model_name, key_metrics='f1-score', run_id=run.info.run_id)