in assets/model_monitoring/components/src/action_analyzer/action_analyzer_identify_problem_traffic/run.py [0:0]
def run():
"""Identify problem traffic."""
# Parse argument
parser = argparse.ArgumentParser()
parser.add_argument("--data_with_groups", type=str)
parser.add_argument("--signal_scored_data", type=str)
parser.add_argument("--signal_output", type=str)
parser.add_argument("--signal_name", type=str)
parser.add_argument("--model_deployment_name", type=str, required=True)
parser.add_argument("--workspace_connection_arm_id", type=str, required=True)
parser.add_argument("--llm_summary_enabled", type=str)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--frequency_penalty", type=float, default=0.0)
parser.add_argument("--presence_penalty", type=float, default=0.0)
parser.add_argument("--stop", type=str, default=None)
parser.add_argument("--api_call_retry_backoff_factor", type=int, default=4)
parser.add_argument("--api_call_retry_max_count", type=int, default=10)
args = parser.parse_args()
# Skip action analyzer if no violated metrics
violated_metrics = get_violated_metrics(args.signal_output, f"signals/{args.signal_name}")
if violated_metrics == []:
print("No violated metrics. No action will be generated.")
save_empty_dataframe(get_output_schema(), args.data_with_groups)
return
print("Violated metrics found: ", violated_metrics)
signal_scored_data_df = try_read_mltable_in_spark(args.signal_scored_data, "signal_scored_data")
print("gsq output df")
signal_scored_data_df.show()
# Add topic_list, group_list, and violated_metrics columns in string type.
# Different topics, groups, metrics will be appended to the string.
# Schema:
# +--------+-------+-------------+----------+---------+-------+------------------------+----------+----------+----------------+---------+ # noqa
# |trace_id|span_id|root_question|completion|Coherence|Fluency|...(other metrics score)|topic_list|group_list|violated_metrics|root_span| # noqa
# +--------+-------+-------------+----------+---------+-------+------------------------+----------+----------+----------------+---------+ # noqa
df = signal_scored_data_df.withColumn(TOPIC_LIST_COLUMN, lit("")) \
.withColumn(GROUP_LIST_COLUMN, lit("")) \
.withColumn(VIOLATED_METRICS_COLUMN, lit(""))
# Rename to root question column
df = df.withColumn(ROOT_QUESTION_COLUMN, col(PROMPT_COLUMN)).drop(PROMPT_COLUMN)
for metrics in violated_metrics:
print("======Current metrics=====")
print(metrics)
score_name = metrics
# Add violated metrics for each query
df = df.withColumn(VIOLATED_METRICS_COLUMN,
assign_violated_metrics(col(VIOLATED_METRICS_COLUMN), col(score_name), lit(metrics)))
# Get the good and bad queries. Sample the good query in case the good query number is too large.
pdf = df.toPandas()
bad_answers = pdf[pdf[score_name] < METRICS_VIOLATION_THRESHOLD]
good_answers = pdf[pdf[score_name] == GOOD_METRICS_VALUE]
good_samples = good_answers.sample(n=min(len(bad_answers), len(good_answers)))
bad_queries = bad_answers[ROOT_QUESTION_COLUMN].tolist()
good_queries = good_samples[ROOT_QUESTION_COLUMN].tolist()
print(f"Sample size for current metrics: {len(bad_queries)}")
# Add default good group name and default bad group name for each query
df = df.withColumn(GROUP_LIST_COLUMN, assign_default_group(col(GROUP_LIST_COLUMN),
col(ROOT_QUESTION_COLUMN),
lit(metrics),
lit(json.dumps(good_queries)),
lit(json.dumps(bad_queries))))
# For bad queries, add semantic topic and separate in different semantic groups
if len(bad_queries) < GROUP_TOPIC_MIN_SAMPLE_SIZE:
# Skip grouing if the sample size is too small
print(f"Bad sample size {len(bad_queries)} is less than {GROUP_TOPIC_MIN_SAMPLE_SIZE}. Skip grouping and set default topic.") # noqa
df = df.withColumn(TOPIC_LIST_COLUMN, assign_default_topic(col(TOPIC_LIST_COLUMN),
lit(json.dumps(bad_queries)),
col(ROOT_QUESTION_COLUMN)))
else:
print("Add semantic grouping for bad queries")
topics_dict = bertopic_get_topic(bad_queries,
args.workspace_connection_arm_id,
args.model_deployment_name)
topic_group_columns = assign_bad_topic_and_group(col(TOPIC_LIST_COLUMN),
col(GROUP_LIST_COLUMN),
col(ROOT_QUESTION_COLUMN),
col(VIOLATED_METRICS_COLUMN),
lit(metrics),
lit(json.dumps(topics_dict)),
lit(args.llm_summary_enabled))
df = df.withColumn(TOPIC_LIST_COLUMN, topic_group_columns[0])
df = df.withColumn(GROUP_LIST_COLUMN, topic_group_columns[1])
# For good queries, only add semantic topic (no grouping)
print("Add semantic topic for good queries")
if len(good_queries) < GROUP_TOPIC_MIN_SAMPLE_SIZE:
# Skip grouing if the sample size is too small
print(f"Good sample size {len(good_queries)} is less than {GROUP_TOPIC_MIN_SAMPLE_SIZE}. Skip grouping and set default topic.") # noqa
df = df.withColumn(TOPIC_LIST_COLUMN, assign_default_topic(col(TOPIC_LIST_COLUMN),
lit(json.dumps(good_queries)),
col(ROOT_QUESTION_COLUMN)))
else:
topics_dict = bertopic_get_topic(good_queries,
args.workspace_connection_arm_id,
args.model_deployment_name)
df = df.withColumn(TOPIC_LIST_COLUMN, assign_good_topic(col(TOPIC_LIST_COLUMN),
col(ROOT_QUESTION_COLUMN),
col(score_name),
lit(json.dumps(topics_dict)),
lit(args.llm_summary_enabled)))
# Take the sampled df with assigned topic and group
sampled_df = df.filter(col(TOPIC_LIST_COLUMN) != "")
sampled_df.show()
# convert the df from trace level to span level and extract retriveal info in below schema
# Schema:
# +--------+-------+-------------+----------+---------+-------+----------+----------+----------------+---------+------+-------------+-------------------------+ # noqa
# |trace_id|span_id|root_question|completion|Coherence|Fluency|topic_list|group_list|violated_metrics|root_span|prompt|index_content|...(other retreival info)| # noqa
# +--------+-------+-------------+----------+---------+-------+----------+----------+----------------+---------+------+-------------+-------------------------+ # noqa
span_level_df = convert_to_retrieval_span_df(sampled_df)
save_spark_df_as_mltable(span_level_df, args.data_with_groups)