parquet_flask/io_logic/ingest_new_file.py (113 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from math import isnan from os import environ import json import pandas from pyspark.sql.dataframe import DataFrame from parquet_flask.io_logic.cdms_constants import CDMSConstants from parquet_flask.io_logic.cdms_schema import CdmsSchema from parquet_flask.io_logic.retrieve_spark_session import RetrieveSparkSession from parquet_flask.io_logic.sanitize_record import SanitizeRecord from parquet_flask.utils.config import Config from parquet_flask.utils.file_utils import FileUtils from pyspark.sql.functions import to_timestamp, year, month, lit, col import pyspark.sql.functions as pyspark_functions # XXX: why not merge with import statement above? from pyspark.sql.types import StringType, DoubleType LOGGER = logging.getLogger(__name__) GEOSPATIAL_INTERVAL = 30 def get_geospatial_interval(project: str) -> dict: """ Get geospatial interval dict object from environment variable. If not found, return empty dict. :param project: project name :return: geospatial interval dict """ interval_dict = {} geo_spatial_interval_by_platform = environ.get(CDMSConstants.geospatial_interval_by_platform) if not geo_spatial_interval_by_platform: return interval_dict geo_spatial_interval_by_platform_dict = json.loads(geo_spatial_interval_by_platform) if not isinstance(geo_spatial_interval_by_platform_dict, dict): return interval_dict if project not in geo_spatial_interval_by_platform_dict or not isinstance(geo_spatial_interval_by_platform_dict[project], dict): return interval_dict return geo_spatial_interval_by_platform_dict[project] class IngestNewJsonFile: def __init__(self, is_overwriting=False): self.__sss = RetrieveSparkSession() config = Config() self.__app_name = config.get_spark_app_name() self.__master_spark = config.get_value('master_spark_url') self.__mode = 'overwrite' if is_overwriting else 'append' self.__parquet_name = config.get_value('parquet_file_name') self.__sanitize_record = True @property def sanitize_record(self): return self.__sanitize_record @sanitize_record.setter def sanitize_record(self, val): """ :param val: :return: None """ self.__sanitize_record = val return @staticmethod def create_df(spark_session, data_list, job_id, provider, project): LOGGER.debug(f'creating data frame with length {len(data_list)}') df = spark_session.createDataFrame(data_list) # spark_session.sparkContext.addPyFile('/usr/app/parquet_flask/lat_lon_udf.py') LOGGER.debug(f'adding columns') df: DataFrame = df.withColumn(CDMSConstants.time_obj_col, to_timestamp(CDMSConstants.time_col))\ .withColumn(CDMSConstants.year_col, year(CDMSConstants.time_col))\ .withColumn(CDMSConstants.month_col, month(CDMSConstants.time_col))\ .withColumn(CDMSConstants.platform_code_col, df[CDMSConstants.platform_col][CDMSConstants.code_col])\ .withColumn(CDMSConstants.job_id_col, lit(job_id))\ .withColumn(CDMSConstants.provider_col, lit(provider))\ .withColumn(CDMSConstants.project_col, lit(project)) geospatial_interval_dict = get_geospatial_interval(project) try: df: DataFrame = df.withColumn( CDMSConstants.geo_spatial_interval_col, pyspark_functions.udf( lambda platform_code, latitude, longitude: f'{int(latitude - divmod(latitude, int(geospatial_interval_dict.get(platform_code, GEOSPATIAL_INTERVAL)))[1])}_{int(longitude - divmod(longitude, int(geospatial_interval_dict.get(platform_code, GEOSPATIAL_INTERVAL)))[1])}', StringType())( df[CDMSConstants.platform_code_col], df[CDMSConstants.lat_col], df[CDMSConstants.lon_col])) df: DataFrame = df.repartition(1) # combine to 1 data frame to increase size # .withColumn('ingested_date', lit(TimeUtils.get_current_time_str())) LOGGER.debug(f'create writer') all_partitions = [ CDMSConstants.provider_col, CDMSConstants.project_col, CDMSConstants.platform_code_col, CDMSConstants.geo_spatial_interval_col, CDMSConstants.year_col, CDMSConstants.month_col, CDMSConstants.job_id_col ] # df = df.repartition(1) # XXX: is this line repeated? df_writer = df.write LOGGER.debug(f'create partitions') df_writer = df_writer.partitionBy(all_partitions) LOGGER.debug(f'created partitions') except BaseException as e: LOGGER.exception(f'unexpected exception. latitude: {df[CDMSConstants.lat_col]}. longitude: {df[CDMSConstants.lon_col]}') raise e return df_writer def ingest(self, abs_file_path, job_id): """ This method will assume that incoming file has data with in_situ_schema file. So, it will definitely has `time`, `project`, and `provider`. :param abs_file_path: :param job_id: :return: int - number of records """ if not FileUtils.file_exist(abs_file_path): raise ValueError('missing file to ingest it. path: {}'.format(abs_file_path)) LOGGER.debug(f'sanitizing the files ? : {self.__sanitize_record}') if self.sanitize_record is True: input_json = SanitizeRecord(Config().get_value('in_situ_schema')).start(abs_file_path) else: if not FileUtils.file_exist(abs_file_path): raise ValueError('json file does not exist: {}'.format(abs_file_path)) input_json = FileUtils.read_json(abs_file_path) for each_record in input_json[CDMSConstants.observations_key]: if 'depth' in each_record: each_record['depth'] = float(each_record['depth']) if 'wind_from_direction' in each_record: each_record['wind_from_direction'] = float(each_record['wind_from_direction']) if 'wind_to_direction' in each_record: each_record['wind_to_direction'] = float(each_record['wind_from_direction']) df_writer = self.create_df( self.__sss.retrieve_spark_session( self.__app_name, self.__master_spark), input_json[CDMSConstants.observations_key], job_id, input_json[CDMSConstants.provider_col], input_json[CDMSConstants.project_col]) df_writer.mode(self.__mode).parquet(self.__parquet_name, compression='GZIP') # snappy GZIP LOGGER.debug(f'finished writing parquet') return len(input_json[CDMSConstants.observations_key])