src/translation/scripts/hive/extract_hive_ddls.py (170 lines of code) (raw):
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import ast
from datetime import datetime
from os.path import abspath
from google.cloud import bigquery, storage
from pyspark.sql import SparkSession
def WriteToCloud(ddl, bucket, path, table_name):
"""
Write String as a file to GCS
"""
print("Writing DDL to GCS: " + table_name)
client = storage.Client()
bucket = client.get_bucket(bucket)
blob = bucket.blob(path + "/" + table_name + ".sql")
blob.upload_from_string(ddl)
def get_spark_session(host_ip):
"""
Get Spark Session using HIVE THRIFT SERVER IP provided
"""
print("Connecting to Metastore: " + "thrift://" + host_ip + ":9083")
warehouse_location = abspath("spark-warehouse")
spark = (
SparkSession.builder.appName("extract_hove_ddl")
.config("hive.metastore.uris", "thrift://" + host_ip + ":9083")
.config("spark.jars", "gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar")
.config("spark.sql.warehouse.dir", warehouse_location)
.config("spark.sql.debug.maxToStringFields", 2000)
.enableHiveSupport()
.getOrCreate()
)
return spark
def read_translation_config(translation_config):
"""
Convert JSON config to dictionary
"""
dict = {}
dict["bigquery_audit_table"] = "hive_ddl_metadata"
dict["bq_dataset_audit"] = "dmt_logs"
dict["host_ip"] = translation_config["hive_config"]["server_config"]["connection"][
"host"
]
source_path = translation_config["migrationTask"]["translationConfigDetails"][
"gcsSourcePath"
]
dict["bucket_name"] = source_path.split("/")[2]
dict["gcs_ddl_output_path"] = source_path.split("/", 3)[-1]
nm_map_list = translation_config["migrationTask"]["translationConfigDetails"][
"nameMappingList"
]["name_map"]
dict["hive_db"] = list(set(d["source"]["schema"] for d in nm_map_list))[0]
dict["bq_dataset"] = list(set(d["target"]["schema"] for d in nm_map_list))[0]
input_tables = translation_config["source_ddl_extract_table_list"]
dict["input_tables_list"] = [x.lower() for x in input_tables.split(",")]
return dict
def get_table_list(dict, spark):
"""
Create list of tables to be loaded
"""
table_list = []
tables = spark.catalog.listTables(dict["hive_db"])
for tbl in tables:
table_list.append(tbl.name.lower())
if dict["input_tables_list"][0] != "*":
table_list = dict["input_tables_list"]
else:
table_list = table_list
return table_list
def get_table_format(tbl, hive_db, spark):
"""
Get table format
"""
df = spark.sql(f"describe formatted {hive_db}.{tbl}")
format_str = (
df.filter("col_name == 'InputFormat'").select("data_type").first()[0].upper()
)
if "AVRO" in format_str:
return "AVRO"
elif "PARQUET" in format_str:
return "PARQUET"
elif "ORC" in format_str:
return "ORC"
elif "TEXT" in format_str:
return "CSV"
else:
return "OTHER"
def get_partition_cluster_info(ddl_hive):
"""
Get Partitioning and Clustering info
"""
partitioning_flag = ""
clustering_flag = ""
if "PARTITIONED BY" in ddl_hive:
partitioning_flag = "Y"
else:
partitioning_flag = "N"
if "CLUSTERED BY" in ddl_hive:
clustering_flag = "Y"
else:
clustering_flag = "N"
return partitioning_flag, clustering_flag
def get_tbl_delimiter(hive_ddl_str):
"""
Get Field Delimiter for TEXT tables
Default Value:'\001' (default HIVE table delimiter)
"""
if "field.delim' = " in hive_ddl_str:
delim = repr(hive_ddl_str.split("field.delim' = ")[1].split("'")[1])
else:
delim = "\001"
return delim
def get_hive_ddls(dict, run_id, spark):
"""
extract HIVE DDls and table metadata
"""
run_time = datetime.now()
dbCheck = spark.catalog._jcatalog.databaseExists(dict["hive_db"])
hive_db = dict["hive_db"]
bq_dataset = dict["bq_dataset"]
if dbCheck:
table_list = get_table_list(dict, spark)
for tbl in table_list:
print(f"Extracting DDL for Table {tbl}")
ddl_hive = ""
try:
ddl_hive_df = spark.sql(f"show create table {hive_db}.{tbl} as serde")
ddl_hive = (
ddl_hive_df.first()[0]
.split("\nLOCATION '")[0]
.split("\nSTORED AS")[0]
)
except Exception:
print(f"Could not get DDL for table: {tbl}, trying without SERDE now..")
if len(ddl_hive) < 1:
try:
ddl_hive_df = spark.sql(f"show create table {hive_db}.{tbl}")
ddl_hive = ddl_hive_df.first()[0].split("\nUSING ")[0]
except Exception as e:
print(e)
if len(ddl_hive) > 1:
ddl_hive = (
ddl_hive.replace(f"{hive_db}.", "").replace(
f" TABLE {tbl}", f" TABLE IF NOT EXISTS {hive_db}.{tbl}"
)
+ ";"
)
WriteToCloud(
ddl_hive, dict["bucket_name"], dict["gcs_ddl_output_path"], tbl
)
storage_format = get_table_format(tbl, hive_db, spark)
partition_flag, cluster_flag = get_partition_cluster_info(ddl_hive)
if storage_format == "CSV":
field_delimiter = get_tbl_delimiter(ddl_hive_df.first()[0])
else:
field_delimiter = "NA"
ddl_extracted = "YES"
else:
(
ddl_extracted,
partition_flag,
cluster_flag,
storage_format,
field_delimiter,
) = ("NO", "", "", "", "")
metadata_list = [
{
"run_id": run_id,
"start_time": str(run_time),
"database": hive_db,
"bq_dataset": bq_dataset,
"table": tbl,
"field_delimiter": field_delimiter,
"partition_flag": partition_flag,
"cluster_flag": cluster_flag,
"format": storage_format,
"ddl_extracted": ddl_extracted,
}
]
client = bigquery.Client()
client.insert_rows_json(
dict["bq_dataset_audit"] + "." + dict["bigquery_audit_table"],
metadata_list,
)
spark.stop()
def main(translation_config, run_id):
dict = read_translation_config(translation_config)
spark = get_spark_session(dict["host_ip"])
get_hive_ddls(dict, run_id, spark)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--build_config", help="Hive Data Load Config")
args = parser.parse_args()
translation_config_str = args.build_config
translation_config = ast.literal_eval(translation_config_str)["config"]
run_id = translation_config["unique_id"]
main(translation_config, run_id)