jobs/broken-site-report-ml/broken_site_report_ml/main.py (220 lines of code) (raw):
import click
import datetime
import logging
import requests
import time
from google.cloud import bigquery
BUGBUG_HTTP_SERVER = "https://bugbug.herokuapp.com"
CLASSIFICATION_LABELS = {0: "valid", 1: "invalid"}
def classification_http_request(url, reports):
reports_list = list(reports.values())
response = requests.post(
url, headers={"X-Api-Key": "docker-etl"}, json={"reports": reports_list}
)
response.raise_for_status()
return response.json()
def get_reports_classification(model, reports, retry_count=21, retry_sleep=10):
"""Get the classification for a list of reports.
Args:
model: The model to use for the classification.
reports: The dict containing reports to classify with uuid used as keys.
retry_count: The number of times to retry the request.
retry_sleep: The number of seconds to sleep between retries.
Returns:
A dictionary with the uuids as keys and classification results as values.
"""
if len(reports) == 0:
return {}
url = f"{BUGBUG_HTTP_SERVER}/{model}/predict/broken_site_report/batch"
json_response = {}
for _ in range(retry_count):
response = classification_http_request(url, reports)
# Check which reports are ready
for uuid, data in response["reports"].items():
if not data.get("ready", True):
continue
# The report is ready, add it to the json_response and pop it
# up from the current batch
reports.pop(uuid, None)
json_response[uuid] = data
if len(reports) == 0:
break
else:
time.sleep(retry_sleep)
else:
total_sleep = retry_count * retry_sleep
msg = f"Couldn't get {len(reports)} report classifications in {total_sleep} seconds, aborting" # noqa
logging.error(msg)
raise Exception(msg)
return json_response
def add_classification_results(client, bq_dataset_id, results):
res = []
for uuid, result in results.items():
bq_result = {
"report_uuid": uuid,
"label": CLASSIFICATION_LABELS[result["class"]],
"created_at": datetime.datetime.utcnow().isoformat(),
"probability": result["prob"][result["class"]],
}
res.append(bq_result)
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
schema=[
bigquery.SchemaField("report_uuid", "STRING", mode="REQUIRED"),
bigquery.SchemaField("label", "STRING", mode="REQUIRED"),
bigquery.SchemaField("created_at", "DATETIME", mode="REQUIRED"),
bigquery.SchemaField("probability", "FLOAT"),
],
write_disposition="WRITE_APPEND",
)
predictions_table = f"{bq_dataset_id}.bugbug_predictions"
job = client.load_table_from_json(
res,
predictions_table,
job_config=job_config,
)
logging.info("Writing to `bugbug_predictions` table")
try:
job.result()
except Exception as e:
print(f"ERROR: {e}")
if job.errors:
for error in job.errors:
logging.error(error)
table = client.get_table(predictions_table)
logging.info(f"Loaded {len(res)} rows into {table}")
def save_translations(client, bq_dataset_id, results):
res = []
for uuid, result in results.items():
if not result["status"]:
bq_result = {
"report_uuid": uuid,
"translated_text": result["translated_text"],
"language_code": result["language_code"],
}
res.append(bq_result)
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
schema=[
bigquery.SchemaField("report_uuid", "STRING", mode="REQUIRED"),
bigquery.SchemaField("translated_text", "STRING"),
bigquery.SchemaField("language_code", "STRING"),
],
write_disposition="WRITE_APPEND",
)
translations_table = f"{bq_dataset_id}.translations"
job = client.load_table_from_json(
res,
translations_table,
job_config=job_config,
)
logging.info("Writing to `translations` table")
try:
job.result()
except Exception as e:
print(f"ERROR: {e}")
if job.errors:
for error in job.errors:
logging.error(error)
table = client.get_table(translations_table)
logging.info(f"Loaded {len(res)} rows into {table}")
def record_classification_run(client, bq_dataset_id, is_ok, count):
rows_to_insert = [
{
"run_at": datetime.datetime.utcnow().isoformat(),
"is_ok": is_ok,
"report_count": count,
},
]
bugbug_runs_table = f"{bq_dataset_id}.bugbug_classification_runs"
errors = client.insert_rows_json(bugbug_runs_table, rows_to_insert)
if errors:
logging.error(errors)
else:
logging.info("Last classification run recorded")
def get_last_classification_datetime(client, bq_dataset_id):
query = f"""
SELECT MAX(run_at) AS last_run_at
FROM `{bq_dataset_id}.bugbug_classification_runs`
WHERE is_ok = TRUE
"""
res = client.query(query).result()
row = list(res)[0]
last_run_time = (
row["last_run_at"] if row["last_run_at"] is not None else "2023-11-20T00:00:00"
)
return last_run_time
def get_reports_since_last_run(client, last_run_time, bq_dataset_id):
query = f"""
SELECT
reports.uuid,
reports.comments as body,
COALESCE(reports.url, reports.uuid) as title,
translations.translated_text
FROM
`moz-fx-data-shared-prod.org_mozilla_broken_site_report.user_reports_live`
AS reports
LEFT JOIN `{bq_dataset_id}.translations` AS translations
ON reports.uuid = translations.report_uuid
WHERE reported_at >= "{last_run_time}" AND comments != ""
ORDER BY reported_at
"""
query_job = client.query(query)
return list(query_job.result())
def get_missed_reports(client, last_run_time, bq_dataset_id):
query = f"""
SELECT
reports.uuid,
reports.comments as body,
COALESCE(reports.url, reports.uuid) as title,
translations.translated_text
FROM
`moz-fx-data-shared-prod.org_mozilla_broken_site_report.user_reports_live`
AS reports
LEFT JOIN `{bq_dataset_id}.bugbug_predictions` AS predictions
ON reports.uuid = predictions.report_uuid
LEFT JOIN `{bq_dataset_id}.translations` AS translations
ON reports.uuid = translations.report_uuid
WHERE predictions.report_uuid IS NULL AND reported_at < "{last_run_time}"
AND reports.comments != ""
ORDER BY reports.reported_at
"""
query_job = client.query(query)
return list(query_job.result())
def translate_by_uuid(client, uuids, bq_dataset_id):
uuids_sql = ", ".join(f"'{uuid}'" for uuid in uuids)
query = f"""
WITH reports AS (
SELECT uuid, comments as text_content
FROM
`moz-fx-data-shared-prod.org_mozilla_broken_site_report.user_reports_live`
WHERE uuid IN ({uuids_sql})
)
SELECT
uuid,
STRING(
ml_translate_result.translations[0].detected_language_code
) AS language_code,
STRING(
ml_translate_result.translations[0].translated_text
) AS translated_text,
ml_translate_status as status
FROM
ML.TRANSLATE(
MODEL `{bq_dataset_id}.translation`,
TABLE reports,
STRUCT(
'translate_text' AS translate_mode,
'en' AS target_language_code
)
);
"""
query_job = client.query(query)
return list(query_job.result())
def translate_reports(client, reports, bq_dataset_id):
result = {}
# Only translate reports that weren't translated
uuids_to_translate = [d["uuid"] for d in reports if not d["translated_text"]]
if uuids_to_translate:
translation_results = translate_by_uuid(
client, uuids_to_translate, bq_dataset_id
)
result = {
result["uuid"]: {field: value for field, value in result.items()}
for result in translation_results
if not result["status"]
}
return result
def deduplicate_reports(reports):
seen = set()
return [
{field: value for field, value in report.items()}
for report in reports
if report["uuid"] not in seen and not seen.add(report["uuid"])
]
def chunk_list(data, size):
for i in range(0, len(data), size):
yield data[i : i + size]
@click.command()
@click.option("--bq_project_id", help="BigQuery project id", required=True)
@click.option("--bq_dataset_id", help="BigQuery dataset id", required=True)
def main(bq_project_id, bq_dataset_id):
client = bigquery.Client(project=bq_project_id)
# Get datetime of the last classification run
last_run_time = get_last_classification_datetime(client, bq_dataset_id)
# Get reports that were filed since last classification run
# and have non-empty descriptions as well as reports that were missed
new_reports = get_reports_since_last_run(client, last_run_time, bq_dataset_id)
missed_reports = get_missed_reports(client, last_run_time, bq_dataset_id)
combined = missed_reports + new_reports
deduplicated_combined = deduplicate_reports(combined)
translated = translate_reports(client, deduplicated_combined, bq_dataset_id)
if translated:
save_translations(client, bq_dataset_id, translated)
for report in deduplicated_combined:
if report["uuid"] in translated:
report["translated_text"] = translated[report["uuid"]]["translated_text"]
if not deduplicated_combined:
logging.info(
f"No new reports with filled descriptions were found since {last_run_time}"
)
return
result_count = 0
try:
for chunk in chunk_list(deduplicated_combined, 20):
objects_dict = {
row["uuid"]: {
"uuid": row["uuid"],
"title": row["title"],
"body": (
row["translated_text"]
if row.get("translated_text")
else row["body"]
),
}
for row in chunk
}
logging.info("Getting classification results from bugbug.")
result = get_reports_classification(
"invalidcompatibilityreport", objects_dict
)
if result:
result_count += len(result)
logging.info("Saving classification results to BQ.")
add_classification_results(client, bq_dataset_id, result)
record_classification_run(client, bq_dataset_id, True, len(result))
except Exception as e:
logging.error(e)
record_classification_run(client, bq_dataset_id, False, 0)
raise
finally:
logging.info(f"Total processed reports count: {result_count}")
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
main()