def runtest()

in bigquery_etl/pytest_plugin/sql.py [0:0]


    def runtest(self):
        """Run."""
        test_name = self.fspath.basename
        query_name = self.fspath.dirpath().basename
        project_dir = (
            self.fspath.dirpath().dirpath().dirpath().dirname.replace("tests", "")
        )

        script_test = False

        # init tests write to dataset_query_test, instead of their default name
        path = self.fspath.dirname.replace("tests", "")
        if test_name == "test_init":
            query = render(
                "query.sql",
                template_folder=path,
                **{"is_init": lambda: True},
            )
        elif test_name == "test_script":
            script_test = True
            query = render("script.sql", template_folder=path)
        else:
            query = render("query.sql", template_folder=path)

        expect = load(self.fspath.strpath, "expect")

        tables = {}
        views = {}

        # generate tables for files with a supported table extension
        for resource in next(os.walk(self.fspath))[2]:
            if "." not in resource:
                continue  # tables require an extension
            table_name, extension = resource.rsplit(".", 1)
            if table_name.endswith(".schema") or table_name in (
                "expect",
                "query_params",
            ):
                continue  # not a table
            if extension in TABLE_EXTENSIONS or extension in ("yaml", "json"):
                if extension in TABLE_EXTENSIONS:
                    source_format = TABLE_EXTENSIONS[extension]
                    source_path = os.path.join(self.fspath.strpath, resource)
                else:
                    source_format = TABLE_EXTENSIONS["ndjson"]
                    source_path = (self.fspath.strpath, table_name)
                if "." in table_name:
                    # combine project and dataset name with table name
                    original, table_name = (
                        table_name,
                        table_name.replace(".", "_").replace("-", "_"),
                    )
                    original_pattern = (
                        r"`?(?<![._])\b"
                        + r"`?\.`?".join(original.split("."))
                        + r"\b(?![._])`?"
                    )
                    query = re.sub(original_pattern, table_name, query)
                else:
                    original = table_name

                # second check for tablename tweaks.
                # if the tablename ends with a date then need to replace that date with '*' for the
                # query text substitution to work.
                # e.g. see moz-fx-data-marketing-prod.65789850.ga_sessions_20230214
                # A query using that table uses moz-fx-data-marketing-prod.65789850.ga_sessions_*
                # with the date appended to allow for daily processing.
                try:
                    datetime.datetime.strptime(table_name[-8:], "%Y%m%d")
                except ValueError:
                    pass
                else:
                    generic_table_name = table_name[:-8] + "*"
                    generic_original = original[:-8] + "*"
                    query = query.replace(generic_original, generic_table_name)
                tables[table_name] = Table(table_name, source_format, source_path)
                print(f"Initialized {table_name}")
            elif extension == "sql":
                if "." in table_name:
                    # combine project and dataset name with table name
                    original, table_name = (
                        table_name,
                        table_name.replace(".", "_").replace("-", "_"),
                    )
                    query = query.replace(original, table_name)
                views[table_name] = read(self.fspath.strpath, resource)

        # rewrite all udfs as temporary
        query = parse_routine.sub_local_routines(query, project_dir)

        # if we're reading an initialization function, ensure that we're not
        # using a partition filter since we rely on `select * from {table}`
        query = query.replace(
            "require_partition_filter = TRUE", "require_partition_filter = FALSE"
        )

        dataset_id = "_".join(self.fspath.strpath.split(os.path.sep)[-3:])
        if "CIRCLE_BUILD_NUM" in os.environ:
            dataset_id += f"_{os.environ['CIRCLE_BUILD_NUM']}"

        bq = bigquery.Client()
        with dataset(bq, dataset_id) as default_dataset:
            with ThreadPool(8) as pool:
                pool.map(
                    partial(load_table, bq, default_dataset),
                    tables.values(),
                    chunksize=1,
                )
                pool.starmap(
                    partial(load_view, bq, default_dataset), views.items(), chunksize=1
                )

            # configure job
            res_table = bigquery.TableReference(default_dataset, query_name)

            if script_test:
                job_config = bigquery.QueryJobConfig(
                    default_dataset=default_dataset,
                    query_parameters=get_query_params(self.fspath.strpath),
                    use_legacy_sql=False,
                )

                bq.query(query, job_config=job_config).result()
                # Retrieve final state of table on init or script tests
                job = bq.query(
                    f"SELECT * FROM {dataset_id}.{query_name}", job_config=job_config
                )

            else:
                job_config = bigquery.QueryJobConfig(
                    default_dataset=default_dataset,
                    destination=res_table,
                    query_parameters=get_query_params(self.fspath.strpath),
                    use_legacy_sql=False,
                    write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
                )

                # run query
                job = bq.query(query, job_config=job_config)

            result = list(coerce_result(*job.result()))
            result.sort(
                key=lambda row: json.dumps(
                    row, sort_keys=True, default=default_encoding
                )
            )
            # make sure we encode dates correctly
            expect = json.loads(
                json.dumps(
                    sorted(
                        expect,
                        key=lambda row: json.dumps(
                            row, sort_keys=True, default=default_encoding
                        ),
                    ),
                    default=default_encoding,
                )
            )

            print_and_test(expect, result)