migration_toolkit/sql_generators/copy_rows/ddl_parser.py (81 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from re import Match
from typing import Dict, List
from common.bigquery_type import BigQueryType
from common.file_reader import read
logger = logging.getLogger(__name__)
class DDLParser:
def __init__(self, ddl_path):
ddl: List[str] = read(ddl_path).split("\n")
# We only care about the `CREATE TABLE` DDL
if ddl[0].startswith("CREATE SCHEMA"):
ddl = ddl[1:]
self._schema: Dict[str, BigQueryType] = self._to_schema(ddl)
self._fully_qualified_table_name = self._to_fully_qualified_table_name(
ddl[0]
)
def get_schema(self) -> Dict[str, BigQueryType]:
return self._schema
def get_fully_qualified_table_name(self):
return self._fully_qualified_table_name
@staticmethod
def _to_fully_qualified_table_name(ddl: str) -> str:
match: Match = re.search("CREATE TABLE `(.*)`", ddl)
if match:
return match.group(1)
@staticmethod
def _to_schema(ddl: List[str]) -> Dict[str, BigQueryType]:
ddl = [line.strip() for line in ddl]
columns_start_index = ddl.index("(")
try:
columns_end_index = ddl.index(")")
except ValueError:
try:
columns_end_index = ddl.index(");")
except ValueError:
raise ValueError(
"Couldn't parse schema file. Make sure this file is generated by"
" the response of `fetch_bigquery_table_schema.py`"
)
if "PRIMARY KEY" in ddl[columns_end_index - 1]:
columns_end_index -= 1
schema = ddl[columns_start_index + 1 : columns_end_index]
return DDLParser._to_dict(schema)
@staticmethod
def _to_dict(schema):
d = {}
for column in schema:
column = DDLParser._strip_trailing_comma(column.strip())
name: str = DDLParser._column_name(column)
if DDLParser._is_metadata_column(name):
logger.debug(f"Skipping metadata column {column}")
continue
try:
source_type: BigQueryType = DDLParser._column_schema(column)
except ValueError:
raise ValueError(
"Expected column description to in the format of '<column_name>"
f" <column_type>' but got: '{column}'"
)
d[name] = source_type
return d
@staticmethod
def _strip_trailing_comma(s: str):
return s[:-1] if s[-1] == "," else s
@staticmethod
def _is_metadata_column(column: str):
return column.startswith("_metadata_") or column == "datastream_metadata"
@staticmethod
def _column_schema(column: str) -> BigQueryType:
column_schema = column.split()[1:]
if column_schema[0].startswith("NUMERIC"):
return BigQueryType.NUMERIC
elif column_schema[0].startswith("BIGNUMERIC"):
return BigQueryType.BIGNUMERIC
elif column_schema[0].startswith("STRING"):
return BigQueryType.STRING
else:
return BigQueryType(column_schema[0])
@staticmethod
def _column_name(column: str):
return column.split()[0].strip("`")