python/dataproc_templates/util/dataframe_reader_wrappers.py (77 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from pyspark.sql import DataFrame, SparkSession
import dataproc_templates.util.template_constants as constants
from dataproc_templates.util.elasticsearch_transformations import rename_columns
def ingest_dataframe_from_cloud_storage(
spark: SparkSession,
args: dict,
input_location: str,
input_format: str,
prefix: str,
avro_format_override: Optional[str] = None,
) -> DataFrame:
"""Return a Dataframe reader object with methods and options applied for reading from Cloud Storage."""
input_data: DataFrame
csv_input_constant_options: dict = constants.get_csv_input_spark_options(prefix)
spark_options = {csv_input_constant_options[k]: v
for k, v in args.items()
if k in csv_input_constant_options and v}
if input_format == constants.FORMAT_PRQT:
input_data = spark.read \
.parquet(input_location)
elif input_format == constants.FORMAT_AVRO:
input_data = spark.read \
.format(avro_format_override or constants.FORMAT_AVRO_EXTD) \
.load(input_location)
elif input_format == constants.FORMAT_CSV:
input_data = spark.read \
.format(constants.FORMAT_CSV) \
.options(**spark_options) \
.load(input_location)
elif input_format == constants.FORMAT_JSON:
input_data = spark.read \
.json(input_location)
elif input_format == constants.FORMAT_DELTA:
input_data = spark.read \
.format(constants.FORMAT_DELTA) \
.load(input_location)
return input_data
def ingest_dataframe_from_elasticsearch(
spark: SparkSession,
es_node: str,
es_index: str,
es_user: str,
es_password: str,
es_api_key: str,
args: dict,
prefix: str,
) -> DataFrame:
"""Return a Dataframe reader object with methods and options applied for reading from Cloud Storage."""
input_data: DataFrame
es_spark_connector_input_options: dict = constants.get_es_spark_connector_input_options(prefix)
es_spark_connector_options = {es_spark_connector_input_options[k]: v
for k, v in args.items()
if k in es_spark_connector_input_options and v}
# Making Spark Case Sensitive
spark.conf.set('spark.sql.caseSensitive', True)
if es_api_key is not None:
es_conf_json = {
"es.nodes": es_node,
"es.resource": es_index,
"es.net.http.header.Authorization": es_api_key,
"es.output.json": "true"
}
else:
es_conf_json = {
"es.nodes": es_node,
"es.resource": es_index,
"es.net.http.auth.user": es_user,
"es.net.http.auth.pass": es_password,
"es.output.json": "true"
}
# Merging the Required and Optional attributes
es_conf_json.update(es_spark_connector_options)
# Read as RDD
input_data = spark.sparkContext.newAPIHadoopRDD(constants.FORMAT_ELASTICSEARCH,\
constants.ELASTICSEARCH_KEY_CLASS,\
constants.ELASTICSEARCH_VALUE_CLASS,\
conf=es_conf_json)
# Remove the Elasticsearch ID from the RDD
input_data = input_data.flatMap(lambda x: x[1:])
# Convert into Dataframe
input_data = spark.read.json(input_data)
# Remove Special Characters from the Column Names
input_data = rename_columns(input_data)
return input_data