migration_toolkit/sql_generators/copy_rows/copy_rows.py (97 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import Dict, NamedTuple from common.bigquery_type import BigQueryType from common.file_writer import write from sql_generators.copy_rows.ddl_parser import DDLParser logger = logging.getLogger(__name__) COPY_DATA_SQL = ( "INSERT INTO {destination_table}\n" "(\n" " {destination_columns}\n" ")\n" "SELECT\n" " {source_columns}\n" "FROM {source_table};" ) class ColumnSchema(NamedTuple): source: BigQueryType destination: BigQueryType # This dictionary provides the appropriate casts between tables that were created # with Dataflow's "Datastream to BigQuery" template and tables that were created # with Datastream's native BigQuery solution. COLUMN_SCHEMAS_TO_CAST_EXPRESSION: Dict[ColumnSchema, str] = { # MySQL BINARY, VARBINARY # Oracle BLOB ColumnSchema( BigQueryType.BYTES, BigQueryType.STRING ): "SAFE_CONVERT_BYTES_TO_STRING({column_name})", # MySQL DATETIME ColumnSchema( BigQueryType.TIMESTAMP, BigQueryType.DATETIME ): "CAST({column_name} as DATETIME)", # MySQL DECIMAL ColumnSchema( BigQueryType.BIGNUMERIC, BigQueryType.NUMERIC ): "CAST({column_name} as NUMERIC)", # MySQL DECIMAL # Oracle NUMBER ColumnSchema( BigQueryType.BIGNUMERIC, BigQueryType.STRING ): "CAST({column_name} as STRING)", # MySQL TIME ColumnSchema(BigQueryType.STRING, BigQueryType.INTERVAL): ( "MAKE_INTERVAL(hour=>DIV(CAST({column_name} as INT64), " "3600000000), second=>DIV(MOD(CAST({column_name} as " "INT64), 3600000000), 1000000)) as {column_name}" ), # MySQL YEAR ColumnSchema( BigQueryType.STRING, BigQueryType.INT64 ): "CAST({column_name} as INT64)", # MySQL JSON ColumnSchema( BigQueryType.STRING, BigQueryType.JSON ): "PARSE_JSON({column_name})", # Oracle NUMBER with negative precision ColumnSchema( BigQueryType.BIGNUMERIC, BigQueryType.INT64 ): "CAST({column_name} as INT64)", } class CopyDataSQLGenerator: def __init__( self, source_bigquery_table_ddl: str, destination_bigquery_table_ddl: str, filepath: str, ): self.source_ddl_parser = DDLParser(source_bigquery_table_ddl) self.destination_ddl_parser = DDLParser(destination_bigquery_table_ddl) self.filepath = filepath def generate_sql(self): source_columns = [] destination_columns = [] for ( column_name, source_type, ) in self.source_ddl_parser.get_schema().items(): destination_type = self.destination_ddl_parser.get_schema().get( column_name ) if not destination_type: raise ValueError( "Column names must match in source and destination, but could not" f" find column name {column_name} in destination table. Destination" f" schema is {self.destination_ddl_parser.get_schema()}" ) column_name = f"`{column_name}`" destination_columns.append(column_name) if source_type == destination_type: logger.debug(f"Type match for column '{column_name}'") source_columns.append(column_name) else: logger.debug( f"Type mismatch for column '{column_name}': {source_type} == >" f" {destination_type}" ) source_columns.append( COLUMN_SCHEMAS_TO_CAST_EXPRESSION[ ColumnSchema(source_type, destination_type) ].format(column_name=column_name) ) sql = COPY_DATA_SQL.format( destination_table=self.destination_ddl_parser.get_fully_qualified_table_name(), source_columns=",\n ".join(source_columns), destination_columns=",\n ".join(destination_columns), source_table=self.source_ddl_parser.get_fully_qualified_table_name(), ) logger.info(f"Generated copy rows SQL statement:\n'{sql}'") self._write_to_file(sql) def _write_to_file(self, sql): write(filepath=self.filepath, data=sql)