community/front-end/ofe/website/ghpcfe/views/asyncview.py (73 lines of code) (raw):
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" asyncviews.py """
import asyncio
import functools
import logging
from asgiref.sync import sync_to_async
from django.core import exceptions
from django.utils.decorators import classonlymethod
from django.views import generic
from rest_framework import viewsets
from rest_framework.authentication import SessionAuthentication
from rest_framework.authentication import TokenAuthentication
from rest_framework.authtoken.models import Token
from rest_framework.permissions import IsAuthenticated
from ..models import Cluster, Role, Task
from ..serializers import TaskSerializer
logger = logging.getLogger(__name__)
class RunningTasksViewSet(viewsets.ModelViewSet):
permission_classes = (IsAuthenticated,)
queryset = Task.objects.all()
serializer_class = TaskSerializer
authentication_classes = [SessionAuthentication, TokenAuthentication]
def _consume_task(record, task):
logger.info("Task %s complete.", task.get_name(), exc_info=task.exception())
if record:
logger.info(
" destroying task record %d-%s. Data: %s",
record.id,
record.title,
record.data
)
asyncio.create_task(sync_to_async(record.delete)())
class BackendAsyncView(generic.View):
"""Template class for backend async operations"""
@classonlymethod
def as_view(cls, **initkwargs):
view = super().as_view(**initkwargs)
view._is_coroutine = asyncio.coroutines._is_coroutine # pylint: disable=protected-access
return view
@sync_to_async
def test_user_access_to_cluster(self, user, cluster_id):
cluster = Cluster.objects.get(pk=cluster_id)
if user not in cluster.authorised_users.all():
raise exceptions.PermissionDenied
@sync_to_async
def test_user_is_cluster_admin(self, user):
if Role.CLUSTERADMIN not in [x.id for x in user.roles.all()]:
raise exceptions.PermissionDenied
@sync_to_async
def make_task_record(self, user, title):
task_data = self.get_task_record_data(self.request)
t = Task.objects.create(owner=user, title=title, data=task_data)
t.save()
return t
@sync_to_async
def get_user_token(self, user):
token = Token.objects.get(user=user)
return token.key
@sync_to_async
def set_cluster_status_async(self, cluster_id, status):
self.set_cluster_status(cluster_id, status)
def set_cluster_status(self, cluster_id, status):
c = Cluster.objects.get(pk=cluster_id)
c.status = status
c.save()
def get_task_record_data(self, request):
"""Called from a synchronous context"""
return {}
async def _cmd(self, *args, **kwargs):
await sync_to_async(self.cmd, thread_sensitive=False)(*args, **kwargs)
async def create_task(self, title, *args, **kwargs):
logger.info("Creating task %s", title)
token = await self.get_user_token(self.request.user)
record = await self.make_task_record(self.request.user, title=title)
task = asyncio.create_task(self._cmd(record.pk, token, *args, **kwargs))
task.add_done_callback(functools.partial(_consume_task, record))
return record