in bigquery_etl/pytest_plugin/sql.py [0:0]
def runtest(self):
"""Run."""
test_name = self.fspath.basename
query_name = self.fspath.dirpath().basename
project_dir = (
self.fspath.dirpath().dirpath().dirpath().dirname.replace("tests", "")
)
script_test = False
# init tests write to dataset_query_test, instead of their default name
path = self.fspath.dirname.replace("tests", "")
if test_name == "test_init":
query = render(
"query.sql",
template_folder=path,
**{"is_init": lambda: True},
)
elif test_name == "test_script":
script_test = True
query = render("script.sql", template_folder=path)
else:
query = render("query.sql", template_folder=path)
expect = load(self.fspath.strpath, "expect")
tables = {}
views = {}
# generate tables for files with a supported table extension
for resource in next(os.walk(self.fspath))[2]:
if "." not in resource:
continue # tables require an extension
table_name, extension = resource.rsplit(".", 1)
if table_name.endswith(".schema") or table_name in (
"expect",
"query_params",
):
continue # not a table
if extension in TABLE_EXTENSIONS or extension in ("yaml", "json"):
if extension in TABLE_EXTENSIONS:
source_format = TABLE_EXTENSIONS[extension]
source_path = os.path.join(self.fspath.strpath, resource)
else:
source_format = TABLE_EXTENSIONS["ndjson"]
source_path = (self.fspath.strpath, table_name)
if "." in table_name:
# combine project and dataset name with table name
original, table_name = (
table_name,
table_name.replace(".", "_").replace("-", "_"),
)
original_pattern = (
r"`?(?<![._])\b"
+ r"`?\.`?".join(original.split("."))
+ r"\b(?![._])`?"
)
query = re.sub(original_pattern, table_name, query)
else:
original = table_name
# second check for tablename tweaks.
# if the tablename ends with a date then need to replace that date with '*' for the
# query text substitution to work.
# e.g. see moz-fx-data-marketing-prod.65789850.ga_sessions_20230214
# A query using that table uses moz-fx-data-marketing-prod.65789850.ga_sessions_*
# with the date appended to allow for daily processing.
try:
datetime.datetime.strptime(table_name[-8:], "%Y%m%d")
except ValueError:
pass
else:
generic_table_name = table_name[:-8] + "*"
generic_original = original[:-8] + "*"
query = query.replace(generic_original, generic_table_name)
tables[table_name] = Table(table_name, source_format, source_path)
print(f"Initialized {table_name}")
elif extension == "sql":
if "." in table_name:
# combine project and dataset name with table name
original, table_name = (
table_name,
table_name.replace(".", "_").replace("-", "_"),
)
query = query.replace(original, table_name)
views[table_name] = read(self.fspath.strpath, resource)
# rewrite all udfs as temporary
query = parse_routine.sub_local_routines(query, project_dir)
# if we're reading an initialization function, ensure that we're not
# using a partition filter since we rely on `select * from {table}`
query = query.replace(
"require_partition_filter = TRUE", "require_partition_filter = FALSE"
)
dataset_id = "_".join(self.fspath.strpath.split(os.path.sep)[-3:])
if "CIRCLE_BUILD_NUM" in os.environ:
dataset_id += f"_{os.environ['CIRCLE_BUILD_NUM']}"
bq = bigquery.Client()
with dataset(bq, dataset_id) as default_dataset:
with ThreadPool(8) as pool:
pool.map(
partial(load_table, bq, default_dataset),
tables.values(),
chunksize=1,
)
pool.starmap(
partial(load_view, bq, default_dataset), views.items(), chunksize=1
)
# configure job
res_table = bigquery.TableReference(default_dataset, query_name)
if script_test:
job_config = bigquery.QueryJobConfig(
default_dataset=default_dataset,
query_parameters=get_query_params(self.fspath.strpath),
use_legacy_sql=False,
)
bq.query(query, job_config=job_config).result()
# Retrieve final state of table on init or script tests
job = bq.query(
f"SELECT * FROM {dataset_id}.{query_name}", job_config=job_config
)
else:
job_config = bigquery.QueryJobConfig(
default_dataset=default_dataset,
destination=res_table,
query_parameters=get_query_params(self.fspath.strpath),
use_legacy_sql=False,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
)
# run query
job = bq.query(query, job_config=job_config)
result = list(coerce_result(*job.result()))
result.sort(
key=lambda row: json.dumps(
row, sort_keys=True, default=default_encoding
)
)
# make sure we encode dates correctly
expect = json.loads(
json.dumps(
sorted(
expect,
key=lambda row: json.dumps(
row, sort_keys=True, default=default_encoding
),
),
default=default_encoding,
)
)
print_and_test(expect, result)