data-access/nexustiles/backends/zarr/backend.py (324 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from datetime import datetime
from urllib.parse import urlparse
import numpy as np
import numpy.ma as ma
import s3fs
import xarray as xr
from nexustiles.AbstractTileService import AbstractTileService
from nexustiles.exception import NexusTileServiceException
from nexustiles.model.nexusmodel import Tile, BBox, TileVariable
from pytz import timezone
from shapely.geometry import MultiPolygon, box
from yarl import URL
EPOCH = timezone('UTC').localize(datetime(1970, 1, 1))
ISO_8601 = '%Y-%m-%dT%H:%M:%S%z'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt="%Y-%m-%dT%H:%M:%S", stream=sys.stdout)
logger = logging.getLogger(__name__)
class ZarrBackend(AbstractTileService):
def __init__(self, dataset_name, path, config=None):
AbstractTileService.__init__(self, dataset_name)
config = config if config is not None else {}
self.__config = config
logger.info(f'Opening zarr backend at {path} for dataset {self._name}')
url = urlparse(path)
self.__url = path
self.__store_type = url.scheme
self.__host = url.netloc
self.__path = url.path
if 'variable' in config:
data_vars = config['variable']
elif 'variables' in config:
data_vars = config['variables']
else:
raise KeyError('Data variables not provided in config')
if isinstance(data_vars, str):
self.__variables = [data_vars]
elif isinstance(data_vars, list):
self.__variables = data_vars
else:
raise TypeError(f'Improper type for variables config: {type(data_vars)}')
self.__variables = [v.strip('\"\'') for v in self.__variables]
self.__longitude = config['coords']['longitude']
self.__latitude = config['coords']['latitude']
self.__time = config['coords']['time']
self.__depth = config['coords'].get('depth')
if self.__store_type in ['', 'file']:
store = self.__path
elif self.__store_type == 's3':
try:
aws_cfg = self.__config['aws']
if aws_cfg['public']:
# region = aws_cfg.get('region', 'us-west-2')
# store = f'https://{self.__host}.s3.{region}.amazonaws.com{self.__path}'
s3 = s3fs.S3FileSystem(True)
store = s3fs.S3Map(root=path, s3=s3, check=False)
else:
s3 = s3fs.S3FileSystem(False, key=aws_cfg['accessKeyID'], secret=aws_cfg['secretAccessKey'])
store = s3fs.S3Map(root=path, s3=s3, check=False)
except Exception as e:
logger.error(f'Failed to open zarr dataset at {self.__path}, ignoring it. Cause: {e}')
raise NexusTileServiceException(f'Cannot open S3 dataset ({e})')
else:
raise ValueError(self.__store_type)
try:
self.__ds: xr.Dataset = xr.open_zarr(store, consolidated=True)
except Exception as e:
logger.error(f'Failed to open zarr dataset at {self.__path}, ignoring it. Cause: {e}')
raise NexusTileServiceException(f'Cannot open dataset ({e})')
lats = self.__ds[self.__latitude].to_numpy()
delta = lats[1] - lats[0]
if delta < 0:
logger.warning(f'Latitude coordinate for {self._name} is in descending order. Flipping it to ascending')
self.__ds = self.__ds.isel({self.__latitude: slice(None, None, -1)})
def heartbeat(self) -> bool:
# TODO: This is temporary, eventually we should use the logic to be introduced for SDAP-517 (PR#312) to evaluate
# if data is accessible currently.
return True
def get_dataseries_list(self, simple=False):
ds = {
"shortName": self._name,
"title": self._name,
"type": "zarr"
}
if not simple:
try:
min_date = self.get_min_time([])
max_date = self.get_max_time([])
ds['start'] = min_date
ds['end'] = max_date
ds['iso_start'] = datetime.utcfromtimestamp(min_date).strftime(ISO_8601)
ds['iso_end'] = datetime.utcfromtimestamp(max_date).strftime(ISO_8601)
ds['metadata'] = dict(self.__ds.attrs)
except Exception as e:
logger.error(f'Failed to access dataset for {self._name}. Cause: {e}')
ds['error'] = "Dataset is currently unavailable"
ds['error_reason'] = str(e)
return [ds]
def find_tile_by_id(self, tile_id, **kwargs):
return [tile_id]
def find_tiles_by_id(self, tile_ids, ds=None, **kwargs):
return tile_ids
def find_days_in_range_asc(self, min_lat, max_lat, min_lon, max_lon, dataset, start_time, end_time,
metrics_callback=None, **kwargs):
start = datetime.now()
if not isinstance(start_time, datetime):
start_time = datetime.utcfromtimestamp(start_time)
if not isinstance(end_time, datetime):
end_time = datetime.utcfromtimestamp(end_time)
sel = {
self.__latitude: slice(min_lat, max_lat),
self.__longitude: slice(min_lon, max_lon),
self.__time: slice(start_time, end_time)
}
times = self.__ds.sel(sel)[self.__time].to_numpy()
if np.issubdtype(times.dtype, np.datetime64):
times = (times - np.datetime64(EPOCH)).astype('timedelta64[s]').astype(int)
times = sorted(times.tolist())
if metrics_callback:
metrics_callback(backend=(datetime.now() - start).total_seconds())
return times
def find_tile_by_polygon_and_most_recent_day_of_year(self, bounding_polygon, ds, day_of_year, **kwargs):
"""
Given a bounding polygon, dataset, and day of year, find tiles in that dataset with the same bounding
polygon and the closest day of year.
For example:
given a polygon minx=0, miny=0, maxx=1, maxy=1; dataset=MY_DS; and day of year=32
search for first tile in MY_DS with identical bbox and day_of_year <= 32 (sorted by day_of_year desc)
Valid matches:
minx=0, miny=0, maxx=1, maxy=1; dataset=MY_DS; day of year = 32
minx=0, miny=0, maxx=1, maxy=1; dataset=MY_DS; day of year = 30
Invalid matches:
minx=1, miny=0, maxx=2, maxy=1; dataset=MY_DS; day of year = 32
minx=0, miny=0, maxx=1, maxy=1; dataset=MY_OTHER_DS; day of year = 32
minx=0, miny=0, maxx=1, maxy=1; dataset=MY_DS; day of year = 30 if minx=0, miny=0, maxx=1, maxy=1; dataset=MY_DS; day of year = 32 also exists
:param bounding_polygon: The exact bounding polygon of tiles to search for
:param ds: The dataset name being searched
:param day_of_year: Tile day of year to search for, tile nearest to this day (without going over) will be returned
:return: List of one tile from ds with bounding_polygon on or before day_of_year or raise NexusTileServiceException if no tile found
"""
times = self.__ds[self.__time].to_numpy()
to_doy = lambda dt: datetime.utcfromtimestamp(int(dt)).timetuple().tm_yday
vfunc = np.vectorize(to_doy)
days_of_year = vfunc(times.astype(datetime) / 1e9)
try:
time = times[np.where(days_of_year <= day_of_year)[0][-1]].astype(datetime) / 1e9
except IndexError:
raise NexusTileServiceException(reason='No tiles matched')
min_lon, min_lat, max_lon, max_lat = bounding_polygon.bounds
return self.find_tiles_in_box(
min_lat, max_lat, min_lon, max_lon, ds, time, time
)
def find_all_tiles_in_box_at_time(self, min_lat, max_lat, min_lon, max_lon, dataset, time, **kwargs):
return self.find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, dataset, time, time, **kwargs)
def find_all_tiles_in_polygon_at_time(self, bounding_polygon, dataset, time, **kwargs):
return self.find_tiles_in_polygon(bounding_polygon, dataset, time, time, **kwargs)
def find_tiles_in_box(self, min_lat, max_lat, min_lon, max_lon, ds=None, start_time=0, end_time=-1, **kwargs):
if type(start_time) is datetime:
start_time = (start_time - EPOCH).total_seconds()
if type(end_time) is datetime:
end_time = (end_time - EPOCH).total_seconds()
params = {
'min_lat': min_lat,
'max_lat': max_lat,
'min_lon': min_lon,
'max_lon': max_lon
}
times = None
if 0 <= start_time <= end_time:
if kwargs.get('distinct', True):
times_asc = self.find_days_in_range_asc(min_lat, max_lat, min_lon, max_lon, ds, start_time, end_time)
times = [(t, t) for t in times_asc]
else:
times = [(start_time, end_time)]
if 'depth' in kwargs:
params['depth'] = kwargs['depth']
elif 'min_depth' in kwargs or 'max_depth' in kwargs:
params['min_depth'] = kwargs.get('min_depth')
params['max_depth'] = kwargs.get('max_depth')
if times:
return [ZarrBackend.__to_url(self._name, min_time=t[0], max_time=t[1], **params) for t in times]
else:
return [ZarrBackend.__to_url(self._name, **params)]
def find_tiles_in_polygon(self, bounding_polygon, ds=None, start_time=None, end_time=None, **kwargs):
# Find tiles that fall within the polygon in the Solr index
bounds = bounding_polygon.bounds
min_lon = bounds[0]
min_lat = bounds[1]
max_lon = bounds[2]
max_lat = bounds[3]
return self.find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, ds, start_time, end_time, **kwargs)
def find_tiles_by_metadata(self, metadata, ds=None, start_time=0, end_time=-1, **kwargs):
"""
Return list of tiles whose metadata matches the specified metadata, start_time, end_time.
:param metadata: List of metadata values to search for tiles e.g ["river_id_i:1", "granule_s:granule_name"]
:param ds: The dataset name to search
:param start_time: The start time to search for tiles
:param end_time: The end time to search for tiles
:return: A list of tiles
"""
raise NotImplementedError()
def find_tiles_by_exact_bounds(self, bounds, ds, start_time, end_time, **kwargs):
"""
The method will return tiles with the exact given bounds within the time range. It differs from
find_tiles_in_polygon in that only tiles with exactly the given bounds will be returned as opposed to
doing a polygon intersection with the given bounds.
:param bounds: (minx, miny, maxx, maxy) bounds to search for
:param ds: Dataset name to search
:param start_time: Start time to search (seconds since epoch)
:param end_time: End time to search (seconds since epoch)
:param kwargs: fetch_data: True/False = whether or not to retrieve tile data
:return:
"""
min_lon = bounds[0]
min_lat = bounds[1]
max_lon = bounds[2]
max_lat = bounds[3]
return self.find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, ds, start_time, end_time, **kwargs)
def find_all_boundary_tiles_at_time(self, min_lat, max_lat, min_lon, max_lon, dataset, time, **kwargs):
# Due to the precise nature of gridded Zarr's subsetting, it doesn't make sense to have a boundary region like
# this
return []
def find_tiles_along_line(self, start_point, end_point, ds=None, start_time=0, end_time=-1, **kwargs):
raise NotImplementedError()
def get_min_max_time_by_granule(self, ds, granule_name):
raise NotImplementedError()
def get_dataset_overall_stats(self, ds):
raise NotImplementedError()
def get_stats_within_box_at_time(self, min_lat, max_lat, min_lon, max_lon, dataset, time, **kwargs):
raise NotImplementedError()
def get_bounding_box(self, tile_ids):
"""
Retrieve a bounding box that encompasses all of the tiles represented by the given tile ids.
:param tile_ids: List of tile ids
:return: shapely.geometry.Polygon that represents the smallest bounding box that encompasses all of the tiles
"""
bounds = [
(
float(URL(u).query['min_lon']),
float(URL(u).query['min_lat']),
float(URL(u).query['max_lon']),
float(URL(u).query['max_lat'])
)
for u in tile_ids
]
poly = MultiPolygon([box(*b) for b in bounds])
return box(*poly.bounds)
def __get_ds_min_max_date(self):
min_date = self.__ds[self.__time].min().to_numpy()
max_date = self.__ds[self.__time].max().to_numpy()
if np.issubdtype(min_date.dtype, np.datetime64):
min_date = (min_date - np.datetime64(EPOCH)).astype('timedelta64[s]').astype(int).item()
if np.issubdtype(max_date.dtype, np.datetime64):
max_date = (max_date - np.datetime64(EPOCH)).astype('timedelta64[s]').astype(int).item()
return min_date, max_date
def get_min_time(self, tile_ids, ds=None):
"""
Get the minimum tile date from the list of tile ids
:param tile_ids: List of tile ids
:param ds: Filter by a specific dataset. Defaults to None (queries all datasets)
:return: long time in seconds since epoch
"""
times = list(filter(lambda x: x is not None, [int(URL(tid).query['min_time']) for tid in tile_ids]))
if len(times) == 0:
min_date, max_date = self.__get_ds_min_max_date()
return min_date
else:
return min(times)
def get_max_time(self, tile_ids, ds=None):
"""
Get the maximum tile date from the list of tile ids
:param tile_ids: List of tile ids
:param ds: Filter by a specific dataset. Defaults to None (queries all datasets)
:return: long time in seconds since epoch
"""
times = list(filter(lambda x: x is not None, [int(URL(tid).query['max_time']) for tid in tile_ids]))
if len(tile_ids) == 0:
min_date, max_date = self.__get_ds_min_max_date()
return max_date
else:
return max(times)
def get_distinct_bounding_boxes_in_polygon(self, bounding_polygon, ds, start_time, end_time):
"""
Get a list of distinct tile bounding boxes from all tiles within the given polygon and time range.
:param bounding_polygon: The bounding polygon of tiles to search for
:param ds: The dataset name to search
:param start_time: The start time to search for tiles
:param end_time: The end time to search for tiles
:return: A list of distinct bounding boxes (as shapely polygons) for tiles in the search polygon
"""
raise NotImplementedError()
def get_tile_count(self, ds, bounding_polygon=None, start_time=0, end_time=-1, metadata=None, **kwargs):
"""
Return number of tiles that match search criteria.
:param ds: The dataset name to search
:param bounding_polygon: The polygon to search for tiles
:param start_time: The start time to search for tiles
:param end_time: The end time to search for tiles
:param metadata: List of metadata values to search for tiles e.g ["river_id_i:1", "granule_s:granule_name"]
:return: number of tiles that match search criteria
"""
raise NotImplementedError()
def fetch_data_for_tiles(self, *tiles):
for tile in tiles:
self.__fetch_data_for_tile(tile)
return tiles
def __fetch_data_for_tile(self, tile: Tile):
bbox: BBox = tile.bbox
min_lat = None
min_lon = None
max_lat = None
max_lon = None
min_time = tile.min_time
max_time = tile.max_time
# if min_time:
# min_time = datetime.utcfromtimestamp(min_time)
#
# if max_time:
# max_time = datetime.utcfromtimestamp(max_time)
if bbox:
min_lat = bbox.min_lat
min_lon = bbox.min_lon
max_lat = bbox.max_lat
max_lon = bbox.max_lon
sel_g = {
self.__latitude: slice(min_lat, max_lat),
self.__longitude: slice(min_lon, max_lon),
}
sel_t = {}
if min_time is None and max_time is None:
sel_t = None
method = None
elif min_time == max_time:
sel_t[self.__time] = [min_time] # List, otherwise self.__time dim will be dropped
method = 'nearest'
else:
sel_t[self.__time] = slice(min_time, max_time)
method = None
tile.variables = [
TileVariable(v, v) for v in self.__variables
]
matched = self.__ds.sel(sel_g)
if sel_t is not None:
matched = matched.sel(sel_t, method=method)
tile.latitudes = ma.masked_invalid(matched[self.__latitude].to_numpy())
tile.longitudes = ma.masked_invalid(matched[self.__longitude].to_numpy())
times = matched[self.__time].to_numpy()
if np.issubdtype(times.dtype, np.datetime64):
times = (times - np.datetime64(EPOCH)).astype('timedelta64[s]').astype(int)
tile.times = ma.masked_invalid(times)
var_data = [matched[var].to_numpy() for var in self.__variables]
if len(self.__variables) > 1:
tile.data = ma.masked_invalid(var_data)
tile.is_multi = True
else:
tile.data = ma.masked_invalid(var_data[0])
tile.is_multi = False
def _metadata_store_docs_to_tiles(self, *store_docs):
return [ZarrBackend.__nts_url_to_tile(d) for d in store_docs]
@staticmethod
def __nts_url_to_tile(nts_url):
tile = Tile()
url = URL(nts_url)
tile.tile_id = nts_url
try:
min_lat = float(url.query['min_lat'])
min_lon = float(url.query['min_lon'])
max_lat = float(url.query['max_lat'])
max_lon = float(url.query['max_lon'])
tile.bbox = BBox(min_lat, max_lat, min_lon, max_lon)
except KeyError:
pass
tile.dataset = url.path
tile.dataset_id = url.path
try:
# tile.min_time = int(url.query['min_time'])
tile.min_time = datetime.utcfromtimestamp(int(url.query['min_time']))
except KeyError:
pass
try:
# tile.max_time = int(url.query['max_time'])
tile.max_time = datetime.utcfromtimestamp(int(url.query['max_time']))
except KeyError:
pass
tile.meta_data = {}
return tile
@staticmethod
def __to_url(dataset, **kwargs):
if 'dataset' in kwargs:
del kwargs['dataset']
if 'ds' in kwargs:
del kwargs['ds']
params = {}
# If any params are numpy dtypes, extract them to base python types
for kw in kwargs:
v = kwargs[kw]
if v is None:
continue
if isinstance(v, np.generic):
v = v.item()
params[kw] = v
return str(URL.build(
scheme='nts',
host='',
path=dataset,
query=params
))