def run()

in assets/model_monitoring/components/src/token_statistics_compute_metrics/run.py [0:0]


def run():
    """Execute the main function."""
    # Init spark session
    spark = init_spark()

    # Parse argument
    parser = argparse.ArgumentParser()
    parser.add_argument("--token_dataset", type=str)
    parser.add_argument("--group_pivot_column_name", type=str, required=False, nargs="?")
    parser.add_argument("--signal_metrics", type=str)

    args = parser.parse_args()
    token_df = try_read_mltable_in_spark_with_error(args.token_dataset, "token_dataset")
    group_pivot_column_name = args.group_pivot_column_name

    # failed calls dont have an id, so imputing them
    token_df = impute_ids_for_failed_calls(token_df)

    # basic data quality checks and filter out rows that dont meet the quality criteria
    token_df = check_data_quality(token_df)

    # designate the group column
    token_df = token_df.withColumnRenamed("node_id", "group")

    #  designate the group_pivot column
    #  if group_pivot_column_name is null, then add a column with value Not Provided
    if group_pivot_column_name is None:
        token_df = token_df.withColumn("group_pivot", lit('').cast("string"))
    else:
        token_df = token_df.withColumnRenamed(group_pivot_column_name, "group_pivot")

    # total tokens = prompt tokens + completion tokens
    token_df = token_df.withColumn("total_tokens",
                                   token_df["prompt_tokens"] + token_df["completion_tokens"])

    # compute GPU utilization metrics for group
    # create a copy of token_df where group_pivot is set to Aggregate
    token_df_aggregate_group_pivot = token_df.withColumn("group_pivot", lit('Aggregate').cast("string"))
    # compute GPU utilization metrics for group
    gpu_utilization_metrics_group_only = compute_GPU_utilization_metrics(token_df_aggregate_group_pivot)
    # compute GPU waste metrics for group
    # These metrics are computed if we have max_tokens and finish_reason in the dataset
    gpu_waste_metrics_group_only = spark.createDataFrame([], gpu_utilization_metrics_group_only.schema)
    if ("finish_reason" in token_df_aggregate_group_pivot.columns)\
            and ("max_tokens" in token_df_aggregate_group_pivot.columns):
        gpu_waste_metrics_group_only = compute_GPU_waste_metrics(token_df_aggregate_group_pivot)

    gpu_utilization_metrics = spark.createDataFrame([], gpu_utilization_metrics_group_only.schema)
    gpu_waste_metrics = spark.createDataFrame([], gpu_utilization_metrics_group_only.schema)

    # if group_pivot_column_name is not null, then compute metrics for group_pivot as well
    if (group_pivot_column_name is not None):
        # compute GPU utilization metrics for group and group_pivot
        gpu_utilization_metrics = compute_GPU_utilization_metrics(token_df)
        # compute GPU waste metrics for group and group_pivot
        # These metrics are computed if we have max_tokens and finish_reason in the dataset
        if ("finish_reason" in token_df.columns) and ("max_tokens" in token_df.columns):
            gpu_waste_metrics = compute_GPU_waste_metrics(token_df)

    # Union the metrics
    GPU_token_stats_metrics = gpu_utilization_metrics.unionAll(gpu_waste_metrics)\
        .unionAll(gpu_utilization_metrics_group_only)\
        .unionAll(gpu_waste_metrics_group_only)

    # Add threshold_value column and set the value to null
    GPU_token_stats_metrics = GPU_token_stats_metrics.withColumn("threshold_value", lit(None).cast("float"))

    # Save metrics in default blob store and log it in active run
    save_spark_df_as_mltable(GPU_token_stats_metrics, args.signal_metrics)