parquet_flask/parquet_stat_extractor/statistics_retriever.py (136 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from pyspark.sql.dataframe import DataFrame import pyspark.sql.functions as pyspark_functions from parquet_flask.io_logic.cdms_constants import CDMSConstants LOGGER = logging.getLogger(__name__) class StatisticsRetriever: def __init__(self, input_dataset: DataFrame, observation_keys: list): self.__observation_keys = observation_keys self.__input_dataset = input_dataset self.__total = -1 self.__min_datetime = None self.__max_datetime = None self.__min_depth = None self.__max_depth = None self.__min_lat = None self.__max_lat = None self.__min_lon = None self.__max_lon = None self.__observation_count = [] @property def total(self): if self.__total == -1: self.total = int(self.__input_dataset.count()) return self.__total @total.setter def total(self, val): """ :param val: :return: None """ self.__total = val return @property def min_datetime(self): return self.__min_datetime @min_datetime.setter def min_datetime(self, val): """ :param val: :return: None """ self.__min_datetime = val return @property def max_datetime(self): return self.__max_datetime @max_datetime.setter def max_datetime(self, val): """ :param val: :return: None """ self.__max_datetime = val return @property def min_depth(self): return self.__min_depth @min_depth.setter def min_depth(self, val): """ :param val: :return: None """ self.__min_depth = val return @property def max_depth(self): return self.__max_depth @max_depth.setter def max_depth(self, val): """ :param val: :return: None """ self.__max_depth = val return @property def min_lat(self): return self.__min_lat @min_lat.setter def min_lat(self, val): """ :param val: :return: None """ self.__min_lat = val return @property def max_lat(self): return self.__max_lat @max_lat.setter def max_lat(self, val): """ :param val: :return: None """ self.__max_lat = val return @property def min_lon(self): return self.__min_lon @min_lon.setter def min_lon(self, val): """ :param val: :return: None """ self.__min_lon = val return @property def max_lon(self): return self.__max_lon @max_lon.setter def max_lon(self, val): """ :param val: :return: None """ self.__max_lon = val return def __get_min_depth_exclude_missing_val(self): filtered_input_dsert = self.__input_dataset.where(f'{CDMSConstants.depth_col} != {CDMSConstants.missing_depth_value}') stats = filtered_input_dsert.select(pyspark_functions.min(CDMSConstants.depth_col)).collect() if len(stats) != 1: raise ValueError(f'invalid row count on stats function: {stats}') stats = stats[0].asDict() self.min_depth = stats[f'min({CDMSConstants.depth_col})'] return def to_json(self) -> dict: """ :return: """ return { 'total': self.total, 'min_datetime': self.min_datetime, 'max_datetime': self.max_datetime, 'min_depth': self.min_depth, 'max_depth': self.max_depth, 'min_lat': self.min_lat, 'max_lat': self.max_lat, 'min_lon': self.min_lon, 'max_lon': self.max_lon, 'observation_counts': self.__observation_count } def start(self): stats = self.__input_dataset.select(pyspark_functions.min(CDMSConstants.lat_col), pyspark_functions.max(CDMSConstants.lat_col), pyspark_functions.min(CDMSConstants.lon_col), pyspark_functions.max(CDMSConstants.lon_col), pyspark_functions.min(CDMSConstants.depth_col), pyspark_functions.max(CDMSConstants.depth_col), pyspark_functions.min(CDMSConstants.time_obj_col), pyspark_functions.max(CDMSConstants.time_obj_col)).collect() if len(stats) != 1: raise ValueError(f'invalid row count on stats function: {stats}') stats = stats[0].asDict() self.min_lat = stats[f'min({CDMSConstants.lat_col})'] self.max_lat = stats[f'max({CDMSConstants.lat_col})'] self.min_lon = stats[f'min({CDMSConstants.lon_col})'] self.max_lon = stats[f'max({CDMSConstants.lon_col})'] self.min_depth = stats[f'min({CDMSConstants.depth_col})'] self.max_depth = stats[f'max({CDMSConstants.depth_col})'] self.min_datetime = stats[f'min({CDMSConstants.time_obj_col})'].timestamp() self.max_datetime = stats[f'max({CDMSConstants.time_obj_col})'].timestamp() if self.min_depth - CDMSConstants.missing_depth_value == 0: self.__get_min_depth_exclude_missing_val() self.__observation_count = {} for each_obs_key in self.__observation_keys: try: obs_count = self.__input_dataset.where(self.__input_dataset[each_obs_key].isNotNull()).count() except Exception as e: LOGGER.exception(f'error while getting total for key: {each_obs_key}') obs_count = 0 self.__observation_count[each_obs_key] = obs_count # self.__observation_count = {each_obs_key: self.__input_dataset.where(self.__input_dataset[each_obs_key].isNotNull()).count() for each_obs_key in self.__observation_keys} return self