def add_classification_results()

in jobs/broken-site-report-ml/broken_site_report_ml/main.py [0:0]


def add_classification_results(client, bq_dataset_id, results):
    res = []
    for uuid, result in results.items():
        bq_result = {
            "report_uuid": uuid,
            "label": CLASSIFICATION_LABELS[result["class"]],
            "created_at": datetime.datetime.utcnow().isoformat(),
            "probability": result["prob"][result["class"]],
        }
        res.append(bq_result)

    job_config = bigquery.LoadJobConfig(
        source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
        schema=[
            bigquery.SchemaField("report_uuid", "STRING", mode="REQUIRED"),
            bigquery.SchemaField("label", "STRING", mode="REQUIRED"),
            bigquery.SchemaField("created_at", "DATETIME", mode="REQUIRED"),
            bigquery.SchemaField("probability", "FLOAT"),
        ],
        write_disposition="WRITE_APPEND",
    )

    predictions_table = f"{bq_dataset_id}.bugbug_predictions"

    job = client.load_table_from_json(
        res,
        predictions_table,
        job_config=job_config,
    )

    logging.info("Writing to `bugbug_predictions` table")

    try:
        job.result()
    except Exception as e:
        print(f"ERROR: {e}")
        if job.errors:
            for error in job.errors:
                logging.error(error)

    table = client.get_table(predictions_table)
    logging.info(f"Loaded {len(res)} rows into {table}")