# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import numpy as np
from netCDF4._netCDF4 import Dataset
from webservice.algorithms.NexusCalcHandler import NexusCalcHandler
from webservice.metrics import MetricsRecord, SparkAccumulatorMetricsField, NumberMetricsField
from webservice.webmodel import NexusProcessingException

logger = logging.getLogger(__name__)


class NexusCalcSparkHandler(NexusCalcHandler):
    class SparkJobContext(object):

        class MaxConcurrentJobsReached(Exception):
            def __init__(self, *args, **kwargs):
                Exception.__init__(self, *args, **kwargs)

        def __init__(self, job_stack):
            self.spark_job_stack = job_stack
            self.job_name = None
            self.log = logging.getLogger(__name__)

        def __enter__(self):
            try:
                self.job_name = self.spark_job_stack.pop()
                self.log.debug("Using %s" % self.job_name)
            except IndexError:
                raise NexusCalcSparkHandler.SparkJobContext.MaxConcurrentJobsReached()
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            if self.job_name is not None:
                self.log.debug("Returning %s" % self.job_name)
                self.spark_job_stack.append(self.job_name)

    def __init__(self, tile_service_factory, sc=None, **kwargs):
        import inspect

        NexusCalcHandler.__init__(self, tile_service_factory=tile_service_factory, **kwargs)
        self.spark_job_stack = []
        self._sc = sc
        # max_concurrent_jobs = algorithm_config.getint("spark", "maxconcurrentjobs") if algorithm_config.has_section(
        #     "spark") and algorithm_config.has_option("spark", "maxconcurrentjobs") else 10
        max_concurrent_jobs = 10
        self.spark_job_stack = list(["Job %s" % x for x in range(1, max_concurrent_jobs + 1)])
        self.log = logging.getLogger(__name__)

        def with_spark_job_context(calc_func):
            from functools import wraps

            @wraps(calc_func)
            def wrapped(*args, **kwargs1):
                try:
                    with NexusCalcSparkHandler.SparkJobContext(self.spark_job_stack) as job_context:
                        # TODO Pool and Job are forced to a 1-to-1 relationship
                        calc_func.__self__._sc.setLocalProperty("spark.scheduler.pool", job_context.job_name)
                        calc_func.__self__._sc.setJobGroup(job_context.job_name, "a spark job")
                        return calc_func(*args, **kwargs1)
                except NexusCalcSparkHandler.SparkJobContext.MaxConcurrentJobsReached:
                    raise NexusProcessingException(code=503,
                                                   reason="Max concurrent requests reached. Please try again later.")

            return wrapped

        for member in inspect.getmembers(self, predicate=inspect.ismethod):
            if member[0] == "calc":
                setattr(self, member[0], with_spark_job_context(member[1]))

    def _setQueryParams(self, ds, bounds, start_time=None, end_time=None,
                        start_year=None, end_year=None, clim_month=None,
                        fill=-9999.):
        self._ds = ds
        self._minLat, self._maxLat, self._minLon, self._maxLon = bounds
        self._startTime = start_time
        self._endTime = end_time
        self._startYear = start_year
        self._endYear = end_year
        self._climMonth = clim_month
        self._fill = fill

    def _set_info_from_tile_set(self, nexus_tiles):
        ntiles = len(nexus_tiles)
        self.log.debug('Attempting to extract info from {0} tiles'. \
                       format(ntiles))
        status = False
        self._latRes = None
        self._lonRes = None
        for tile in nexus_tiles:
            self.log.debug('tile coords:')
            self.log.debug('tile lats: {0}'.format(tile.latitudes))
            self.log.debug('tile lons: {0}'.format(tile.longitudes))
            if self._latRes is None:
                lats = tile.latitudes.data
                if (len(lats) > 1):
                    self._latRes = abs(lats[1] - lats[0])
            if self._lonRes is None:
                lons = tile.longitudes.data
                if (len(lons) > 1):
                    self._lonRes = abs(lons[1] - lons[0])
            if ((self._latRes is not None) and
                    (self._lonRes is not None)):
                lats_agg = np.concatenate([tile.latitudes.compressed()
                                           for tile in nexus_tiles])
                lons_agg = np.concatenate([tile.longitudes.compressed()
                                           for tile in nexus_tiles])
                self._minLatCent = np.min(lats_agg)
                self._maxLatCent = np.max(lats_agg)
                self._minLonCent = np.min(lons_agg)
                self._maxLonCent = np.max(lons_agg)
                self._nlats = int((self._maxLatCent - self._minLatCent) /
                                  self._latRes + 0.5) + 1
                self._nlons = int((self._maxLonCent - self._minLonCent) /
                                  self._lonRes + 0.5) + 1
                status = True
                break
        return status

    def _find_global_tile_set(self, metrics_callback=None):
        # This only works for a single dataset.  If more than one is provided,
        # we use the first one and ignore the rest.
        if type(self._ds) in (list, tuple):
            ds = self._ds[0]
        else:
            ds = self._ds

        # See what time stamps are in the specified range.
        t_in_range = self._tile_service.find_days_in_range_asc(self._minLat,
                                                               self._maxLat,
                                                               self._minLon,
                                                               self._maxLon,
                                                               ds,
                                                               self._startTime,
                                                               self._endTime,
                                                               metrics_callback=metrics_callback)

        # Empty tile set will be returned upon failure to find the global
        # tile set.
        nexus_tiles = []

        # Check one time stamp at a time and attempt to extract the global
        # tile set.
        for t in t_in_range:
            nexus_tiles = self._tile_service.get_tiles_bounded_by_box(self._minLat, self._maxLat, self._minLon,
                                                                      self._maxLon, ds=ds, start_time=t, end_time=t,
                                                                      metrics_callback=metrics_callback)
            if self._set_info_from_tile_set(nexus_tiles):
                # Successfully retrieved global tile set from nexus_tiles,
                # so no need to check any other time stamps.
                break
        return nexus_tiles

    def _find_tile_bounds(self, t):
        lats = t.latitudes
        lons = t.longitudes
        if (len(lats.compressed()) > 0) and (len(lons.compressed()) > 0):
            min_lat = np.ma.min(lats)
            max_lat = np.ma.max(lats)
            min_lon = np.ma.min(lons)
            max_lon = np.ma.max(lons)
            good_inds_lat = np.where(lats.mask == False)[0]
            good_inds_lon = np.where(lons.mask == False)[0]
            min_y = np.min(good_inds_lat)
            max_y = np.max(good_inds_lat)
            min_x = np.min(good_inds_lon)
            max_x = np.max(good_inds_lon)
            bounds = (min_lat, max_lat, min_lon, max_lon,
                      min_y, max_y, min_x, max_x)
        else:
            self.log.warn('Nothing in this tile!')
            bounds = None
        return bounds

    @staticmethod
    def query_by_parts(tile_service, min_lat, max_lat, min_lon, max_lon,
                       dataset, start_time, end_time, part_dim=0):
        nexus_max_tiles_per_query = 100
        # print 'trying query: ',min_lat, max_lat, min_lon, max_lon, \
        #    dataset, start_time, end_time
        try:
            tiles = \
                tile_service.find_tiles_in_box(min_lat, max_lat,
                                               min_lon, max_lon,
                                               dataset,
                                               start_time=start_time,
                                               end_time=end_time,
                                               fetch_data=False)
            assert (len(tiles) <= nexus_max_tiles_per_query)
        except:
            # print 'failed query: ',min_lat, max_lat, min_lon, max_lon, \
            #    dataset, start_time, end_time
            if part_dim == 0:
                # Partition by latitude.
                mid_lat = (min_lat + max_lat) / 2
                nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service,
                                                                   min_lat, mid_lat,
                                                                   min_lon, max_lon,
                                                                   dataset,
                                                                   start_time, end_time,
                                                                   part_dim=part_dim)
                nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service,
                                                                        mid_lat,
                                                                        max_lat,
                                                                        min_lon,
                                                                        max_lon,
                                                                        dataset,
                                                                        start_time,
                                                                        end_time,
                                                                        part_dim=part_dim))
            elif part_dim == 1:
                # Partition by longitude.
                mid_lon = (min_lon + max_lon) / 2
                nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service,
                                                                   min_lat, max_lat,
                                                                   min_lon, mid_lon,
                                                                   dataset,
                                                                   start_time, end_time,
                                                                   part_dim=part_dim)
                nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service,
                                                                        min_lat,
                                                                        max_lat,
                                                                        mid_lon,
                                                                        max_lon,
                                                                        dataset,
                                                                        start_time,
                                                                        end_time,
                                                                        part_dim=part_dim))
            elif part_dim == 2:
                # Partition by time.
                mid_time = (start_time + end_time) / 2
                nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service,
                                                                   min_lat, max_lat,
                                                                   min_lon, max_lon,
                                                                   dataset,
                                                                   start_time, mid_time,
                                                                   part_dim=part_dim)
                nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service,
                                                                        min_lat,
                                                                        max_lat,
                                                                        min_lon,
                                                                        max_lon,
                                                                        dataset,
                                                                        mid_time,
                                                                        end_time,
                                                                        part_dim=part_dim))
        else:
            # No exception, so query Cassandra for the tile data.
            # print 'Making NEXUS query to Cassandra for %d tiles...' % \
            #    len(tiles)
            # t1 = time.time()
            # print 'NEXUS call start at time %f' % t1
            # sys.stdout.flush()
            nexus_tiles = list(tile_service.fetch_data_for_tiles(*tiles))
            nexus_tiles = list(tile_service.mask_tiles_to_bbox(min_lat, max_lat,
                                                               min_lon, max_lon,
                                                               nexus_tiles))
            # t2 = time.time()
            # print 'NEXUS call end at time %f' % t2
            # print 'Seconds in NEXUS call: ', t2-t1
            # sys.stdout.flush()

        # print 'Returning %d tiles' % len(nexus_tiles)
        return nexus_tiles

    @staticmethod
    def _prune_tiles(nexus_tiles):
        del_ind = np.where([np.all(tile.data.mask) for tile in nexus_tiles])[0]
        for i in np.flipud(del_ind):
            del nexus_tiles[i]

    def _lat2ind(self, lat):
        return int((lat - self._minLatCent) / self._latRes + 0.5)

    def _lon2ind(self, lon):
        return int((lon - self._minLonCent) / self._lonRes + 0.5)

    def _ind2lat(self, y):
        return self._minLatCent + y * self._latRes

    def _ind2lon(self, x):
        return self._minLonCent + x * self._lonRes

    def _create_nc_file_time1d(self, a, fname, varname, varunits=None,
                               fill=None):
        self.log.debug('a={0}'.format(a))
        self.log.debug('shape a = {0}'.format(a.shape))
        assert len(a.shape) == 1
        time_dim = len(a)
        rootgrp = Dataset(fname, "w", format="NETCDF4")
        rootgrp.createDimension("time", time_dim)
        vals = rootgrp.createVariable(varname, "f4", dimensions=("time",),
                                      fill_value=fill)
        times = rootgrp.createVariable("time", "f4", dimensions=("time",))
        vals[:] = [d['mean'] for d in a]
        times[:] = [d['time'] for d in a]
        if varunits is not None:
            vals.units = varunits
        times.units = 'seconds since 1970-01-01 00:00:00'
        rootgrp.close()

    def _create_nc_file_latlon2d(self, a, fname, varname, varunits=None,
                                 fill=None):
        self.log.debug('a={0}'.format(a))
        self.log.debug('shape a = {0}'.format(a.shape))
        assert len(a.shape) == 2
        lat_dim, lon_dim = a.shape
        rootgrp = Dataset(fname, "w", format="NETCDF4")
        rootgrp.createDimension("lat", lat_dim)
        rootgrp.createDimension("lon", lon_dim)
        vals = rootgrp.createVariable(varname, "f4",
                                      dimensions=("lat", "lon",),
                                      fill_value=fill)
        lats = rootgrp.createVariable("lat", "f4", dimensions=("lat",))
        lons = rootgrp.createVariable("lon", "f4", dimensions=("lon",))
        vals[:, :] = a
        lats[:] = np.linspace(self._minLatCent,
                              self._maxLatCent, lat_dim)
        lons[:] = np.linspace(self._minLonCent,
                              self._maxLonCent, lon_dim)
        if varunits is not None:
            vals.units = varunits
        lats.units = "degrees north"
        lons.units = "degrees east"
        rootgrp.close()

    def _create_nc_file(self, a, fname, varname, **kwargs):
        self._create_nc_file_latlon2d(a, fname, varname, **kwargs)

    def _spark_nparts(self, nparts_requested):
        max_parallelism = 128
        num_partitions = min(nparts_requested if nparts_requested > 0
                             else self._sc.defaultParallelism,
                             max_parallelism)
        return num_partitions

    def _create_metrics_record(self):
        return MetricsRecord([
            SparkAccumulatorMetricsField(key='num_tiles',
                                         description='Number of tiles fetched',
                                         accumulator=self._sc.accumulator(0)),
            SparkAccumulatorMetricsField(key='partitions',
                                         description='Number of Spark partitions',
                                         accumulator=self._sc.accumulator(0)),
            SparkAccumulatorMetricsField(key='cassandra',
                                         description='Cumulative time to fetch data from Cassandra',
                                         accumulator=self._sc.accumulator(0)),
            SparkAccumulatorMetricsField(key='solr',
                                         description='Cumulative time to fetch data from Solr',
                                         accumulator=self._sc.accumulator(0)),
            SparkAccumulatorMetricsField(key='calculation',
                                         description='Cumulative time to do calculations',
                                         accumulator=self._sc.accumulator(0)),
            NumberMetricsField(key='reduce', description='Actual time to reduce results'),
            NumberMetricsField(key="actual_time", description="Total (actual) time")
        ])
