python/dataproc_templates/util/elasticsearch_transformations.py (96 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Dict, List, Any
import re
import json
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, flatten
from pyspark.sql.types import StructType, ArrayType
def flatten_struct_fields(
nested_df: DataFrame
) -> DataFrame:
"""Return a Dataframe with the struct columns flattened"""
stack = [((), nested_df)]
columns = []
while stack:
parents, df = stack.pop()
flat_cols = [
col(".".join(parents + (field.name,))).alias("__".join(parents + (field.name,)))
for field in df.schema.fields
if not isinstance(field.dataType, StructType)
]
nested_cols = [
field.name
for field in df.schema.fields
if isinstance(field.dataType, StructType)
]
columns.extend(flat_cols)
for nested_col in nested_cols:
projected_df = df.select(nested_col + ".*")
stack.append((parents + (nested_col,), projected_df))
return nested_df.select(columns)
def detect_multidimensional_array_columns(
df: DataFrame
) -> List[tuple]:
"""Return a Dataframe with the struct columns flattened"""
multidim_columns = []
for field in df.schema.fields:
# Check if the field is an array
if isinstance(field.dataType, ArrayType):
depth = 0
element_type = field.dataType
# Unwrap ArrayTypes to find the depth and the innermost element type
while isinstance(element_type, ArrayType):
depth += 1
element_type = element_type.elementType
if depth > 1:
multidim_columns.append((field.name, depth))
return multidim_columns
def flatten_array_fields(
df: DataFrame
) -> DataFrame:
"""Return a Dataframe with the multidimensional array columns flattened into one dimensional array columns"""
columns_with_multidimensional_arrays = detect_multidimensional_array_columns(df)
for column_name, depth in columns_with_multidimensional_arrays:
while depth > 1:
df = df.withColumn(column_name, flatten(col(column_name)))
depth -= 1
return df
def rename_duplicate_columns(
dataframe_schema: Dict[str, Any],
column_name_set: set = set(),
parent: tuple = ()
) -> Dict[str, Any]:
"""Return a modified dataframe schema dict with the duplicate columns renamed"""
if 'fields' in dataframe_schema:
for fields in dataframe_schema['fields']:
qualified_column_name = '.'.join(parent + (fields['name'],))
new_qualified_column_name = qualified_column_name
i = 1
while new_qualified_column_name.lower() in column_name_set:
new_qualified_column_name = f"{qualified_column_name}_{i}"
i+=1
column_name_set.add(new_qualified_column_name.lower())
fields['name'] = new_qualified_column_name.split('.')[-1]
if 'type' in fields and isinstance(fields['type'], dict):
if fields['type']['type'] == "struct":
fields['type'] = rename_duplicate_columns(fields['type'], column_name_set, parent+(new_qualified_column_name.split('.')[-1],))
return dataframe_schema
def modify_json_schema(
dataframe_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""Return a modified dataframe schema dict with the Special Characters replaced with _ in the column names"""
if isinstance(dataframe_schema, dict):
for key in list(dataframe_schema.keys()):
if key == "name":
# Replaces all non-alphanumeric characters with underscores
dataframe_schema[key] = re.sub(r'[^a-zA-Z0-9_]+', '_', dataframe_schema[key])
# Recur for nested dictionaries
elif isinstance(dataframe_schema[key], dict):
modify_json_schema(dataframe_schema[key])
# Recur for each dictionary in the list if it's a list of dictionaries
elif isinstance(dataframe_schema[key], list):
for i in range(len(dataframe_schema[key])):
if isinstance(dataframe_schema[key][i], dict):
modify_json_schema(dataframe_schema[key][i])
return dataframe_schema
def rename_columns(
input_data: DataFrame,
) -> DataFrame:
"""Return a Dataframe with the Special Characters replaced with _ in the column names"""
renamed_df: DataFrame
# Rename the first level columns
renamed_df = input_data.selectExpr(*[f"`{column}` as `{re.sub(r'[^a-zA-Z0-9_]+', '_', column)}`" for column in input_data.columns])
# Rename the remaining columns
json_schema = modify_json_schema(json.loads(renamed_df.schema.json()))
# Rename the duplicate columns
json_schema = rename_duplicate_columns(json_schema)
replaced_schema = StructType.fromJson(json_schema)
for col_schema in replaced_schema:
if isinstance(col_schema.dataType, StructType) or isinstance(col_schema.dataType, ArrayType):
renamed_df = renamed_df.withColumn(col_schema.name, renamed_df[col_schema.name].cast(col_schema.dataType))
return renamed_df