def qualify_table_references_in_file()

in bigquery_etl/util/common.py [0:0]


def qualify_table_references_in_file(path: Path) -> str:
    """Add project id and dataset id to table/view references and persistent udfs in a given query.

    e.g.:
    `table` -> `target_project.default_dataset.table`
    `dataset.table` -> `target_project.dataset.table`

    This allows a query to run in a different project than the sql dir it is located in
    while referencing the same tables.
    """
    # sqlglot cannot handle scripts with variables and control statements
    if re.search(
        r"^\s*DECLARE\b", path.read_text(), flags=re.MULTILINE
    ) or path.name in ("script.sql", "udf.sql"):
        raise NotImplementedError(
            "Cannot qualify table_references of query scripts or UDFs"
        )

    # determine the default target project and dataset from the path
    target_project = Path(path).parent.parent.parent.name
    default_dataset = Path(path).parent.parent.name

    # sqlglot doesn't support Jinja, so we need to render the queries and
    # init queries to raw SQL
    sql_query = render(
        path.name,
        template_folder=path.parent,
        format=False,
        **DEFAULT_QUERY_TEMPLATE_VARS,
    )
    init_query = render(
        path.name,
        template_folder=path.parent,
        format=False,
        is_init=lambda: True,
        metrics=MetricHub(),
    )
    # use sqlglot to get the SQL AST
    init_query_statements = sqlglot.parse(
        init_query,
        read="bigquery",
    )
    sql_query_statements = sqlglot.parse(sql_query, read="bigquery")

    # tuples of (table identifier, replacement string)
    table_replacements: Set[Tuple[str, str]] = set()

    # find all non-fully qualified table/view references including backticks
    for query in [init_query_statements, sql_query_statements]:
        for statement in query:
            if statement is None:
                continue

            cte_names = {
                cte.alias_or_name.lower() for cte in statement.find_all(sqlglot.exp.CTE)
            }

            table_aliases = {
                cte.alias_or_name.lower()
                for cte in statement.find_all(sqlglot.exp.TableAlias)
            }

            for table_expr in statement.find_all(sqlglot.exp.Table):
                # existing table ref including backticks without alias
                table_expr.set("alias", "")
                reference_string = table_expr.sql(dialect="bigquery")

                matched_cte = [
                    re.match(
                        rf"^{cte}(?![a-zA-Z0-9_])",
                        reference_string.replace("`", "").lower(),
                    )
                    for cte in cte_names.union(table_aliases)
                ]
                if any(matched_cte):
                    continue

                # project id is parsed as the catalog attribute
                # but information_schema region may also be parsed as catalog
                if table_expr.catalog.startswith("region-"):
                    project_name = f"{target_project}.{table_expr.catalog}"
                elif table_expr.catalog == "":  # no project id
                    project_name = target_project
                else:  # project id exists
                    continue

                # fully qualified table ref
                replacement_string = f"`{project_name}.{table_expr.db or default_dataset}.{table_expr.name}`"

                table_replacements.add((reference_string, replacement_string))

    updated_query = path.read_text()

    for identifier, replacement in table_replacements:
        if identifier.count(".") == 0:
            # if no dataset and project, only replace if it follows a FROM, JOIN, or implicit cross join
            regex = (
                r"(?P<from>(FROM|JOIN)\s+)"
                r"(?P<cross_join>[a-zA-Z0-9_`.\-]+\s*,\s*)?"
                rf"{identifier}(?![a-zA-Z0-9_`.])"
            )
            replacement = r"\g<from>\g<cross_join>" + replacement
        else:
            identifier = identifier.replace(".", r"\.")
            # ensure match is against the full identifier and no project id already
            regex = rf"(?<![a-zA-Z0-9_`.]){identifier}(?![a-zA-Z0-9_`.])"

        updated_query = re.sub(
            re.compile(regex),
            replacement,
            updated_query,
        )

    # replace udfs from udf/udf_js that do not have a project qualifier
    regex = r"(?<![a-zA-Z0-9_.])`?(?P<dataset>udf(_js)?)`?\.`?(?P<name>[a-zA-Z0-9_]+)`?"
    updated_query = re.sub(
        re.compile(regex),
        rf"`{target_project}.\g<dataset>.\g<name>`",
        updated_query,
    )

    return updated_query