bigquery_etl/pytest_plugin/routine.py (75 lines of code) (raw):
"""PyTest plugin for running udf tests."""
import os
import re
import pytest
from google.api_core.exceptions import BadRequest
from google.cloud import bigquery
from bigquery_etl.config import ConfigLoader
from bigquery_etl.util.common import project_dirs
from ..routine.parse_routine import (
GENERIC_DATASET,
PROCEDURE_FILE,
UDF_FILE,
parse_routines,
)
from .sql_test import dataset
_parsed_routines = None
def parsed_routines():
"""Get cached parsed routines."""
global _parsed_routines
if _parsed_routines is None:
_parsed_routines = {
routine.filepath: routine
for project in ConfigLoader.get(
"routine", "test_projects", fallback=project_dirs()
)
for routine in parse_routines(project)
}
return _parsed_routines
def pytest_configure(config):
"""Register a custom marker."""
config.addinivalue_line("markers", "routine: mark routine tests.")
def pytest_collect_file(parent, path):
"""Collect non-python query tests."""
if "tests/data" not in str(path.dirpath()):
if path.basename in (UDF_FILE, PROCEDURE_FILE):
return RoutineFile.from_parent(parent, fspath=path)
class RoutineFile(pytest.File):
"""Routine File."""
def collect(self):
"""Collect."""
self.add_marker("routine")
base_path = self.parent.parent.parent.parent.parent.path
path = str(self.path.relative_to(base_path))
self.routine = parsed_routines()[path]
for i, query in enumerate(self.routine.tests_full_sql):
yield RoutineTest.from_parent(
self, name=f"{self.routine.name}#{i+1}", query=query
)
class RoutineTest(pytest.Item):
"""Routine Test."""
def __init__(self, name, parent, query):
"""Initialize."""
super().__init__(name, parent)
self.query = query
if "#xfail" in query:
self.add_marker(pytest.mark.xfail(strict=True))
def safe_name(self):
"""Get the name as a valid slug."""
value = re.sub(r"[^\w\s_]", "", self.name.lower()).strip()
return re.sub(r"[_\s]+", "_", value)
def reportinfo(self):
"""Set report title to `self.name`."""
return super().reportinfo()[:2] + (self.name,)
def repr_failure(self, excinfo):
"""Skip traceback for api error."""
if excinfo.errisinstance(BadRequest):
return str(excinfo.value)
return super().repr_failure(excinfo)
def _prunetraceback(self, excinfo):
"""Prune traceback to runtest method."""
traceback = excinfo.traceback
ntraceback = traceback.cut(path=__file__)
excinfo.traceback = ntraceback.filter()
def runtest(self):
"""Run Test."""
bq = bigquery.Client()
dataset_id = self.safe_name()
if "CIRCLE_BUILD_NUM" in os.environ:
dataset_id += f"_{os.environ['CIRCLE_BUILD_NUM']}"
with dataset(bq, dataset_id) as default_dataset:
job_config = bigquery.QueryJobConfig(
use_legacy_sql=False, default_dataset=default_dataset
)
job = bq.query(
self.query.replace(GENERIC_DATASET, default_dataset.dataset_id),
job_config=job_config,
)
job.result()