django_airavata/apps/api/view_utils.py (180 lines of code) (raw):

import logging import os from collections.__init__ import OrderedDict from datetime import datetime from pathlib import Path import pytz from airavata_django_portal_sdk import user_storage from django.conf import settings from django.http import Http404 from django.http.request import QueryDict from rest_framework import mixins, pagination, permissions from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.utils.urls import remove_query_param, replace_query_param from rest_framework.viewsets import GenericViewSet logger = logging.getLogger(__name__) class GenericAPIBackedViewSet(GenericViewSet): # Make lookup_value_regex to any set of non-forward-slash characters. Many # Airavata ids contains period ('.') which the default lookup_value_regex # in DRF doesn't allow. lookup_value_regex = '[^/]+' def get_list(self): """ Subclasses must implement. """ raise NotImplementedError() def get_instance(self, lookup_value): """ Subclasses must implement. """ raise NotImplementedError() def get_queryset(self): if isinstance(self, mixins.ListModelMixin): return self.get_list() else: # get_queryset() is invoked whenever a detail extra action route # returns a many valued response. For ViewSets that have such # actions, return None here so they don't need to provide a # get_list() implementation return None def get_object(self): lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_value = self.kwargs[lookup_url_kwarg] inst = self.get_instance(lookup_value) if inst is None: raise Http404 self.check_object_permissions(self.request, inst) return inst @property def username(self): return self.request.user.username @property def gateway_id(self): return settings.GATEWAY_ID @property def authz_token(self): return self.request.authz_token class ReadOnlyAPIBackedViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericAPIBackedViewSet): """ A viewset that provides default `retrieve()` and `list()` actions. Subclasses must implement the following: * get_list(self) * get_instance(self, lookup_value) """ pass class APIBackedViewSet(mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, GenericAPIBackedViewSet): """ A viewset that provides default `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. Subclasses must implement the following: * get_list(self) * get_instance(self, lookup_value) * perform_create(self, serializer) - should return instance with id populated * perform_update(self, serializer) * perform_destroy(self, instance) """ pass class APIResultIterator(object): """ Iterable container over API results which allow limit/offset style slicing. """ limit = -1 offset = 0 def __init__(self, query_params=None): self.query_params = query_params if query_params is not None else QueryDict() def get_results(self, limit=-1, offset=0): raise NotImplementedError("Subclasses must implement get_results") def __iter__(self): results = self.get_results(self.limit, self.offset) for result in results: yield result def __getitem__(self, key): if isinstance(key, slice): self.limit = key.stop - key.start self.offset = key.start return iter(self) else: return self.get_results(1, key) class APIResultPagination(pagination.LimitOffsetPagination): """ Based on DRF's LimitOffsetPagination; Airavata API pagination results don't have a known count, so it isn't always possible to know how many pages there are. """ default_limit = 10 def paginate_queryset(self, queryset, request, view=None): assert isinstance( queryset, APIResultIterator), "queryset is not an APIResultIterator: {}".format(queryset) self.query_params = queryset.query_params.copy() self.limit = self.get_limit(request) if self.limit is None: return None self.offset = self.get_offset(request) self.request = request # When a paged view is called from another view (for example, to get the # initial data to display), this pagination class needs to know the name # of the view being paginated. if view and hasattr(view, 'pagination_viewname'): self.viewname = view.pagination_viewname return list(queryset[self.offset:self.offset + self.limit]) def get_limit(self, request): # If limit <= 0 then don't paginate if self.limit_query_param in request.query_params and int( request.query_params[self.limit_query_param]) <= 0: return None return super().get_limit(request) def get_paginated_response(self, data): has_next_link = len(data) >= self.limit return Response(OrderedDict([ ('next', self.get_next_link() if has_next_link else None), ('previous', self.get_previous_link()), ('results', data), ('limit', self.limit), ('offset', self.offset) ])) def get_next_link(self): url = self.get_base_url() url = replace_query_param(url, self.limit_query_param, self.limit) offset = self.offset + self.limit return replace_query_param(url, self.offset_query_param, offset) def get_previous_link(self): if self.offset <= 0: return None url = self.get_base_url() url = replace_query_param(url, self.limit_query_param, self.limit) if self.offset - self.limit <= 0: return remove_query_param(url, self.offset_query_param) offset = self.offset - self.limit return replace_query_param(url, self.offset_query_param, offset) def get_base_url(self): if hasattr(self, 'viewname'): base_url = self.request.build_absolute_uri(reverse(self.viewname)) if len(self.query_params) > 0: base_url += f"?{self.query_params.urlencode()}" return base_url else: return self.request.build_absolute_uri() def convert_utc_iso8601_to_date(iso8601_utc_string): # This is meant to convert a JavaScript `new Date().toJSON()` into a # datetime instance timestamp = datetime.strptime( iso8601_utc_string, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=pytz.UTC) logger.debug("convert_utc_iso8601_to_date({})={}".format( iso8601_utc_string, timestamp)) return timestamp class IsInAdminsGroupPermission(permissions.BasePermission): message = "User must be member of the Admins or Read Only Admins groups." def has_permission(self, request, view): # Read Only Admins can make GET requests only if request.method in permissions.SAFE_METHODS: return (request.is_gateway_admin or request.is_read_only_gateway_admin) else: return request.is_gateway_admin class ReadOnly(permissions.BasePermission): def has_permission(self, request, view): return request.method in permissions.SAFE_METHODS def is_shared_dir(path): shared_dirs: dict = getattr(settings, 'GATEWAY_DATA_SHARED_DIRECTORIES', {}) return any(map(lambda n: Path(n) == Path(path), shared_dirs.keys())) def is_shared_path(path): shared_dirs: dict = getattr(settings, 'GATEWAY_DATA_SHARED_DIRECTORIES', {}) # FIXME: path returned when creating a new directory in user storage is an # absolute path. Assume that when an absolute path is given that it was for # a newly created directory and so it is not a shared path if os.path.isabs(path): return False # check if path starts with a shared directory return any(map(lambda n: os.path.commonpath((n, path)) == n, shared_dirs.keys())) class BaseSharedDirPermission(permissions.BasePermission): def get_path(self, request, view) -> str: raise NotImplementedError() def has_permission(self, request, view): if request.method in permissions.SAFE_METHODS: return True path = self.get_path(request, view) # check if path starts with a shared directory shared_path = is_shared_path(path) shared_dir = is_shared_dir(path) if shared_path: # No user can delete a shared directory if shared_dir and request.method == 'DELETE': return False # Only admins can create/update/delete files/directories in a shared directory return request.is_gateway_admin return True class DataProductSharedDirPermission(BaseSharedDirPermission): def get_path(self, request, view) -> str: data_product_uri = request.query_params.get('data-product-uri', request.query_params.get('product-uri', '')) file_metadata = user_storage.get_data_product_metadata(request, data_product_uri=data_product_uri) return file_metadata["path"] def has_permission(self, request, view): # Special handling for remote API, just get the userHasWriteAccess attribute and use that if hasattr(settings, 'GATEWAY_DATA_STORE_REMOTE_API'): if request.method in permissions.SAFE_METHODS: return True data_product_uri = request.query_params.get('data-product-uri', request.query_params.get('product-uri', '')) file_metadata = user_storage.get_data_product_metadata(request, data_product_uri=data_product_uri) return file_metadata["userHasWriteAccess"] else: return super().has_permission(request, view) class UserStorageSharedDirPermission(BaseSharedDirPermission): def get_path(self, request, view): # 'path' can be a url path parameter, query parameter or in the request body (data) return request.query_params.get('path', request.data.get('path', view.kwargs.get('path')))