#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import json
import logging
from stats.statistical_scoring import stat_score
from typing import Any, Dict

from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample, inline_serializer

from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import APIException, bad_request
from rest_framework.fields import CharField, FloatField, IntegerField
from rest_framework.response import Response

from api.models import Algorithm, Dataset, PredictionRequest
from api.serializers import AlgorithmSerializer, PredictionRequestSerializer, DatasetSerializer

from ml.classifiers import RandomForestClassifier

from server.wsgi import registry

# Create your views here.

log = logging.getLogger(__name__)


class AlgorithmViewSet(viewsets.ModelViewSet):
    # permission_classes = []
    serializer_class = AlgorithmSerializer
    queryset = Algorithm.objects.all()

    @extend_schema(
        description='Predict credit risk for a loan',
        parameters=[
            OpenApiParameter(
                name='classifier',
                description='The algorithm/classifier to use',
                required=True,
                examples=[
                    OpenApiExample(
                        'Example 1',
                        value=RandomForestClassifier().__class__.__name__)
                ]),
            OpenApiParameter(
                name='dataset',
                description='The name of the dataset',
                examples=[OpenApiExample('Example 1', value='german')]),
            OpenApiParameter(
                name='status',
                description='The status of the algorithm',
                deprecated=True,
                examples=[OpenApiExample('Example 1', value='production')]),
            OpenApiParameter(
                name='version',
                description='Algorithm version',
                required=True,
                default='0.0.1',
                examples=[OpenApiExample('Example 1', value='0.0.1')]),
        ],
        operation_id='algorithms_predict',
        request=Dict[str, Any],
        responses=inline_serializer(name="PredictionResponse",
                                    fields={
                                        "probability": FloatField(),
                                        "label": CharField(),
                                        "method": CharField(),
                                        "color": CharField(),
                                        "wilkis_lambda": FloatField(),
                                        "pillais_trace": FloatField(),
                                        "hotelling_tawley": FloatField(),
                                        "roys_reatest_roots": FloatField(),
                                        "request_id": IntegerField()
                                    }))
    @action(detail=False, methods=['post'])
    def predict(self, request, format=None):

        try:
            classifier = self.request.query_params.get("classifier")
            region = self.request.query_params.get("dataset", "german")
            version = self.request.query_params.get("version", "0.0.1")
            status = self.request.query_params.get("status", "production")

            print(request)

            if version is None:
                raise bad_request(
                    request=request,
                    data={
                        "error": "Missing required query parameter: version"
                    })
            if classifier is None:
                raise bad_request(
                    request=request,
                    data={
                        "error": "Missing required query parameter: classifier"
                    })

            if classifier in [
                    'manova', 'linearRegression', 'polynomialRegression'
            ]:
                prediction = stat_score(request.data, classifier)
                algorithm = None

            else:
                algorithm: Algorithm = Algorithm.objects.filter(
                    classifier=classifier,
                    status=status,
                    version=version,
                    dataset__name=region)[0]

                if algorithm is None:
                    raise bad_request(
                        request=request,
                        data={"error": "ML algorithm is not available"})
                classifier = registry.classifiers[algorithm.id]
                prediction = classifier.compute_prediction(request.data)

            if "label" in prediction:
                label = prediction["label"]
            else:
                label = prediction['method']

            prediction_request = PredictionRequest(input=json.dumps(
                request.data),
                                                   response=prediction,
                                                   prediction=label,
                                                   feedback="",
                                                   algorithm=algorithm)
            prediction_request.save()

            prediction["request_id"] = prediction_request.id

            return Response(prediction)
        except Exception as e:
            raise APIException(str(e))


class PredictionRequestViewSet(viewsets.ModelViewSet):
    # permission_classes = []
    serializer_class = PredictionRequestSerializer
    queryset = PredictionRequest.objects.all()


class DatasetViewSet(viewsets.ReadOnlyModelViewSet):
    # permission_classes = []
    serializer_class = DatasetSerializer
    queryset = Dataset.objects.all()
