def bq_select_best_kmeans_model()

in python/pipelines/components/bigquery/component.py [0:0]


def bq_select_best_kmeans_model(
        project_id: str,
        location: str,
        dataset_id: str,
        model_prefix: str,
        metric_name: str,
        metric_threshold: float,
        number_of_models_considered: int,
        metrics_logger: Output[Metrics],
        elected_model: Output[Artifact]) -> None:
    
    """Selects the best KMeans model from a set of models based on a given metric.

    Args:
        project_id: The project ID of the models.
        location: The location of the models.
        dataset_id: The dataset ID of the models.
        model_prefix: The prefix of the model IDs.
        metric_name: The name of the metric to use for comparison.
        metric_threshold: The minimum value of the metric that is acceptable.
        number_of_models_considered: The number of models to consider.
        metrics_logger: The output artifact to log the metrics of the selected model.
        elected_model: The output artifact to store the metadata of the selected model.
    """

    from google.cloud import bigquery
    import logging
    from enum import Enum

    from google.api_core.gapic_v1.client_info import ClientInfo

    USER_AGENT_FEATURES = 'cloud-solutions/marketing-analytics-jumpstart-features-v1'
    USER_AGENT_PROPENSITY_TRAINING = 'cloud-solutions/marketing-analytics-jumpstart-propensity-training-v1'
    USER_AGENT_PROPENSITY_PREDICTION= 'cloud-solutions/marketing-analytics-jumpstart-propensity-prediction-v1'
    USER_AGENT_REGRESSION_TRAINING = 'cloud-solutions/marketing-analytics-jumpstart-regression-training-v1'
    USER_AGENT_REGRESSION_PREDICTION = 'cloud-solutions/marketing-analytics-jumpstart-regression-prediction-v1'
    USER_AGENT_SEGMENTATION_TRAINING = 'cloud-solutions/marketing-analytics-jumpstart-segmentation-training-v1'
    USER_AGENT_SEGMENTATION_PREDICTION = 'cloud-solutions/marketing-analytics-jumpstart-segmentation-prediction-v1'
    USER_AGENT_VBB_TRAINING = 'cloud-solutions/marketing-analytics-jumpstart-vbb-training-v1'
    USER_AGENT_VBB_EXPLANATION = 'cloud-solutions/marketing-analytics-jumpstart-vbb-explanation-v1'


    class MetricsEnum(Enum):
        DAVIES_BOULDIN_INDEX = 'davies_bouldin_index'
        MEAN_SQUARED_DISCTANCE = 'mean_squared_distance'

        def is_new_metric_better(self, new_value: float, old_value: float):
            return new_value < old_value

        @classmethod
        def list(cls):
            return list(map(lambda c: c.value, cls))

    # Construct a BigQuery client object.
    client = bigquery.Client(
        project=project_id,
        location=location,
        client_info=ClientInfo(user_agent=USER_AGENT_SEGMENTATION_PREDICTION)
    )

    # TODO(developer): Set dataset_id to the ID of the dataset that contains
    #                  the models you are listing.
    # dataset_id = 'your-project.your_dataset'

    logging.info(f"Getting models from: {project_id}.{dataset_id}")
    models = client.list_models(f"{dataset_id}")  # Make an API request.

    models_to_compare = []
    counter = 0
    for model in models:
        if model.model_id.startswith(model_prefix):
            # logging.info(f"{model.model_id} - {model.created}")
            if (counter < number_of_models_considered):
                models_to_compare.append(model)
                counter += 1
            else:
                canditate = model
                for idx, m in enumerate(models_to_compare):
                    # checks if current canditate is newer than one already in list
                    if canditate.created.timestamp() > m.created.timestamp():
                        tmp = m
                        models_to_compare[idx] = canditate
                        canditate = tmp

            # logging.info(f"{counter} {models_to_compare}")

    if len(models_to_compare) == 0:
        raise Exception(f"No models in vertex model registry match '{model_prefix}'")

    best_model = dict()
    best_eval_metrics = dict()
    for i in models_to_compare:
        logging.info(i.path)
        model_bq_name = f"{i.project}.{i.dataset_id}.{i.model_id}"
        query = f"""
            SELECT * FROM ML.EVALUATE(MODEL `{model_bq_name}`)
        """
        query_job = client.query(
            query=query,
            location=location
        )

        r = list(query_job.result())[0]

        logging.info(f"keys {r.keys()}")
        logging.info(f"{metric_name} {r.get(metric_name)}")

        if (metric_name not in best_model) or MetricsEnum(metric_name).is_new_metric_better(r.get(metric_name), best_model[metric_name]):

            for k in r.keys():
                best_eval_metrics[k] = r.get(k)

            best_model[metric_name] = r.get(metric_name)
            best_model["resource_name"] = i.path
            best_model["uri"] = model_bq_name
            logging.info(
                f"New Model/Version elected | name: {model_bq_name} | metric name: {metric_name} | metric value: {best_model[metric_name]} ")

    if MetricsEnum(metric_name).is_new_metric_better(metric_threshold, best_model[metric_name]):
        raise ValueError(
            f"Model evaluation metric {metric_name} of value {best_model[metric_name]} does not meet minumum criteria of threshold{metric_threshold}")

    for k, v in best_eval_metrics.items():
        if k in MetricsEnum.list():
            metrics_logger.log_metric(k, v)

    # elected_model.uri = f"bq://{best_model['uri']}"
    elected_model.metadata = best_model
    pId, dId, mId = best_model['uri'].split('.')
    elected_model.metadata = {
        "projectId": pId,
        "datasetId": dId,
        "modelId": mId,
        "resourceName": best_model["resource_name"]}