migration_toolkit/sql_generators/copy_rows/ddl_parser.py (81 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import re from re import Match from typing import Dict, List from common.bigquery_type import BigQueryType from common.file_reader import read logger = logging.getLogger(__name__) class DDLParser: def __init__(self, ddl_path): ddl: List[str] = read(ddl_path).split("\n") # We only care about the `CREATE TABLE` DDL if ddl[0].startswith("CREATE SCHEMA"): ddl = ddl[1:] self._schema: Dict[str, BigQueryType] = self._to_schema(ddl) self._fully_qualified_table_name = self._to_fully_qualified_table_name( ddl[0] ) def get_schema(self) -> Dict[str, BigQueryType]: return self._schema def get_fully_qualified_table_name(self): return self._fully_qualified_table_name @staticmethod def _to_fully_qualified_table_name(ddl: str) -> str: match: Match = re.search("CREATE TABLE `(.*)`", ddl) if match: return match.group(1) @staticmethod def _to_schema(ddl: List[str]) -> Dict[str, BigQueryType]: ddl = [line.strip() for line in ddl] columns_start_index = ddl.index("(") try: columns_end_index = ddl.index(")") except ValueError: try: columns_end_index = ddl.index(");") except ValueError: raise ValueError( "Couldn't parse schema file. Make sure this file is generated by" " the response of `fetch_bigquery_table_schema.py`" ) if "PRIMARY KEY" in ddl[columns_end_index - 1]: columns_end_index -= 1 schema = ddl[columns_start_index + 1 : columns_end_index] return DDLParser._to_dict(schema) @staticmethod def _to_dict(schema): d = {} for column in schema: column = DDLParser._strip_trailing_comma(column.strip()) name: str = DDLParser._column_name(column) if DDLParser._is_metadata_column(name): logger.debug(f"Skipping metadata column {column}") continue try: source_type: BigQueryType = DDLParser._column_schema(column) except ValueError: raise ValueError( "Expected column description to in the format of '<column_name>" f" <column_type>' but got: '{column}'" ) d[name] = source_type return d @staticmethod def _strip_trailing_comma(s: str): return s[:-1] if s[-1] == "," else s @staticmethod def _is_metadata_column(column: str): return column.startswith("_metadata_") or column == "datastream_metadata" @staticmethod def _column_schema(column: str) -> BigQueryType: column_schema = column.split()[1:] if column_schema[0].startswith("NUMERIC"): return BigQueryType.NUMERIC elif column_schema[0].startswith("BIGNUMERIC"): return BigQueryType.BIGNUMERIC elif column_schema[0].startswith("STRING"): return BigQueryType.STRING else: return BigQueryType(column_schema[0]) @staticmethod def _column_name(column: str): return column.split()[0].strip("`")