# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
# with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import datetime

from flask import Response, request
from flask.json.provider import DefaultJSONProvider
from werkzeug.routing import BaseConverter

import api.utils as utils
from api.PclusterApiHandler import (
    authenticated,
    cancel_job,
    create_user,
    delete_user,
    ec2_action,
    get_app_config,
    get_aws_config,
    get_cluster_config,
    get_custom_image_config,
    get_dcv_session,
    get_identity,
    get_version,
    get_instance_types,
    list_users,
    login,
    logout,
    price_estimate,
    queue_status,
    sacct,
    scontrol_job,
    CLIENT_ID, CLIENT_SECRET, USER_POOL_ID, pc
)
from api.costmonitoring import costs
from api.logging import parse_log_entry, push_log_entry
from api.pcm_globals import logger
from api.security.csrf import CSRF
from api.security.csrf.csrf import csrf_needed
from api.security.fingerprint import CognitoFingerprintGenerator
from api.validation import validated, EC2Action
from api.validation.schemas import CreateUser, DeleteUser, GetClusterConfig, GetCustomImageConfig, GetAwsConfig, GetInstanceTypes,\
     Login, PushLog, PriceEstimate, GetDcvSession, QueueStatus, ScontrolJob, CancelJob, Sacct

ADMINS_GROUP = { "admin" }

class RegexConverter(BaseConverter):
    def __init__(self, url_map, *items):
        super(RegexConverter, self).__init__(url_map)
        self.regex = items[0]


class PClusterJSONEncoder(DefaultJSONProvider):
    """Make the model objects JSON serializable."""

    include_nulls = False

    def default(self, obj):
        if isinstance(obj, datetime.date):
            return utils.to_iso_timestr(obj)
        return DefaultJSONProvider.default(self, obj)


def run():
    app = utils.build_flask_app(__name__)
    app.config["APPLICATION_ROOT"] = '/pcui'
    app.json = PClusterJSONEncoder(app)
    app.url_map.converters["regex"] = RegexConverter
    CSRF(app, CognitoFingerprintGenerator(CLIENT_ID, CLIENT_SECRET, USER_POOL_ID))

    @app.errorhandler(401)
    def custom_401(_error):
        return Response(
            "You are not authorized to perform this action.", 401
        )

    @app.errorhandler(404)
    def page_not_found(_error):
        return utils.serve_frontend(app)

    @app.route("/", defaults={"path": ""})
    @app.route('/<path:path>')
    def serve(path):
        return utils.serve_frontend(app, path)

    @app.route("/manager/ec2_action", methods=["POST"])
    @authenticated(ADMINS_GROUP)
    @csrf_needed
    @validated(params=EC2Action)
    def ec2_action_():
        return ec2_action()

    @app.route("/manager/get_cluster_configuration")
    @authenticated(ADMINS_GROUP)
    @validated(params=GetClusterConfig)
    def get_cluster_config_():
        return get_cluster_config()

    @app.route("/manager/get_custom_image_configuration")
    @authenticated(ADMINS_GROUP)
    @validated(params=GetCustomImageConfig)
    def get_custom_image_config_():
        return get_custom_image_config()

    @app.route("/manager/get_aws_configuration")
    @authenticated(ADMINS_GROUP)
    @validated(params=GetAwsConfig)
    def get_aws_config_():
        return get_aws_config()

    @app.route("/manager/get_instance_types")
    @authenticated(ADMINS_GROUP)
    @validated(params=GetInstanceTypes)
    def get_instance_types_():
        return get_instance_types()

    @app.route("/manager/get_dcv_session")
    @authenticated(ADMINS_GROUP)
    @validated(params=GetDcvSession)
    def get_dcv_session_():
        return get_dcv_session()

    @app.route("/manager/get_identity")
    @authenticated(ADMINS_GROUP)
    def get_identity_():
        return get_identity()

    @app.route("/manager/get_version")
    def get_version_():
        return get_version()

    @app.route("/manager/get_app_config")
    def get_app_config_():
        return get_app_config()

    @app.route("/manager/list_users")
    @authenticated(ADMINS_GROUP)
    def list_users_():
        return list_users()

    @app.route("/manager/create_user", methods=["POST"])
    @authenticated(ADMINS_GROUP)
    @csrf_needed
    @validated(body=CreateUser)
    def create_user_():
        return create_user()

    @app.route("/manager/delete_user", methods=["DELETE"])
    @authenticated(ADMINS_GROUP)
    @csrf_needed
    @validated(params=DeleteUser)
    def delete_user_():
        return delete_user()

    @app.route("/manager/queue_status")
    @authenticated(ADMINS_GROUP)
    @validated(params=QueueStatus)
    def queue_status_():
        return queue_status()

    @app.route("/manager/cancel_job")
    @authenticated(ADMINS_GROUP)
    @validated(params=CancelJob)
    def cancel_job_():
        return cancel_job()

    @app.route("/manager/price_estimate")
    @authenticated(ADMINS_GROUP)
    @validated(params=PriceEstimate)
    def price_estimate_():
        return price_estimate()

    @app.route("/manager/sacct", methods=["POST"])
    @authenticated(ADMINS_GROUP)
    @csrf_needed
    @validated(params=Sacct)
    def sacct_():
        return sacct()

    @app.route("/manager/scontrol_job")
    @authenticated(ADMINS_GROUP)
    @validated(params=ScontrolJob)
    def scontrol_job_():
        return scontrol_job()

    @app.route("/login")
    @validated(params=Login)
    def login_():
        return login()

    @app.route("/logout")
    def logout_():
        return logout()

    @app.route('/logs', methods=['POST'])
    @authenticated(ADMINS_GROUP)
    @csrf_needed
    @validated(body=PushLog)
    def push_log():
        for entry in request.json['logs']:
            level, message, extra = parse_log_entry(logger, entry)
            push_log_entry(logger, level, message, extra)

        return {}, 200

    @app.route('/<regex("(home|clusters|users|configure|images).*"):base>', defaults={"base": ""})
    def catch_all(base):
        return utils.serve_frontend(app, base)

    @app.route('/<regex("(home|clusters|users|configure|images).*"):base>/<path:u_path>', defaults={"base": "", "u_path": ""})
    def catch_all2(base, u_path):
        return utils.serve_frontend(app, base)

    app.register_blueprint(pc, url_prefix='/api')
    app.register_blueprint(costs, url_prefix='/cost-monitoring')
    return app


if __name__ == "__main__":
    run()
