in Project-BasicAlgorithm/train.py [0:0]
def main(algorithm, data_path, label_column, model_name, random_state, param_file, params, search_params):
train_x, train_y, test_x, test_y = load_data(
data_path, label_column, random_state=random_state
)
training_func = get_training_func(algorithm)
with mlflow.start_run() as run:
model, metrics = training_func(train_x,
train_y,
test_x,
test_y,
param_file=param_file,
params=params,
search_params=search_params,
)
print(metrics)
mlflow.log_metrics(metrics)
mlflow.sklearn.log_model(model, artifact_path="sklearn_model")
if model_name:
create_model_version(
model_name, key_metrics='f1-score', run_id=run.info.run_id)