# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from datetime import datetime

import pyspark.sql.functions as F
from parquet_flask.utils.file_utils import FileUtils
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import lit
from pyspark.sql.types import Row
from pyspark.sql.utils import AnalysisException

from parquet_flask.io_logic.cdms_schema import CdmsSchema
from parquet_flask.io_logic.parquet_query_condition_management_v4 import ParquetQueryConditionManagementV4
from parquet_flask.io_logic.partitioned_parquet_path import PartitionedParquetPath
from parquet_flask.io_logic.query_v2 import QueryProps
from parquet_flask.io_logic.cdms_constants import CDMSConstants
from parquet_flask.utils.config import Config
from parquet_flask.utils.general_utils import GeneralUtils

LOGGER = logging.getLogger(__name__)


class QueryV4:
    def __init__(self, props=QueryProps()):
        self.__props = props
        config = Config()
        self.__app_name = config.get_spark_app_name()
        self.__master_spark = config.get_value(Config.master_spark_url)
        self.__parquet_name = config.get_value(Config.parquet_file_name)
        self.__es_config = {
            'es_url': config.get_value(Config.es_url),
            'es_index': CDMSConstants.es_index_parquet_stats,
            'es_port': int(config.get_value(Config.es_port, '443')),
        }
        self.__parquet_name = self.__parquet_name if not self.__parquet_name.endswith('/') else self.__parquet_name[:-1]
        self.__missing_depth_value = CDMSConstants.missing_depth_value
        self.__conditions = []
        self.__sorting_columns = [CDMSConstants.time_col, CDMSConstants.platform_code_col, CDMSConstants.depth_col, CDMSConstants.lat_col, CDMSConstants.lon_col]
        self.__set_missing_depth_val()

    def __set_missing_depth_val(self):
        possible_missing_depth = Config().get_value(Config.missing_depth_value)
        if GeneralUtils.is_int(possible_missing_depth):
            self.__missing_depth_value = int(possible_missing_depth)
        return

    def __retrieve_spark(self):
        from parquet_flask.io_logic.retrieve_spark_session import RetrieveSparkSession
        spark = RetrieveSparkSession().retrieve_spark_session(self.__app_name, self.__master_spark)
        return spark

    def __strip_duplicates_maintain_order(self, condition_manager: ParquetQueryConditionManagementV4):
        LOGGER.warning(f'length of parquet_names: {len(condition_manager.parquet_names)}')
        distinct_list = []
        distinct_set = set([])
        for each in condition_manager.parquet_names:
            each: PartitionedParquetPath = each
            parquet_path = each.generate_path()
            if parquet_path in distinct_set:
                continue
            distinct_set.add(parquet_path)
            distinct_list.append(each)
        LOGGER.warning(f'length of distinct_parquet_names: {len(distinct_list)}')
        LOGGER.warning(f'distinct_parquet_names: {distinct_set}')
        return distinct_list

    def get_unioned_read_df(self, condition_manager: ParquetQueryConditionManagementV4, spark: SparkSession) -> DataFrame:
        cdms_spark_struct = CdmsSchema().get_schema_from_json(FileUtils.read_json(Config().get_value(Config.in_situ_schema)))
        if len(condition_manager.parquet_names) < 1:
            LOGGER.fatal(f'cannot find any in ES. returning None instead of searching entire parquet directory for now. ')
            return None
            # read_df: DataFrame = spark.read.schema(cdms_spark_struct).parquet(condition_manager.parquet_name)
            # return read_df
        read_df_list = []
        distinct_parquet_names = self.__strip_duplicates_maintain_order(condition_manager)
        for each in distinct_parquet_names:
            each: PartitionedParquetPath = each
            try:
                temp_df: DataFrame = spark.read.schema(cdms_spark_struct).parquet(each.generate_path())
                for k, v in each.get_df_columns().items():
                    temp_df: DataFrame = temp_df.withColumn(k, lit(v))
                read_df_list.append(temp_df)
            except Exception as e:
                LOGGER.exception(f'failed to retrieve data from spark for: {each.generate_path()}')
        if len(read_df_list) < 1:
            return None
        main_read_df: DataFrame = read_df_list[0]
        for each in read_df_list[1:]:
            main_read_df = main_read_df.union(each)
        return main_read_df

    def __get_paged_result(self, result_df: DataFrame, total_result: int):
        remaining_size = total_result - self.__props.start_at
        current_page_size = remaining_size if remaining_size < self.__props.size else self.__props.size
        result = result_df.limit(self.__props.start_at + current_page_size).tail(current_page_size)
        return result

    def __get_paged_result_v2(self, result_df: DataFrame):
        offset = self.__props.start_at + self.__props.size
        limit = self.__props.size
        df = result_df.withColumn('_id', F.monotonically_increasing_id())
        df = df.where(F.col('_id').between(offset, offset + limit))
        return df.collect()

    def __is_in_old_page(self, current_item: dict) -> bool:
        return current_item[CDMSConstants.time_col] == self.__props.min_datetime and current_item[CDMSConstants.platform_col]['code'] <= self.__props.marker_platform_code

    def __get_sorting_params(self, query_result: DataFrame):
        return [query_result[k].asc() for k in self.__sorting_columns]

    def __get_nth_first_page(self, query_result: DataFrame):
        result_head = query_result.where(f"{CDMSConstants.time_col} = '{self.__props.min_datetime}'").sort(self.__get_sorting_params(query_result)).collect()
        new_index = -1
        for i, each_row in enumerate(result_head):
            each_row: Row = each_row
            each_sha_256 = GeneralUtils.gen_sha_256_json_obj(each_row.asDict())
            if each_sha_256 == self.__props.marker_platform_code:
                new_index = i
                break
        if new_index < 0:
            LOGGER.warning(f'comparing sha256: {self.__props.marker_platform_code}')
            for each_row in result_head:
                each_row: Row = each_row
                each_sha_256 = GeneralUtils.gen_sha_256_json_obj(each_row.asDict())
                LOGGER.warning(f'each row: {str(each_row)}. each_sha_256: {each_sha_256}')
            raise ValueError(f'cannot find existing row. It should not happen.')
        result_page = query_result.take(self.__props.size + new_index + 1)
        result_tail = result_page[new_index + 1:]
        return result_tail

    def __get_page(self, query_result: DataFrame, total_result: int):
        if self.__props.size == 0:
            return []
        if self.__props.marker_platform_code is not None:  # pagination new logic
            return self.__get_nth_first_page(query_result)
        if total_result < 0:
            raise ValueError('total_result is not calculated for old pagination logic. This should not happen. Something has horribly gone wrong')
        # result = self.__get_paged_result_v2(query_result)
        return self.__get_paged_result(query_result, total_result)

    def __get_total_count(self, query_result: DataFrame):
        if self.__props.marker_platform_code is not None:
            LOGGER.debug(f'not counting total since this is an Nth page')
            return -1
        LOGGER.debug(f'counting total')
        return int(query_result.count())

    def search(self, spark_session=None):
        LOGGER.debug(f'<delay_check> query_v4_search started')
        condition_manager = ParquetQueryConditionManagementV4(self.__parquet_name, self.__missing_depth_value, self.__es_config, self.__props)
        condition_manager.manage_query_props()

        conditions = ' AND '.join(condition_manager.conditions)
        query_begin_time = datetime.now()
        LOGGER.debug(f'<delay_check> query begins at {query_begin_time}')
        spark = self.__retrieve_spark() if spark_session is None else spark_session
        created_spark_session_time = datetime.now()
        LOGGER.debug(f'<delay_check>spark session created at {created_spark_session_time}. duration: {created_spark_session_time - query_begin_time}')
        LOGGER.debug(f'__parquet_name: {condition_manager.parquet_name}')
        read_df: DataFrame = self.get_unioned_read_df(condition_manager, spark)
        if read_df is None:
            return {
                'total': 0,
                'results': [],
            }
        read_df_time = datetime.now()
        LOGGER.debug(f'<delay_check> parquet read created at {read_df_time}. duration: {read_df_time - created_spark_session_time}')
        query_result = read_df.where(conditions)
        query_result = query_result.sort(self.__get_sorting_params(query_result))
        query_time = datetime.now()
        LOGGER.debug(f'<delay_check> parquet read filtered at {query_time}. duration: {query_time - read_df_time}')
        LOGGER.debug(f'<delay_check> total duration: {query_time - query_begin_time}')
        total_result = self.__get_total_count(query_result)
        LOGGER.debug(f'<delay_check> total calc count duration: {datetime.now() - query_time}')
        if self.__props.size < 1:
            LOGGER.debug(f'returning only the size: {total_result}')
            return {
                'total': total_result,
                'results': [],
            }
        query_time = datetime.now()
        # result = query_result.withColumn('_id', F.monotonically_increasing_id())
        removing_cols = [CDMSConstants.time_obj_col, CDMSConstants.year_col, CDMSConstants.month_col]
        # result = result.where(F.col('_id').between(self.__props.start_at, self.__props.start_at + self.__props.size)).drop(*removing_cols)
        if len(condition_manager.columns) > 0:
            query_result = query_result.select(condition_manager.columns)
        else:
            query_result = query_result.drop(*removing_cols)
        LOGGER.debug(f'<delay_check> returning size : {total_result}')
        result = self.__get_page(query_result, total_result)
        query_result.unpersist()
        LOGGER.debug(f'<delay_check> total retrieval duration: {datetime.now() - query_time}')
        # spark.stop()
        return {
            'total': total_result,
            'results': [k.asDict() for k in result],
        }
