jobs/broken-site-report-ml/broken_site_report_ml/main.py (220 lines of code) (raw):

import click import datetime import logging import requests import time from google.cloud import bigquery BUGBUG_HTTP_SERVER = "https://bugbug.herokuapp.com" CLASSIFICATION_LABELS = {0: "valid", 1: "invalid"} def classification_http_request(url, reports): reports_list = list(reports.values()) response = requests.post( url, headers={"X-Api-Key": "docker-etl"}, json={"reports": reports_list} ) response.raise_for_status() return response.json() def get_reports_classification(model, reports, retry_count=21, retry_sleep=10): """Get the classification for a list of reports. Args: model: The model to use for the classification. reports: The dict containing reports to classify with uuid used as keys. retry_count: The number of times to retry the request. retry_sleep: The number of seconds to sleep between retries. Returns: A dictionary with the uuids as keys and classification results as values. """ if len(reports) == 0: return {} url = f"{BUGBUG_HTTP_SERVER}/{model}/predict/broken_site_report/batch" json_response = {} for _ in range(retry_count): response = classification_http_request(url, reports) # Check which reports are ready for uuid, data in response["reports"].items(): if not data.get("ready", True): continue # The report is ready, add it to the json_response and pop it # up from the current batch reports.pop(uuid, None) json_response[uuid] = data if len(reports) == 0: break else: time.sleep(retry_sleep) else: total_sleep = retry_count * retry_sleep msg = f"Couldn't get {len(reports)} report classifications in {total_sleep} seconds, aborting" # noqa logging.error(msg) raise Exception(msg) return json_response def add_classification_results(client, bq_dataset_id, results): res = [] for uuid, result in results.items(): bq_result = { "report_uuid": uuid, "label": CLASSIFICATION_LABELS[result["class"]], "created_at": datetime.datetime.utcnow().isoformat(), "probability": result["prob"][result["class"]], } res.append(bq_result) job_config = bigquery.LoadJobConfig( source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, schema=[ bigquery.SchemaField("report_uuid", "STRING", mode="REQUIRED"), bigquery.SchemaField("label", "STRING", mode="REQUIRED"), bigquery.SchemaField("created_at", "DATETIME", mode="REQUIRED"), bigquery.SchemaField("probability", "FLOAT"), ], write_disposition="WRITE_APPEND", ) predictions_table = f"{bq_dataset_id}.bugbug_predictions" job = client.load_table_from_json( res, predictions_table, job_config=job_config, ) logging.info("Writing to `bugbug_predictions` table") try: job.result() except Exception as e: print(f"ERROR: {e}") if job.errors: for error in job.errors: logging.error(error) table = client.get_table(predictions_table) logging.info(f"Loaded {len(res)} rows into {table}") def save_translations(client, bq_dataset_id, results): res = [] for uuid, result in results.items(): if not result["status"]: bq_result = { "report_uuid": uuid, "translated_text": result["translated_text"], "language_code": result["language_code"], } res.append(bq_result) job_config = bigquery.LoadJobConfig( source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, schema=[ bigquery.SchemaField("report_uuid", "STRING", mode="REQUIRED"), bigquery.SchemaField("translated_text", "STRING"), bigquery.SchemaField("language_code", "STRING"), ], write_disposition="WRITE_APPEND", ) translations_table = f"{bq_dataset_id}.translations" job = client.load_table_from_json( res, translations_table, job_config=job_config, ) logging.info("Writing to `translations` table") try: job.result() except Exception as e: print(f"ERROR: {e}") if job.errors: for error in job.errors: logging.error(error) table = client.get_table(translations_table) logging.info(f"Loaded {len(res)} rows into {table}") def record_classification_run(client, bq_dataset_id, is_ok, count): rows_to_insert = [ { "run_at": datetime.datetime.utcnow().isoformat(), "is_ok": is_ok, "report_count": count, }, ] bugbug_runs_table = f"{bq_dataset_id}.bugbug_classification_runs" errors = client.insert_rows_json(bugbug_runs_table, rows_to_insert) if errors: logging.error(errors) else: logging.info("Last classification run recorded") def get_last_classification_datetime(client, bq_dataset_id): query = f""" SELECT MAX(run_at) AS last_run_at FROM `{bq_dataset_id}.bugbug_classification_runs` WHERE is_ok = TRUE """ res = client.query(query).result() row = list(res)[0] last_run_time = ( row["last_run_at"] if row["last_run_at"] is not None else "2023-11-20T00:00:00" ) return last_run_time def get_reports_since_last_run(client, last_run_time, bq_dataset_id): query = f""" SELECT reports.uuid, reports.comments as body, COALESCE(reports.url, reports.uuid) as title, translations.translated_text FROM `moz-fx-data-shared-prod.org_mozilla_broken_site_report.user_reports_live` AS reports LEFT JOIN `{bq_dataset_id}.translations` AS translations ON reports.uuid = translations.report_uuid WHERE reported_at >= "{last_run_time}" AND comments != "" ORDER BY reported_at """ query_job = client.query(query) return list(query_job.result()) def get_missed_reports(client, last_run_time, bq_dataset_id): query = f""" SELECT reports.uuid, reports.comments as body, COALESCE(reports.url, reports.uuid) as title, translations.translated_text FROM `moz-fx-data-shared-prod.org_mozilla_broken_site_report.user_reports_live` AS reports LEFT JOIN `{bq_dataset_id}.bugbug_predictions` AS predictions ON reports.uuid = predictions.report_uuid LEFT JOIN `{bq_dataset_id}.translations` AS translations ON reports.uuid = translations.report_uuid WHERE predictions.report_uuid IS NULL AND reported_at < "{last_run_time}" AND reports.comments != "" ORDER BY reports.reported_at """ query_job = client.query(query) return list(query_job.result()) def translate_by_uuid(client, uuids, bq_dataset_id): uuids_sql = ", ".join(f"'{uuid}'" for uuid in uuids) query = f""" WITH reports AS ( SELECT uuid, comments as text_content FROM `moz-fx-data-shared-prod.org_mozilla_broken_site_report.user_reports_live` WHERE uuid IN ({uuids_sql}) ) SELECT uuid, STRING( ml_translate_result.translations[0].detected_language_code ) AS language_code, STRING( ml_translate_result.translations[0].translated_text ) AS translated_text, ml_translate_status as status FROM ML.TRANSLATE( MODEL `{bq_dataset_id}.translation`, TABLE reports, STRUCT( 'translate_text' AS translate_mode, 'en' AS target_language_code ) ); """ query_job = client.query(query) return list(query_job.result()) def translate_reports(client, reports, bq_dataset_id): result = {} # Only translate reports that weren't translated uuids_to_translate = [d["uuid"] for d in reports if not d["translated_text"]] if uuids_to_translate: translation_results = translate_by_uuid( client, uuids_to_translate, bq_dataset_id ) result = { result["uuid"]: {field: value for field, value in result.items()} for result in translation_results if not result["status"] } return result def deduplicate_reports(reports): seen = set() return [ {field: value for field, value in report.items()} for report in reports if report["uuid"] not in seen and not seen.add(report["uuid"]) ] def chunk_list(data, size): for i in range(0, len(data), size): yield data[i : i + size] @click.command() @click.option("--bq_project_id", help="BigQuery project id", required=True) @click.option("--bq_dataset_id", help="BigQuery dataset id", required=True) def main(bq_project_id, bq_dataset_id): client = bigquery.Client(project=bq_project_id) # Get datetime of the last classification run last_run_time = get_last_classification_datetime(client, bq_dataset_id) # Get reports that were filed since last classification run # and have non-empty descriptions as well as reports that were missed new_reports = get_reports_since_last_run(client, last_run_time, bq_dataset_id) missed_reports = get_missed_reports(client, last_run_time, bq_dataset_id) combined = missed_reports + new_reports deduplicated_combined = deduplicate_reports(combined) translated = translate_reports(client, deduplicated_combined, bq_dataset_id) if translated: save_translations(client, bq_dataset_id, translated) for report in deduplicated_combined: if report["uuid"] in translated: report["translated_text"] = translated[report["uuid"]]["translated_text"] if not deduplicated_combined: logging.info( f"No new reports with filled descriptions were found since {last_run_time}" ) return result_count = 0 try: for chunk in chunk_list(deduplicated_combined, 20): objects_dict = { row["uuid"]: { "uuid": row["uuid"], "title": row["title"], "body": ( row["translated_text"] if row.get("translated_text") else row["body"] ), } for row in chunk } logging.info("Getting classification results from bugbug.") result = get_reports_classification( "invalidcompatibilityreport", objects_dict ) if result: result_count += len(result) logging.info("Saving classification results to BQ.") add_classification_results(client, bq_dataset_id, result) record_classification_run(client, bq_dataset_id, True, len(result)) except Exception as e: logging.error(e) record_classification_run(client, bq_dataset_id, False, 0) raise finally: logging.info(f"Total processed reports count: {result_count}") if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) main()