api/views.py (118 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import json import logging from stats.statistical_scoring import stat_score from typing import Any, Dict from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample, inline_serializer from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.exceptions import APIException, bad_request from rest_framework.fields import CharField, FloatField, IntegerField from rest_framework.response import Response from api.models import Algorithm, Dataset, PredictionRequest from api.serializers import AlgorithmSerializer, PredictionRequestSerializer, DatasetSerializer from ml.classifiers import RandomForestClassifier from server.wsgi import registry # Create your views here. log = logging.getLogger(__name__) class AlgorithmViewSet(viewsets.ModelViewSet): # permission_classes = [] serializer_class = AlgorithmSerializer queryset = Algorithm.objects.all() @extend_schema( description='Predict credit risk for a loan', parameters=[ OpenApiParameter( name='classifier', description='The algorithm/classifier to use', required=True, examples=[ OpenApiExample( 'Example 1', value=RandomForestClassifier().__class__.__name__) ]), OpenApiParameter( name='dataset', description='The name of the dataset', examples=[OpenApiExample('Example 1', value='german')]), OpenApiParameter( name='status', description='The status of the algorithm', deprecated=True, examples=[OpenApiExample('Example 1', value='production')]), OpenApiParameter( name='version', description='Algorithm version', required=True, default='0.0.1', examples=[OpenApiExample('Example 1', value='0.0.1')]), ], operation_id='algorithms_predict', request=Dict[str, Any], responses=inline_serializer(name="PredictionResponse", fields={ "probability": FloatField(), "label": CharField(), "method": CharField(), "color": CharField(), "wilkis_lambda": FloatField(), "pillais_trace": FloatField(), "hotelling_tawley": FloatField(), "roys_reatest_roots": FloatField(), "request_id": IntegerField() })) @action(detail=False, methods=['post']) def predict(self, request, format=None): try: classifier = self.request.query_params.get("classifier") region = self.request.query_params.get("dataset", "german") version = self.request.query_params.get("version", "0.0.1") status = self.request.query_params.get("status", "production") print(request) if version is None: raise bad_request( request=request, data={ "error": "Missing required query parameter: version" }) if classifier is None: raise bad_request( request=request, data={ "error": "Missing required query parameter: classifier" }) if classifier in [ 'manova', 'linearRegression', 'polynomialRegression' ]: prediction = stat_score(request.data, classifier) algorithm = None else: algorithm: Algorithm = Algorithm.objects.filter( classifier=classifier, status=status, version=version, dataset__name=region)[0] if algorithm is None: raise bad_request( request=request, data={"error": "ML algorithm is not available"}) classifier = registry.classifiers[algorithm.id] prediction = classifier.compute_prediction(request.data) if "label" in prediction: label = prediction["label"] else: label = prediction['method'] prediction_request = PredictionRequest(input=json.dumps( request.data), response=prediction, prediction=label, feedback="", algorithm=algorithm) prediction_request.save() prediction["request_id"] = prediction_request.id return Response(prediction) except Exception as e: raise APIException(str(e)) class PredictionRequestViewSet(viewsets.ModelViewSet): # permission_classes = [] serializer_class = PredictionRequestSerializer queryset = PredictionRequest.objects.all() class DatasetViewSet(viewsets.ReadOnlyModelViewSet): # permission_classes = [] serializer_class = DatasetSerializer queryset = Dataset.objects.all()