parquet_flask/io_logic/query_v4.py (173 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from datetime import datetime import pyspark.sql.functions as F from parquet_flask.utils.file_utils import FileUtils from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import lit from pyspark.sql.types import Row from pyspark.sql.utils import AnalysisException from parquet_flask.io_logic.cdms_schema import CdmsSchema from parquet_flask.io_logic.parquet_query_condition_management_v4 import ParquetQueryConditionManagementV4 from parquet_flask.io_logic.partitioned_parquet_path import PartitionedParquetPath from parquet_flask.io_logic.query_v2 import QueryProps from parquet_flask.io_logic.cdms_constants import CDMSConstants from parquet_flask.utils.config import Config from parquet_flask.utils.general_utils import GeneralUtils LOGGER = logging.getLogger(__name__) class QueryV4: def __init__(self, props=QueryProps()): self.__props = props config = Config() self.__app_name = config.get_spark_app_name() self.__master_spark = config.get_value(Config.master_spark_url) self.__parquet_name = config.get_value(Config.parquet_file_name) self.__es_config = { 'es_url': config.get_value(Config.es_url), 'es_index': CDMSConstants.es_index_parquet_stats, 'es_port': int(config.get_value(Config.es_port, '443')), } self.__parquet_name = self.__parquet_name if not self.__parquet_name.endswith('/') else self.__parquet_name[:-1] self.__missing_depth_value = CDMSConstants.missing_depth_value self.__conditions = [] self.__sorting_columns = [CDMSConstants.time_col, CDMSConstants.platform_code_col, CDMSConstants.depth_col, CDMSConstants.lat_col, CDMSConstants.lon_col] self.__set_missing_depth_val() def __set_missing_depth_val(self): possible_missing_depth = Config().get_value(Config.missing_depth_value) if GeneralUtils.is_int(possible_missing_depth): self.__missing_depth_value = int(possible_missing_depth) return def __retrieve_spark(self): from parquet_flask.io_logic.retrieve_spark_session import RetrieveSparkSession spark = RetrieveSparkSession().retrieve_spark_session(self.__app_name, self.__master_spark) return spark def __strip_duplicates_maintain_order(self, condition_manager: ParquetQueryConditionManagementV4): LOGGER.warning(f'length of parquet_names: {len(condition_manager.parquet_names)}') distinct_list = [] distinct_set = set([]) for each in condition_manager.parquet_names: each: PartitionedParquetPath = each parquet_path = each.generate_path() if parquet_path in distinct_set: continue distinct_set.add(parquet_path) distinct_list.append(each) LOGGER.warning(f'length of distinct_parquet_names: {len(distinct_list)}') LOGGER.warning(f'distinct_parquet_names: {distinct_set}') return distinct_list def get_unioned_read_df(self, condition_manager: ParquetQueryConditionManagementV4, spark: SparkSession) -> DataFrame: cdms_spark_struct = CdmsSchema().get_schema_from_json(FileUtils.read_json(Config().get_value(Config.in_situ_schema))) if len(condition_manager.parquet_names) < 1: LOGGER.fatal(f'cannot find any in ES. returning None instead of searching entire parquet directory for now. ') return None # read_df: DataFrame = spark.read.schema(cdms_spark_struct).parquet(condition_manager.parquet_name) # return read_df read_df_list = [] distinct_parquet_names = self.__strip_duplicates_maintain_order(condition_manager) for each in distinct_parquet_names: each: PartitionedParquetPath = each try: temp_df: DataFrame = spark.read.schema(cdms_spark_struct).parquet(each.generate_path()) for k, v in each.get_df_columns().items(): temp_df: DataFrame = temp_df.withColumn(k, lit(v)) read_df_list.append(temp_df) except Exception as e: LOGGER.exception(f'failed to retrieve data from spark for: {each.generate_path()}') if len(read_df_list) < 1: return None main_read_df: DataFrame = read_df_list[0] for each in read_df_list[1:]: main_read_df = main_read_df.union(each) return main_read_df def __get_paged_result(self, result_df: DataFrame, total_result: int): remaining_size = total_result - self.__props.start_at current_page_size = remaining_size if remaining_size < self.__props.size else self.__props.size result = result_df.limit(self.__props.start_at + current_page_size).tail(current_page_size) return result def __get_paged_result_v2(self, result_df: DataFrame): offset = self.__props.start_at + self.__props.size limit = self.__props.size df = result_df.withColumn('_id', F.monotonically_increasing_id()) df = df.where(F.col('_id').between(offset, offset + limit)) return df.collect() def __is_in_old_page(self, current_item: dict) -> bool: return current_item[CDMSConstants.time_col] == self.__props.min_datetime and current_item[CDMSConstants.platform_col]['code'] <= self.__props.marker_platform_code def __get_sorting_params(self, query_result: DataFrame): return [query_result[k].asc() for k in self.__sorting_columns] def __get_nth_first_page(self, query_result: DataFrame): result_head = query_result.where(f"{CDMSConstants.time_col} = '{self.__props.min_datetime}'").sort(self.__get_sorting_params(query_result)).collect() new_index = -1 for i, each_row in enumerate(result_head): each_row: Row = each_row each_sha_256 = GeneralUtils.gen_sha_256_json_obj(each_row.asDict()) if each_sha_256 == self.__props.marker_platform_code: new_index = i break if new_index < 0: LOGGER.warning(f'comparing sha256: {self.__props.marker_platform_code}') for each_row in result_head: each_row: Row = each_row each_sha_256 = GeneralUtils.gen_sha_256_json_obj(each_row.asDict()) LOGGER.warning(f'each row: {str(each_row)}. each_sha_256: {each_sha_256}') raise ValueError(f'cannot find existing row. It should not happen.') result_page = query_result.take(self.__props.size + new_index + 1) result_tail = result_page[new_index + 1:] return result_tail def __get_page(self, query_result: DataFrame, total_result: int): if self.__props.size == 0: return [] if self.__props.marker_platform_code is not None: # pagination new logic return self.__get_nth_first_page(query_result) if total_result < 0: raise ValueError('total_result is not calculated for old pagination logic. This should not happen. Something has horribly gone wrong') # result = self.__get_paged_result_v2(query_result) return self.__get_paged_result(query_result, total_result) def __get_total_count(self, query_result: DataFrame): if self.__props.marker_platform_code is not None: LOGGER.debug(f'not counting total since this is an Nth page') return -1 LOGGER.debug(f'counting total') return int(query_result.count()) def search(self, spark_session=None): LOGGER.debug(f'<delay_check> query_v4_search started') condition_manager = ParquetQueryConditionManagementV4(self.__parquet_name, self.__missing_depth_value, self.__es_config, self.__props) condition_manager.manage_query_props() conditions = ' AND '.join(condition_manager.conditions) query_begin_time = datetime.now() LOGGER.debug(f'<delay_check> query begins at {query_begin_time}') spark = self.__retrieve_spark() if spark_session is None else spark_session created_spark_session_time = datetime.now() LOGGER.debug(f'<delay_check>spark session created at {created_spark_session_time}. duration: {created_spark_session_time - query_begin_time}') LOGGER.debug(f'__parquet_name: {condition_manager.parquet_name}') read_df: DataFrame = self.get_unioned_read_df(condition_manager, spark) if read_df is None: return { 'total': 0, 'results': [], } read_df_time = datetime.now() LOGGER.debug(f'<delay_check> parquet read created at {read_df_time}. duration: {read_df_time - created_spark_session_time}') query_result = read_df.where(conditions) query_result = query_result.sort(self.__get_sorting_params(query_result)) query_time = datetime.now() LOGGER.debug(f'<delay_check> parquet read filtered at {query_time}. duration: {query_time - read_df_time}') LOGGER.debug(f'<delay_check> total duration: {query_time - query_begin_time}') total_result = self.__get_total_count(query_result) LOGGER.debug(f'<delay_check> total calc count duration: {datetime.now() - query_time}') if self.__props.size < 1: LOGGER.debug(f'returning only the size: {total_result}') return { 'total': total_result, 'results': [], } query_time = datetime.now() # result = query_result.withColumn('_id', F.monotonically_increasing_id()) removing_cols = [CDMSConstants.time_obj_col, CDMSConstants.year_col, CDMSConstants.month_col] # result = result.where(F.col('_id').between(self.__props.start_at, self.__props.start_at + self.__props.size)).drop(*removing_cols) if len(condition_manager.columns) > 0: query_result = query_result.select(condition_manager.columns) else: query_result = query_result.drop(*removing_cols) LOGGER.debug(f'<delay_check> returning size : {total_result}') result = self.__get_page(query_result, total_result) query_result.unpersist() LOGGER.debug(f'<delay_check> total retrieval duration: {datetime.now() - query_time}') # spark.stop() return { 'total': total_result, 'results': [k.asDict() for k in result], }