jobs/webcompat-kb/webcompat_kb/bqhelpers.py (175 lines of code) (raw):

import logging import uuid from dataclasses import dataclass from types import TracebackType from typing import Any, Iterable, Mapping, Optional, Self, Sequence, cast import google.auth from google.cloud import bigquery Json = Mapping[str, "Json"] | Sequence["Json"] | str | int | float | bool | None def get_client(bq_project_id: str) -> bigquery.Client: credentials, _ = google.auth.default( scopes=[ "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/bigquery", ] ) return bigquery.Client(credentials=credentials, project=bq_project_id) @dataclass class RangePartition: field: str start: int end: int interval: int = 1 class BigQuery: def __init__(self, client: bigquery.Client, default_dataset_id: str, write: bool): self.client = client self.default_dataset_id = default_dataset_id self.write = write def get_dataset(self, dataset_id: Optional[str]) -> str: if dataset_id is None: return self.default_dataset_id return dataset_id def get_table_id( self, dataset_id: Optional[str], table: bigquery.Table | str ) -> bigquery.Table | str: if isinstance(table, bigquery.Table): return table if "." in table: return table dataset_id = self.get_dataset(dataset_id) return f"{self.client.project}.{dataset_id}.{table}" def ensure_table( self, table_id: str, schema: Iterable[bigquery.SchemaField], recreate: bool = False, dataset_id: Optional[str] = None, partition: Optional[RangePartition] = None, ) -> bigquery.Table: table = bigquery.Table(self.get_table_id(dataset_id, table_id), schema=schema) if partition: table.range_partitioning = bigquery.table.RangePartitioning( bigquery.table.PartitionRange( partition.start, partition.end, partition.interval ), field=partition.field, ) if self.write: if recreate: self.client.delete_table(table, not_found_ok=True) self.client.create_table(table, exists_ok=True) return table def write_table( self, table: bigquery.Table | str, schema: list[bigquery.SchemaField], rows: Sequence[Mapping[str, Json]], overwrite: bool, dataset_id: Optional[str] = None, ) -> None: table = self.get_table_id(dataset_id, table) job_config = bigquery.LoadJobConfig( source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, schema=schema, write_disposition="WRITE_APPEND" if not overwrite else "WRITE_TRUNCATE", ) if self.write: job = self.client.load_table_from_json( cast(Iterable[dict[str, Any]], rows), table, job_config=job_config, ) job.result() logging.info(f"Wrote {len(rows)} records into {table}") else: logging.info(f"Skipping writes, would have written {len(rows)} to {table}") for row in rows: logging.debug(f" {row}") def insert_rows( self, table: str | bigquery.Table, rows: Sequence[Mapping[str, Any]], dataset_id: Optional[str] = None, ) -> None: table = self.get_table_id(dataset_id, table) if self.write: errors = self.client.insert_rows(table, rows) if errors: logging.error(errors) else: logging.info(f"Skipping writes, would have written {len(rows)} to {table}") for row in rows: logging.debug(f" {row}") def query( self, query: str, dataset_id: Optional[str] = None, parameters: Optional[Sequence[bigquery.query._AbstractQueryParameter]] = None, ) -> bigquery.table.RowIterator: """Run a query Note that this can't prevent writes in the case that the SQL does writes""" job_config = bigquery.QueryJobConfig( default_dataset=f"{self.client.project}.{self.get_dataset(dataset_id)}" ) if parameters is not None: job_config.query_parameters = parameters logging.debug(query) return self.client.query(query, job_config=job_config).result() def delete_table( self, table: bigquery.Table | str, not_found_ok: bool = False ) -> None: return self.client.delete_table(table, not_found_ok=not_found_ok) def temporary_table( self, schema: Iterable[bigquery.SchemaField], rows: Optional[Sequence[Mapping[str, Any]]] = None, dataset_id: Optional[str] = None, ) -> "TemporaryTable": return TemporaryTable(self, schema, rows, dataset_id) class TemporaryTable: def __init__( self, client: BigQuery, schema: Iterable[bigquery.SchemaField], rows: Optional[Sequence[Mapping[str, Any]]] = None, dataset_id: Optional[str] = None, ): self.client = client self.schema = schema self.rows = rows self.dataset_id = dataset_id self.name = f"tmp_{uuid.uuid4()}" self.table: Optional[bigquery.Table] = None def __enter__(self) -> Self: self.table = bigquery.Table( self.client.get_table_id(self.dataset_id, self.name), schema=self.schema ) self.client.client.create_table(self.table) if self.rows is not None: self.client.client.load_table_from_json( cast(Iterable[dict[str, Any]], self.rows), self.table, ).result() logging.info(f"Wrote {len(self.rows)} records into {self.name}") return self def __exit__( self, type_: Optional[type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: assert self.table is not None logging.info(f"Removing temporary table {self.name}") self.client.client.delete_table(self.table) self.table = None def query( self, query: str, parameters: Optional[Sequence[bigquery.query._AbstractQueryParameter]] = None, ) -> bigquery.table.RowIterator: job_config = bigquery.QueryJobConfig( default_dataset=f"{self.client.client.project}.{self.client.get_dataset(self.dataset_id)}" ) if parameters is not None: job_config.query_parameters = parameters logging.debug(query) return self.client.client.query(query, job_config=job_config).result()