analysis/webservice/algorithms_spark/NexusCalcSparkHandler.py (295 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy as np from netCDF4._netCDF4 import Dataset from webservice.algorithms.NexusCalcHandler import NexusCalcHandler from webservice.metrics import MetricsRecord, SparkAccumulatorMetricsField, NumberMetricsField from webservice.webmodel import NexusProcessingException logger = logging.getLogger(__name__) class NexusCalcSparkHandler(NexusCalcHandler): class SparkJobContext(object): class MaxConcurrentJobsReached(Exception): def __init__(self, *args, **kwargs): Exception.__init__(self, *args, **kwargs) def __init__(self, job_stack): self.spark_job_stack = job_stack self.job_name = None self.log = logging.getLogger(__name__) def __enter__(self): try: self.job_name = self.spark_job_stack.pop() self.log.debug("Using %s" % self.job_name) except IndexError: raise NexusCalcSparkHandler.SparkJobContext.MaxConcurrentJobsReached() return self def __exit__(self, exc_type, exc_val, exc_tb): if self.job_name is not None: self.log.debug("Returning %s" % self.job_name) self.spark_job_stack.append(self.job_name) def __init__(self, tile_service_factory, sc=None, **kwargs): import inspect NexusCalcHandler.__init__(self, tile_service_factory=tile_service_factory, **kwargs) self.spark_job_stack = [] self._sc = sc # max_concurrent_jobs = algorithm_config.getint("spark", "maxconcurrentjobs") if algorithm_config.has_section( # "spark") and algorithm_config.has_option("spark", "maxconcurrentjobs") else 10 max_concurrent_jobs = 10 self.spark_job_stack = list(["Job %s" % x for x in range(1, max_concurrent_jobs + 1)]) self.log = logging.getLogger(__name__) def with_spark_job_context(calc_func): from functools import wraps @wraps(calc_func) def wrapped(*args, **kwargs1): try: with NexusCalcSparkHandler.SparkJobContext(self.spark_job_stack) as job_context: # TODO Pool and Job are forced to a 1-to-1 relationship calc_func.__self__._sc.setLocalProperty("spark.scheduler.pool", job_context.job_name) calc_func.__self__._sc.setJobGroup(job_context.job_name, "a spark job") return calc_func(*args, **kwargs1) except NexusCalcSparkHandler.SparkJobContext.MaxConcurrentJobsReached: raise NexusProcessingException(code=503, reason="Max concurrent requests reached. Please try again later.") return wrapped for member in inspect.getmembers(self, predicate=inspect.ismethod): if member[0] == "calc": setattr(self, member[0], with_spark_job_context(member[1])) def _setQueryParams(self, ds, bounds, start_time=None, end_time=None, start_year=None, end_year=None, clim_month=None, fill=-9999.): self._ds = ds self._minLat, self._maxLat, self._minLon, self._maxLon = bounds self._startTime = start_time self._endTime = end_time self._startYear = start_year self._endYear = end_year self._climMonth = clim_month self._fill = fill def _set_info_from_tile_set(self, nexus_tiles): ntiles = len(nexus_tiles) self.log.debug('Attempting to extract info from {0} tiles'. \ format(ntiles)) status = False self._latRes = None self._lonRes = None for tile in nexus_tiles: self.log.debug('tile coords:') self.log.debug('tile lats: {0}'.format(tile.latitudes)) self.log.debug('tile lons: {0}'.format(tile.longitudes)) if self._latRes is None: lats = tile.latitudes.data if (len(lats) > 1): self._latRes = abs(lats[1] - lats[0]) if self._lonRes is None: lons = tile.longitudes.data if (len(lons) > 1): self._lonRes = abs(lons[1] - lons[0]) if ((self._latRes is not None) and (self._lonRes is not None)): lats_agg = np.concatenate([tile.latitudes.compressed() for tile in nexus_tiles]) lons_agg = np.concatenate([tile.longitudes.compressed() for tile in nexus_tiles]) self._minLatCent = np.min(lats_agg) self._maxLatCent = np.max(lats_agg) self._minLonCent = np.min(lons_agg) self._maxLonCent = np.max(lons_agg) self._nlats = int((self._maxLatCent - self._minLatCent) / self._latRes + 0.5) + 1 self._nlons = int((self._maxLonCent - self._minLonCent) / self._lonRes + 0.5) + 1 status = True break return status def _find_global_tile_set(self, metrics_callback=None): # This only works for a single dataset. If more than one is provided, # we use the first one and ignore the rest. if type(self._ds) in (list, tuple): ds = self._ds[0] else: ds = self._ds # See what time stamps are in the specified range. t_in_range = self._tile_service.find_days_in_range_asc(self._minLat, self._maxLat, self._minLon, self._maxLon, ds, self._startTime, self._endTime, metrics_callback=metrics_callback) # Empty tile set will be returned upon failure to find the global # tile set. nexus_tiles = [] # Check one time stamp at a time and attempt to extract the global # tile set. for t in t_in_range: nexus_tiles = self._tile_service.get_tiles_bounded_by_box(self._minLat, self._maxLat, self._minLon, self._maxLon, ds=ds, start_time=t, end_time=t, metrics_callback=metrics_callback) if self._set_info_from_tile_set(nexus_tiles): # Successfully retrieved global tile set from nexus_tiles, # so no need to check any other time stamps. break return nexus_tiles def _find_tile_bounds(self, t): lats = t.latitudes lons = t.longitudes if (len(lats.compressed()) > 0) and (len(lons.compressed()) > 0): min_lat = np.ma.min(lats) max_lat = np.ma.max(lats) min_lon = np.ma.min(lons) max_lon = np.ma.max(lons) good_inds_lat = np.where(lats.mask == False)[0] good_inds_lon = np.where(lons.mask == False)[0] min_y = np.min(good_inds_lat) max_y = np.max(good_inds_lat) min_x = np.min(good_inds_lon) max_x = np.max(good_inds_lon) bounds = (min_lat, max_lat, min_lon, max_lon, min_y, max_y, min_x, max_x) else: self.log.warn('Nothing in this tile!') bounds = None return bounds @staticmethod def query_by_parts(tile_service, min_lat, max_lat, min_lon, max_lon, dataset, start_time, end_time, part_dim=0): nexus_max_tiles_per_query = 100 # print 'trying query: ',min_lat, max_lat, min_lon, max_lon, \ # dataset, start_time, end_time try: tiles = \ tile_service.find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, dataset, start_time=start_time, end_time=end_time, fetch_data=False) assert (len(tiles) <= nexus_max_tiles_per_query) except: # print 'failed query: ',min_lat, max_lat, min_lon, max_lon, \ # dataset, start_time, end_time if part_dim == 0: # Partition by latitude. mid_lat = (min_lat + max_lat) / 2 nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service, min_lat, mid_lat, min_lon, max_lon, dataset, start_time, end_time, part_dim=part_dim) nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service, mid_lat, max_lat, min_lon, max_lon, dataset, start_time, end_time, part_dim=part_dim)) elif part_dim == 1: # Partition by longitude. mid_lon = (min_lon + max_lon) / 2 nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service, min_lat, max_lat, min_lon, mid_lon, dataset, start_time, end_time, part_dim=part_dim) nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service, min_lat, max_lat, mid_lon, max_lon, dataset, start_time, end_time, part_dim=part_dim)) elif part_dim == 2: # Partition by time. mid_time = (start_time + end_time) / 2 nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service, min_lat, max_lat, min_lon, max_lon, dataset, start_time, mid_time, part_dim=part_dim) nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service, min_lat, max_lat, min_lon, max_lon, dataset, mid_time, end_time, part_dim=part_dim)) else: # No exception, so query Cassandra for the tile data. # print 'Making NEXUS query to Cassandra for %d tiles...' % \ # len(tiles) # t1 = time.time() # print 'NEXUS call start at time %f' % t1 # sys.stdout.flush() nexus_tiles = list(tile_service.fetch_data_for_tiles(*tiles)) nexus_tiles = list(tile_service.mask_tiles_to_bbox(min_lat, max_lat, min_lon, max_lon, nexus_tiles)) # t2 = time.time() # print 'NEXUS call end at time %f' % t2 # print 'Seconds in NEXUS call: ', t2-t1 # sys.stdout.flush() # print 'Returning %d tiles' % len(nexus_tiles) return nexus_tiles @staticmethod def _prune_tiles(nexus_tiles): del_ind = np.where([np.all(tile.data.mask) for tile in nexus_tiles])[0] for i in np.flipud(del_ind): del nexus_tiles[i] def _lat2ind(self, lat): return int((lat - self._minLatCent) / self._latRes + 0.5) def _lon2ind(self, lon): return int((lon - self._minLonCent) / self._lonRes + 0.5) def _ind2lat(self, y): return self._minLatCent + y * self._latRes def _ind2lon(self, x): return self._minLonCent + x * self._lonRes def _create_nc_file_time1d(self, a, fname, varname, varunits=None, fill=None): self.log.debug('a={0}'.format(a)) self.log.debug('shape a = {0}'.format(a.shape)) assert len(a.shape) == 1 time_dim = len(a) rootgrp = Dataset(fname, "w", format="NETCDF4") rootgrp.createDimension("time", time_dim) vals = rootgrp.createVariable(varname, "f4", dimensions=("time",), fill_value=fill) times = rootgrp.createVariable("time", "f4", dimensions=("time",)) vals[:] = [d['mean'] for d in a] times[:] = [d['time'] for d in a] if varunits is not None: vals.units = varunits times.units = 'seconds since 1970-01-01 00:00:00' rootgrp.close() def _create_nc_file_latlon2d(self, a, fname, varname, varunits=None, fill=None): self.log.debug('a={0}'.format(a)) self.log.debug('shape a = {0}'.format(a.shape)) assert len(a.shape) == 2 lat_dim, lon_dim = a.shape rootgrp = Dataset(fname, "w", format="NETCDF4") rootgrp.createDimension("lat", lat_dim) rootgrp.createDimension("lon", lon_dim) vals = rootgrp.createVariable(varname, "f4", dimensions=("lat", "lon",), fill_value=fill) lats = rootgrp.createVariable("lat", "f4", dimensions=("lat",)) lons = rootgrp.createVariable("lon", "f4", dimensions=("lon",)) vals[:, :] = a lats[:] = np.linspace(self._minLatCent, self._maxLatCent, lat_dim) lons[:] = np.linspace(self._minLonCent, self._maxLonCent, lon_dim) if varunits is not None: vals.units = varunits lats.units = "degrees north" lons.units = "degrees east" rootgrp.close() def _create_nc_file(self, a, fname, varname, **kwargs): self._create_nc_file_latlon2d(a, fname, varname, **kwargs) def _spark_nparts(self, nparts_requested): max_parallelism = 128 num_partitions = min(nparts_requested if nparts_requested > 0 else self._sc.defaultParallelism, max_parallelism) return num_partitions def _create_metrics_record(self): return MetricsRecord([ SparkAccumulatorMetricsField(key='num_tiles', description='Number of tiles fetched', accumulator=self._sc.accumulator(0)), SparkAccumulatorMetricsField(key='partitions', description='Number of Spark partitions', accumulator=self._sc.accumulator(0)), SparkAccumulatorMetricsField(key='cassandra', description='Cumulative time to fetch data from Cassandra', accumulator=self._sc.accumulator(0)), SparkAccumulatorMetricsField(key='solr', description='Cumulative time to fetch data from Solr', accumulator=self._sc.accumulator(0)), SparkAccumulatorMetricsField(key='backend', description='Cumulative time to fetch data from external backend(s)', accumulator=self._sc.accumulator(0)), SparkAccumulatorMetricsField(key='calculation', description='Cumulative time to do calculations', accumulator=self._sc.accumulator(0)), NumberMetricsField(key='reduce', description='Actual time to reduce results'), NumberMetricsField(key="actual_time", description="Total (actual) time") ])