jobs/webcompat-kb/webcompat_kb/bqhelpers.py (175 lines of code) (raw):
import logging
import uuid
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Iterable, Mapping, Optional, Self, Sequence, cast
import google.auth
from google.cloud import bigquery
Json = Mapping[str, "Json"] | Sequence["Json"] | str | int | float | bool | None
def get_client(bq_project_id: str) -> bigquery.Client:
credentials, _ = google.auth.default(
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/bigquery",
]
)
return bigquery.Client(credentials=credentials, project=bq_project_id)
@dataclass
class RangePartition:
field: str
start: int
end: int
interval: int = 1
class BigQuery:
def __init__(self, client: bigquery.Client, default_dataset_id: str, write: bool):
self.client = client
self.default_dataset_id = default_dataset_id
self.write = write
def get_dataset(self, dataset_id: Optional[str]) -> str:
if dataset_id is None:
return self.default_dataset_id
return dataset_id
def get_table_id(
self, dataset_id: Optional[str], table: bigquery.Table | str
) -> bigquery.Table | str:
if isinstance(table, bigquery.Table):
return table
if "." in table:
return table
dataset_id = self.get_dataset(dataset_id)
return f"{self.client.project}.{dataset_id}.{table}"
def ensure_table(
self,
table_id: str,
schema: Iterable[bigquery.SchemaField],
recreate: bool = False,
dataset_id: Optional[str] = None,
partition: Optional[RangePartition] = None,
) -> bigquery.Table:
table = bigquery.Table(self.get_table_id(dataset_id, table_id), schema=schema)
if partition:
table.range_partitioning = bigquery.table.RangePartitioning(
bigquery.table.PartitionRange(
partition.start, partition.end, partition.interval
),
field=partition.field,
)
if self.write:
if recreate:
self.client.delete_table(table, not_found_ok=True)
self.client.create_table(table, exists_ok=True)
return table
def write_table(
self,
table: bigquery.Table | str,
schema: list[bigquery.SchemaField],
rows: Sequence[Mapping[str, Json]],
overwrite: bool,
dataset_id: Optional[str] = None,
) -> None:
table = self.get_table_id(dataset_id, table)
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
schema=schema,
write_disposition="WRITE_APPEND" if not overwrite else "WRITE_TRUNCATE",
)
if self.write:
job = self.client.load_table_from_json(
cast(Iterable[dict[str, Any]], rows),
table,
job_config=job_config,
)
job.result()
logging.info(f"Wrote {len(rows)} records into {table}")
else:
logging.info(f"Skipping writes, would have written {len(rows)} to {table}")
for row in rows:
logging.debug(f" {row}")
def insert_rows(
self,
table: str | bigquery.Table,
rows: Sequence[Mapping[str, Any]],
dataset_id: Optional[str] = None,
) -> None:
table = self.get_table_id(dataset_id, table)
if self.write:
errors = self.client.insert_rows(table, rows)
if errors:
logging.error(errors)
else:
logging.info(f"Skipping writes, would have written {len(rows)} to {table}")
for row in rows:
logging.debug(f" {row}")
def query(
self,
query: str,
dataset_id: Optional[str] = None,
parameters: Optional[Sequence[bigquery.query._AbstractQueryParameter]] = None,
) -> bigquery.table.RowIterator:
"""Run a query
Note that this can't prevent writes in the case that the SQL does writes"""
job_config = bigquery.QueryJobConfig(
default_dataset=f"{self.client.project}.{self.get_dataset(dataset_id)}"
)
if parameters is not None:
job_config.query_parameters = parameters
logging.debug(query)
return self.client.query(query, job_config=job_config).result()
def delete_table(
self, table: bigquery.Table | str, not_found_ok: bool = False
) -> None:
return self.client.delete_table(table, not_found_ok=not_found_ok)
def temporary_table(
self,
schema: Iterable[bigquery.SchemaField],
rows: Optional[Sequence[Mapping[str, Any]]] = None,
dataset_id: Optional[str] = None,
) -> "TemporaryTable":
return TemporaryTable(self, schema, rows, dataset_id)
class TemporaryTable:
def __init__(
self,
client: BigQuery,
schema: Iterable[bigquery.SchemaField],
rows: Optional[Sequence[Mapping[str, Any]]] = None,
dataset_id: Optional[str] = None,
):
self.client = client
self.schema = schema
self.rows = rows
self.dataset_id = dataset_id
self.name = f"tmp_{uuid.uuid4()}"
self.table: Optional[bigquery.Table] = None
def __enter__(self) -> Self:
self.table = bigquery.Table(
self.client.get_table_id(self.dataset_id, self.name), schema=self.schema
)
self.client.client.create_table(self.table)
if self.rows is not None:
self.client.client.load_table_from_json(
cast(Iterable[dict[str, Any]], self.rows),
self.table,
).result()
logging.info(f"Wrote {len(self.rows)} records into {self.name}")
return self
def __exit__(
self,
type_: Optional[type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
assert self.table is not None
logging.info(f"Removing temporary table {self.name}")
self.client.client.delete_table(self.table)
self.table = None
def query(
self,
query: str,
parameters: Optional[Sequence[bigquery.query._AbstractQueryParameter]] = None,
) -> bigquery.table.RowIterator:
job_config = bigquery.QueryJobConfig(
default_dataset=f"{self.client.client.project}.{self.client.get_dataset(self.dataset_id)}"
)
if parameters is not None:
job_config.query_parameters = parameters
logging.debug(query)
return self.client.client.query(query, job_config=job_config).result()