analysis/webservice/algorithms_spark/NexusCalcSparkHandler.py (292 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from netCDF4._netCDF4 import Dataset
from webservice.algorithms.NexusCalcHandler import NexusCalcHandler
from webservice.metrics import MetricsRecord, SparkAccumulatorMetricsField, NumberMetricsField
from webservice.webmodel import NexusProcessingException
logger = logging.getLogger(__name__)
class NexusCalcSparkHandler(NexusCalcHandler):
class SparkJobContext(object):
class MaxConcurrentJobsReached(Exception):
def __init__(self, *args, **kwargs):
Exception.__init__(self, *args, **kwargs)
def __init__(self, job_stack):
self.spark_job_stack = job_stack
self.job_name = None
self.log = logging.getLogger(__name__)
def __enter__(self):
try:
self.job_name = self.spark_job_stack.pop()
self.log.debug("Using %s" % self.job_name)
except IndexError:
raise NexusCalcSparkHandler.SparkJobContext.MaxConcurrentJobsReached()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.job_name is not None:
self.log.debug("Returning %s" % self.job_name)
self.spark_job_stack.append(self.job_name)
def __init__(self, tile_service_factory, sc=None, **kwargs):
import inspect
NexusCalcHandler.__init__(self, tile_service_factory=tile_service_factory, **kwargs)
self.spark_job_stack = []
self._sc = sc
# max_concurrent_jobs = algorithm_config.getint("spark", "maxconcurrentjobs") if algorithm_config.has_section(
# "spark") and algorithm_config.has_option("spark", "maxconcurrentjobs") else 10
max_concurrent_jobs = 10
self.spark_job_stack = list(["Job %s" % x for x in range(1, max_concurrent_jobs + 1)])
self.log = logging.getLogger(__name__)
def with_spark_job_context(calc_func):
from functools import wraps
@wraps(calc_func)
def wrapped(*args, **kwargs1):
try:
with NexusCalcSparkHandler.SparkJobContext(self.spark_job_stack) as job_context:
# TODO Pool and Job are forced to a 1-to-1 relationship
calc_func.__self__._sc.setLocalProperty("spark.scheduler.pool", job_context.job_name)
calc_func.__self__._sc.setJobGroup(job_context.job_name, "a spark job")
return calc_func(*args, **kwargs1)
except NexusCalcSparkHandler.SparkJobContext.MaxConcurrentJobsReached:
raise NexusProcessingException(code=503,
reason="Max concurrent requests reached. Please try again later.")
return wrapped
for member in inspect.getmembers(self, predicate=inspect.ismethod):
if member[0] == "calc":
setattr(self, member[0], with_spark_job_context(member[1]))
def _setQueryParams(self, ds, bounds, start_time=None, end_time=None,
start_year=None, end_year=None, clim_month=None,
fill=-9999.):
self._ds = ds
self._minLat, self._maxLat, self._minLon, self._maxLon = bounds
self._startTime = start_time
self._endTime = end_time
self._startYear = start_year
self._endYear = end_year
self._climMonth = clim_month
self._fill = fill
def _set_info_from_tile_set(self, nexus_tiles):
ntiles = len(nexus_tiles)
self.log.debug('Attempting to extract info from {0} tiles'. \
format(ntiles))
status = False
self._latRes = None
self._lonRes = None
for tile in nexus_tiles:
self.log.debug('tile coords:')
self.log.debug('tile lats: {0}'.format(tile.latitudes))
self.log.debug('tile lons: {0}'.format(tile.longitudes))
if self._latRes is None:
lats = tile.latitudes.data
if (len(lats) > 1):
self._latRes = abs(lats[1] - lats[0])
if self._lonRes is None:
lons = tile.longitudes.data
if (len(lons) > 1):
self._lonRes = abs(lons[1] - lons[0])
if ((self._latRes is not None) and
(self._lonRes is not None)):
lats_agg = np.concatenate([tile.latitudes.compressed()
for tile in nexus_tiles])
lons_agg = np.concatenate([tile.longitudes.compressed()
for tile in nexus_tiles])
self._minLatCent = np.min(lats_agg)
self._maxLatCent = np.max(lats_agg)
self._minLonCent = np.min(lons_agg)
self._maxLonCent = np.max(lons_agg)
self._nlats = int((self._maxLatCent - self._minLatCent) /
self._latRes + 0.5) + 1
self._nlons = int((self._maxLonCent - self._minLonCent) /
self._lonRes + 0.5) + 1
status = True
break
return status
def _find_global_tile_set(self, metrics_callback=None):
# This only works for a single dataset. If more than one is provided,
# we use the first one and ignore the rest.
if type(self._ds) in (list, tuple):
ds = self._ds[0]
else:
ds = self._ds
# See what time stamps are in the specified range.
t_in_range = self._tile_service.find_days_in_range_asc(self._minLat,
self._maxLat,
self._minLon,
self._maxLon,
ds,
self._startTime,
self._endTime,
metrics_callback=metrics_callback)
# Empty tile set will be returned upon failure to find the global
# tile set.
nexus_tiles = []
# Check one time stamp at a time and attempt to extract the global
# tile set.
for t in t_in_range:
nexus_tiles = self._tile_service.get_tiles_bounded_by_box(self._minLat, self._maxLat, self._minLon,
self._maxLon, ds=ds, start_time=t, end_time=t,
metrics_callback=metrics_callback)
if self._set_info_from_tile_set(nexus_tiles):
# Successfully retrieved global tile set from nexus_tiles,
# so no need to check any other time stamps.
break
return nexus_tiles
def _find_tile_bounds(self, t):
lats = t.latitudes
lons = t.longitudes
if (len(lats.compressed()) > 0) and (len(lons.compressed()) > 0):
min_lat = np.ma.min(lats)
max_lat = np.ma.max(lats)
min_lon = np.ma.min(lons)
max_lon = np.ma.max(lons)
good_inds_lat = np.where(lats.mask == False)[0]
good_inds_lon = np.where(lons.mask == False)[0]
min_y = np.min(good_inds_lat)
max_y = np.max(good_inds_lat)
min_x = np.min(good_inds_lon)
max_x = np.max(good_inds_lon)
bounds = (min_lat, max_lat, min_lon, max_lon,
min_y, max_y, min_x, max_x)
else:
self.log.warn('Nothing in this tile!')
bounds = None
return bounds
@staticmethod
def query_by_parts(tile_service, min_lat, max_lat, min_lon, max_lon,
dataset, start_time, end_time, part_dim=0):
nexus_max_tiles_per_query = 100
# print 'trying query: ',min_lat, max_lat, min_lon, max_lon, \
# dataset, start_time, end_time
try:
tiles = \
tile_service.find_tiles_in_box(min_lat, max_lat,
min_lon, max_lon,
dataset,
start_time=start_time,
end_time=end_time,
fetch_data=False)
assert (len(tiles) <= nexus_max_tiles_per_query)
except:
# print 'failed query: ',min_lat, max_lat, min_lon, max_lon, \
# dataset, start_time, end_time
if part_dim == 0:
# Partition by latitude.
mid_lat = (min_lat + max_lat) / 2
nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service,
min_lat, mid_lat,
min_lon, max_lon,
dataset,
start_time, end_time,
part_dim=part_dim)
nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service,
mid_lat,
max_lat,
min_lon,
max_lon,
dataset,
start_time,
end_time,
part_dim=part_dim))
elif part_dim == 1:
# Partition by longitude.
mid_lon = (min_lon + max_lon) / 2
nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service,
min_lat, max_lat,
min_lon, mid_lon,
dataset,
start_time, end_time,
part_dim=part_dim)
nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service,
min_lat,
max_lat,
mid_lon,
max_lon,
dataset,
start_time,
end_time,
part_dim=part_dim))
elif part_dim == 2:
# Partition by time.
mid_time = (start_time + end_time) / 2
nexus_tiles = NexusCalcSparkHandler.query_by_parts(tile_service,
min_lat, max_lat,
min_lon, max_lon,
dataset,
start_time, mid_time,
part_dim=part_dim)
nexus_tiles.extend(NexusCalcSparkHandler.query_by_parts(tile_service,
min_lat,
max_lat,
min_lon,
max_lon,
dataset,
mid_time,
end_time,
part_dim=part_dim))
else:
# No exception, so query Cassandra for the tile data.
# print 'Making NEXUS query to Cassandra for %d tiles...' % \
# len(tiles)
# t1 = time.time()
# print 'NEXUS call start at time %f' % t1
# sys.stdout.flush()
nexus_tiles = list(tile_service.fetch_data_for_tiles(*tiles))
nexus_tiles = list(tile_service.mask_tiles_to_bbox(min_lat, max_lat,
min_lon, max_lon,
nexus_tiles))
# t2 = time.time()
# print 'NEXUS call end at time %f' % t2
# print 'Seconds in NEXUS call: ', t2-t1
# sys.stdout.flush()
# print 'Returning %d tiles' % len(nexus_tiles)
return nexus_tiles
@staticmethod
def _prune_tiles(nexus_tiles):
del_ind = np.where([np.all(tile.data.mask) for tile in nexus_tiles])[0]
for i in np.flipud(del_ind):
del nexus_tiles[i]
def _lat2ind(self, lat):
return int((lat - self._minLatCent) / self._latRes + 0.5)
def _lon2ind(self, lon):
return int((lon - self._minLonCent) / self._lonRes + 0.5)
def _ind2lat(self, y):
return self._minLatCent + y * self._latRes
def _ind2lon(self, x):
return self._minLonCent + x * self._lonRes
def _create_nc_file_time1d(self, a, fname, varname, varunits=None,
fill=None):
self.log.debug('a={0}'.format(a))
self.log.debug('shape a = {0}'.format(a.shape))
assert len(a.shape) == 1
time_dim = len(a)
rootgrp = Dataset(fname, "w", format="NETCDF4")
rootgrp.createDimension("time", time_dim)
vals = rootgrp.createVariable(varname, "f4", dimensions=("time",),
fill_value=fill)
times = rootgrp.createVariable("time", "f4", dimensions=("time",))
vals[:] = [d['mean'] for d in a]
times[:] = [d['time'] for d in a]
if varunits is not None:
vals.units = varunits
times.units = 'seconds since 1970-01-01 00:00:00'
rootgrp.close()
def _create_nc_file_latlon2d(self, a, fname, varname, varunits=None,
fill=None):
self.log.debug('a={0}'.format(a))
self.log.debug('shape a = {0}'.format(a.shape))
assert len(a.shape) == 2
lat_dim, lon_dim = a.shape
rootgrp = Dataset(fname, "w", format="NETCDF4")
rootgrp.createDimension("lat", lat_dim)
rootgrp.createDimension("lon", lon_dim)
vals = rootgrp.createVariable(varname, "f4",
dimensions=("lat", "lon",),
fill_value=fill)
lats = rootgrp.createVariable("lat", "f4", dimensions=("lat",))
lons = rootgrp.createVariable("lon", "f4", dimensions=("lon",))
vals[:, :] = a
lats[:] = np.linspace(self._minLatCent,
self._maxLatCent, lat_dim)
lons[:] = np.linspace(self._minLonCent,
self._maxLonCent, lon_dim)
if varunits is not None:
vals.units = varunits
lats.units = "degrees north"
lons.units = "degrees east"
rootgrp.close()
def _create_nc_file(self, a, fname, varname, **kwargs):
self._create_nc_file_latlon2d(a, fname, varname, **kwargs)
def _spark_nparts(self, nparts_requested):
max_parallelism = 128
num_partitions = min(nparts_requested if nparts_requested > 0
else self._sc.defaultParallelism,
max_parallelism)
return num_partitions
def _create_metrics_record(self):
return MetricsRecord([
SparkAccumulatorMetricsField(key='num_tiles',
description='Number of tiles fetched',
accumulator=self._sc.accumulator(0)),
SparkAccumulatorMetricsField(key='partitions',
description='Number of Spark partitions',
accumulator=self._sc.accumulator(0)),
SparkAccumulatorMetricsField(key='cassandra',
description='Cumulative time to fetch data from Cassandra',
accumulator=self._sc.accumulator(0)),
SparkAccumulatorMetricsField(key='solr',
description='Cumulative time to fetch data from Solr',
accumulator=self._sc.accumulator(0)),
SparkAccumulatorMetricsField(key='calculation',
description='Cumulative time to do calculations',
accumulator=self._sc.accumulator(0)),
NumberMetricsField(key='reduce', description='Actual time to reduce results'),
NumberMetricsField(key="actual_time", description="Total (actual) time")
])