def run()

in assets/model_monitoring/components/src/action_analyzer/action_analyzer_identify_problem_traffic/run.py [0:0]


def run():
    """Identify problem traffic."""
    # Parse argument
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_with_groups", type=str)
    parser.add_argument("--signal_scored_data", type=str)
    parser.add_argument("--signal_output", type=str)
    parser.add_argument("--signal_name", type=str)
    parser.add_argument("--model_deployment_name", type=str, required=True)
    parser.add_argument("--workspace_connection_arm_id", type=str, required=True)
    parser.add_argument("--llm_summary_enabled", type=str)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument("--frequency_penalty", type=float, default=0.0)
    parser.add_argument("--presence_penalty", type=float, default=0.0)
    parser.add_argument("--stop", type=str, default=None)
    parser.add_argument("--api_call_retry_backoff_factor", type=int, default=4)
    parser.add_argument("--api_call_retry_max_count", type=int, default=10)
    args = parser.parse_args()

    # Skip action analyzer if no violated metrics
    violated_metrics = get_violated_metrics(args.signal_output, f"signals/{args.signal_name}")
    if violated_metrics == []:
        print("No violated metrics. No action will be generated.")
        save_empty_dataframe(get_output_schema(), args.data_with_groups)
        return

    print("Violated metrics found: ", violated_metrics)

    signal_scored_data_df = try_read_mltable_in_spark(args.signal_scored_data, "signal_scored_data")
    print("gsq output df")
    signal_scored_data_df.show()

    # Add topic_list, group_list, and violated_metrics columns in string type.
    # Different topics, groups, metrics will be appended to the string.
    # Schema:
    # +--------+-------+-------------+----------+---------+-------+------------------------+----------+----------+----------------+---------+  # noqa
    # |trace_id|span_id|root_question|completion|Coherence|Fluency|...(other metrics score)|topic_list|group_list|violated_metrics|root_span|  # noqa
    # +--------+-------+-------------+----------+---------+-------+------------------------+----------+----------+----------------+---------+  # noqa
    df = signal_scored_data_df.withColumn(TOPIC_LIST_COLUMN, lit("")) \
                              .withColumn(GROUP_LIST_COLUMN, lit("")) \
                              .withColumn(VIOLATED_METRICS_COLUMN, lit(""))
    # Rename to root question column
    df = df.withColumn(ROOT_QUESTION_COLUMN, col(PROMPT_COLUMN)).drop(PROMPT_COLUMN)

    for metrics in violated_metrics:
        print("======Current metrics=====")
        print(metrics)
        score_name = metrics
        # Add violated metrics for each query
        df = df.withColumn(VIOLATED_METRICS_COLUMN,
                           assign_violated_metrics(col(VIOLATED_METRICS_COLUMN), col(score_name), lit(metrics)))

        # Get the good and bad queries. Sample the good query in case the good query number is too large.
        pdf = df.toPandas()
        bad_answers = pdf[pdf[score_name] < METRICS_VIOLATION_THRESHOLD]
        good_answers = pdf[pdf[score_name] == GOOD_METRICS_VALUE]
        good_samples = good_answers.sample(n=min(len(bad_answers), len(good_answers)))
        bad_queries = bad_answers[ROOT_QUESTION_COLUMN].tolist()
        good_queries = good_samples[ROOT_QUESTION_COLUMN].tolist()
        print(f"Sample size for current metrics: {len(bad_queries)}")

        # Add default good group name and default bad group name for each query
        df = df.withColumn(GROUP_LIST_COLUMN, assign_default_group(col(GROUP_LIST_COLUMN),
                                                                   col(ROOT_QUESTION_COLUMN),
                                                                   lit(metrics),
                                                                   lit(json.dumps(good_queries)),
                                                                   lit(json.dumps(bad_queries))))

        # For bad queries, add semantic topic and separate in different semantic groups
        if len(bad_queries) < GROUP_TOPIC_MIN_SAMPLE_SIZE:
            # Skip grouing if the sample size is too small
            print(f"Bad sample size {len(bad_queries)} is less than {GROUP_TOPIC_MIN_SAMPLE_SIZE}. Skip grouping and set default topic.")  # noqa
            df = df.withColumn(TOPIC_LIST_COLUMN, assign_default_topic(col(TOPIC_LIST_COLUMN),
                                                                       lit(json.dumps(bad_queries)),
                                                                       col(ROOT_QUESTION_COLUMN)))
        else:
            print("Add semantic grouping for bad queries")
            topics_dict = bertopic_get_topic(bad_queries,
                                             args.workspace_connection_arm_id,
                                             args.model_deployment_name)

            topic_group_columns = assign_bad_topic_and_group(col(TOPIC_LIST_COLUMN),
                                                             col(GROUP_LIST_COLUMN),
                                                             col(ROOT_QUESTION_COLUMN),
                                                             col(VIOLATED_METRICS_COLUMN),
                                                             lit(metrics),
                                                             lit(json.dumps(topics_dict)),
                                                             lit(args.llm_summary_enabled))
            df = df.withColumn(TOPIC_LIST_COLUMN, topic_group_columns[0])
            df = df.withColumn(GROUP_LIST_COLUMN, topic_group_columns[1])

        # For good queries, only add semantic topic (no grouping)
        print("Add semantic topic for good queries")
        if len(good_queries) < GROUP_TOPIC_MIN_SAMPLE_SIZE:
            # Skip grouing if the sample size is too small
            print(f"Good sample size {len(good_queries)} is less than {GROUP_TOPIC_MIN_SAMPLE_SIZE}. Skip grouping and set default topic.")  # noqa
            df = df.withColumn(TOPIC_LIST_COLUMN, assign_default_topic(col(TOPIC_LIST_COLUMN),
                                                                       lit(json.dumps(good_queries)),
                                                                       col(ROOT_QUESTION_COLUMN)))
        else:
            topics_dict = bertopic_get_topic(good_queries,
                                             args.workspace_connection_arm_id,
                                             args.model_deployment_name)

            df = df.withColumn(TOPIC_LIST_COLUMN, assign_good_topic(col(TOPIC_LIST_COLUMN),
                                                                    col(ROOT_QUESTION_COLUMN),
                                                                    col(score_name),
                                                                    lit(json.dumps(topics_dict)),
                                                                    lit(args.llm_summary_enabled)))

    # Take the sampled df with assigned topic and group
    sampled_df = df.filter(col(TOPIC_LIST_COLUMN) != "")
    sampled_df.show()

    # convert the df from trace level to span level and extract retriveal info in below schema
    # Schema:
    # +--------+-------+-------------+----------+---------+-------+----------+----------+----------------+---------+------+-------------+-------------------------+  # noqa
    # |trace_id|span_id|root_question|completion|Coherence|Fluency|topic_list|group_list|violated_metrics|root_span|prompt|index_content|...(other retreival info)|  # noqa
    # +--------+-------+-------------+----------+---------+-------+----------+----------+----------------+---------+------+-------------+-------------------------+  # noqa
    span_level_df = convert_to_retrieval_span_df(sampled_df)
    save_spark_df_as_mltable(span_level_df, args.data_with_groups)