def create_model_version()

in Project-AutoML/train.py [0:0]


def create_model_version(model_name, key_metrics=None, run_id=None, auto_replace=True):
    client = mlflow.tracking.MlflowClient()
    filter_string = "name='{}'".format(model_name)
    versions = client.search_model_versions(filter_string)

    if not versions:
        client.create_registered_model(model_name)

    for version in versions:
        if version.current_stage == "Production":
            client.transition_model_version_stage(
                model_name, version=version.version, stage="Archived"
            )

    if run_id:
        uri = f"runs:/{run_id}/{ARTIFACT_TAG}"
        mv = mlflow.register_model(uri, model_name)

        if not key_metrics:
            client.transition_model_version_stage(
                model_name, version=mv.version, stage="Production"
            )
            logger.info("register last version to Production")

    if key_metrics:
        version2metrics = []
        versions = client.search_model_versions(filter_string)
        for version in versions:
            metrics = client.get_run(version.run_id).data.metrics[key_metrics]
            version2metrics.append((version.version, metrics))

        logger.info(f"version2metrics({key_metrics}): {version2metrics}")

        best_version = max(version2metrics, key=lambda x: x[1])[0]

        logger.info("register version: %s to Production", best_version)
        client.transition_model_version_stage(
            model_name, version=best_version, stage="Production"
        )

    return versions