in Project-BasicAlgorithm/train.py [0:0]
def create_model_version(model_name, key_metrics=None, run_id=None, auto_replace=True):
client = mlflow.tracking.MlflowClient()
filter_string = "name='{}'".format(model_name)
versions = client.search_model_versions(filter_string)
if not versions:
client.create_registered_model(model_name)
for version in versions:
if version.current_stage == "Production":
client.transition_model_version_stage(
model_name, version=version.version, stage="Archived"
)
if run_id:
uri = f"runs:/{run_id}/sklearn_model"
mv = mlflow.register_model(uri, model_name)
if not key_metrics:
client.transition_model_version_stage(
model_name, version=mv.version, stage="Production"
)
logger.info("register last version to Production")
if key_metrics:
version2metrics = []
versions = client.search_model_versions(filter_string)
for version in versions:
metrics = client.get_run(version.run_id).data.metrics[key_metrics]
version2metrics.append((version.version, metrics))
logger.info(f"version2metrics({key_metrics}): {version2metrics}")
best_version = max(version2metrics, key=lambda x: x[1])[0]
logger.info("register version: %s to Production", best_version)
client.transition_model_version_stage(
model_name, version=best_version, stage="Production"
)
return versions