"""Handles Swagger Parsing"""

import logging
from typing import Dict, List, Union

from samcli.commands.local.lib.swagger.integration_uri import IntegrationType, LambdaUri
from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator
from samcli.local.apigw.authorizers.authorizer import Authorizer
from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer
from samcli.local.apigw.exceptions import (
    IncorrectOasWithDefaultAuthorizerException,
    InvalidOasVersion,
    InvalidSecurityDefinition,
    MultipleAuthorizerException,
)
from samcli.local.apigw.route import Route

LOG = logging.getLogger(__name__)


class SwaggerParser:
    _AUTHORIZER_KEY = "x-amazon-apigateway-authorizer"
    _INTEGRATION_KEY = "x-amazon-apigateway-integration"
    _ANY_METHOD_EXTENSION_KEY = "x-amazon-apigateway-any-method"
    _BINARY_MEDIA_TYPES_EXTENSION_KEY = "x-amazon-apigateway-binary-media-types"  # pylint: disable=C0103
    _ANY_METHOD = "ANY"

    _SWAGGER = "swagger"
    _OPENAPI = "openapi"
    _2_X_VERSION = "2."
    _3_X_VERSION = "3."
    _SWAGGER_COMPONENTS = "components"
    _SWAGGER_SECURITY = "security"
    _SWAGGER_SECURITY_SCHEMES = "securitySchemes"
    _SWAGGER_SECURITY_DEFINITIONS = "securityDefinitions"
    _AUTHORIZER_TYPE = "type"
    _AUTHORIZER_PAYLOAD_VERSION = "authorizerPayloadFormatVersion"
    _AUTHORIZER_LAMBDA_URI = "authorizerUri"
    _AUTHORIZER_LAMBDA_VALIDATION = "identityValidationExpression"
    _AUTHORIZER_NAME = "name"
    _AUTHORIZER_IN = "in"
    _AUTHORIZER_IDENTITY_SOURCE = "identitySource"
    _AUTHORIZER_SIMPLE_RESPONSES = "enableSimpleResponses"

    def __init__(self, stack_path: str, swagger):
        """
        Constructs an Swagger Parser object

        :param str stack_path: Path of the stack the resource is located
        :param dict swagger: Dictionary representation of a Swagger document
        """
        self.swagger = swagger or {}
        self.stack_path = stack_path

    def get_binary_media_types(self):
        """
        Get the list of Binary Media Types from Swagger

        Returns
        -------
        list of str
            List of strings that represent the Binary Media Types for the API, defaulting to empty list is None

        """
        return self.swagger.get(self._BINARY_MEDIA_TYPES_EXTENSION_KEY) or []

    def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]:
        """
        Parse Swagger document and returns a list of Authorizer objects

        Parameters
        ----------
        event_type: str
            String indicating what type of API Gateway this is

        Returns
        -------
        dict[str, Authorizer]
            A map of authorizer names and Authorizer objects found in the body definition
        """
        authorizers: Dict[str, Authorizer] = {}

        authorizer_dict = {}
        document_version = self._get_document_version()

        if document_version.startswith(SwaggerParser._2_X_VERSION):
            LOG.debug("Parsing Swagger document using 2.0 specification")
            authorizer_dict = self.swagger.get(SwaggerParser._SWAGGER_SECURITY_DEFINITIONS, {})
        elif document_version.startswith(SwaggerParser._3_X_VERSION):
            LOG.debug("Parsing Swagger document using 3.0 specification")
            authorizer_dict = self.swagger.get(SwaggerParser._SWAGGER_COMPONENTS, {}).get(
                SwaggerParser._SWAGGER_SECURITY_SCHEMES, {}
            )
        else:
            raise InvalidOasVersion(
                f"An invalid OpenApi version was detected: '{document_version}', must be one of 2.x or 3.x",
            )

        for auth_name, properties in authorizer_dict.items():
            authorizer_object = properties.get(self._AUTHORIZER_KEY)

            if not authorizer_object:
                LOG.warning("Skip parsing unsupported authorizer '%s'", auth_name)
                continue

            authorizer_type = authorizer_object.get(SwaggerParser._AUTHORIZER_TYPE, "").lower()
            payload_version = authorizer_object.get(SwaggerParser._AUTHORIZER_PAYLOAD_VERSION)

            if event_type == Route.HTTP and payload_version not in LambdaAuthorizer.PAYLOAD_VERSIONS:
                raise InvalidSecurityDefinition(f"Authorizer '{auth_name}' contains an invalid payload version")

            if event_type == Route.API:
                payload_version = LambdaAuthorizer.PAYLOAD_V1

            lambda_name = LambdaUri.get_function_name(authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_URI))

            if not lambda_name:
                LOG.warning("Unable to parse authorizerUri '%s' for authorizer '%s', skipping", lambda_name, auth_name)
                continue

            # only add authorizer if it is Lambda token or request based (not jwt)
            if authorizer_type not in LambdaAuthorizer.VALID_TYPES:
                LOG.warning("Lambda authorizer '%s' type '%s' is unsupported, skipping", auth_name, authorizer_type)
                continue

            identity_sources = self._get_lambda_identity_sources(
                auth_name, authorizer_type, event_type, properties, authorizer_object
            )

            validation_expression = authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_VALIDATION)
            if event_type == Route.HTTP and validation_expression:
                validation_expression = None

                LOG.warning(
                    "Validation expressions is only available on REST APIs, ignoring for Lambda authorizer '%s'",
                    auth_name,
                )

            enable_simple_response = authorizer_object.get(SwaggerParser._AUTHORIZER_SIMPLE_RESPONSES, False)

            if (
                event_type != Route.HTTP
                or payload_version != LambdaAuthorizer.PAYLOAD_V2
                or not isinstance(enable_simple_response, bool)
            ):
                enable_simple_response = False

                if authorizer_object.get(SwaggerParser._AUTHORIZER_SIMPLE_RESPONSES) is not None:
                    LOG.warning(
                        "Simple responses are only available on HTTP APIs with payload version "
                        "2.0, ignoring for Lambda authorizer '%s'",
                        auth_name,
                    )

            # token based authorizers must have an identity source defined
            # this is determined by taking the header key in the properties
            # to form the identity source in a previous method call
            if not identity_sources and authorizer_type == LambdaAuthorizer.TOKEN:
                LOG.warning(
                    "Skip parsing Lambda authorizer '%s', must contain valid "
                    "identity sources for Rest Api based token authorizers",
                    auth_name,
                )
                continue

            lambda_authorizer = LambdaAuthorizer(
                authorizer_name=auth_name,
                type=authorizer_type,
                payload_version=payload_version,
                lambda_name=lambda_name,
                identity_sources=identity_sources,
                validation_string=validation_expression,
                use_simple_response=enable_simple_response,
            )

            authorizers[auth_name] = lambda_authorizer

            LOG.debug("Parsing Lambda authorizer '%s' type '%s'", auth_name, authorizer_type)

        return authorizers

    @staticmethod
    def _get_lambda_identity_sources(
        auth_name: str, auth_type: str, event_type: str, properties: dict, authorizer_object: dict
    ) -> List[str]:
        """
        Parses the properties depending on the Lambda Authorizer type (token or request) and retrieves identity sources

        Parameters
        ----------
        auth_name: str
            Name of the authorizer used for logging
        auth_type: str
            Type of authorizer (token, request)
        event_type: str
            API Gateway type (API, HTTP API)
        properties: dict
            Swagger Lambda Authorizer properties
        authorizer_object: dict
            Lambda Authorizer integration properties
        Returns
        -------
        List[str]
            A list of identity sources
        """
        identity_sources: List[str] = []

        if auth_type == LambdaAuthorizer.TOKEN:
            header_name = properties.get(SwaggerParser._AUTHORIZER_NAME)

            if not properties.get(SwaggerParser._AUTHORIZER_IN) == "header" or not header_name:
                LOG.warning(
                    "Missing properties for Lambda Authorizer '%s', "
                    "property 'in' must be set to 'header' and "
                    "property 'name' must be provided",
                    auth_name,
                )
            elif event_type == Route.HTTP:
                LOG.info("Type 'token' for Lambda Authorizer '%s' is unsupported ", auth_name)
            else:
                identity_sources.append(f"method.request.header.{header_name}")
        else:
            identity_source_string = authorizer_object.get(SwaggerParser._AUTHORIZER_IDENTITY_SOURCE, "")

            # split the identity sources, remove any trailing spaces, and validate
            # we check for false-y string since .split() will return [""] instead of [] on an empty string
            split_identity_source: List[str] = identity_source_string.split(",") if identity_source_string else []

            for identity in split_identity_source:
                trimmed_identity = identity.strip()
                is_valid_format = IdentitySourceValidator.validate_identity_source(trimmed_identity, event_type)

                if not is_valid_format:
                    raise InvalidSecurityDefinition(
                        f"Identity source '{trimmed_identity}' for Lambda Authorizer '{auth_name}' "
                        "is not a valid identity source, check the spelling/format."
                    )

                identity_sources.append(trimmed_identity)

        return identity_sources

    def _get_document_version(self) -> str:
        """
        Helper method to fetch the Swagger document version

        Returns
        -------
        str
            A string representing a version, blank if not found
        """
        document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or ""

        return str(document_version)

    def get_default_authorizer(self, event_type: str) -> Union[str, None]:
        """
        Parses the body definition to find root level Authorizer definitions

        Parameters
        ----------
        event_type: str
            String representing the type of API the definition body is defined as

        Returns
        -------
        Union[str, None]
            Returns the name of the authorizer, if there is one defined, otherwise None
        """
        document_version = self._get_document_version()
        authorizers = self.swagger.get(SwaggerParser._SWAGGER_SECURITY, [])

        if not authorizers:
            return None

        if not document_version.startswith(SwaggerParser._3_X_VERSION):
            raise IncorrectOasWithDefaultAuthorizerException(
                "Root level definition of default authorizers are only supported for API "
                "resources using an OpenApi 3.x body"
            )

        if len(authorizers) > 1:
            raise MultipleAuthorizerException(
                f"There must only be a single authorizer defined for a single route, found '{len(authorizers)}'"
            )

        if len(authorizers) == 1:
            # user has authorizer defined
            authorizer_object = authorizers[0]
            authorizer_object = list(authorizers[0])

            # make sure that authorizer actually has keys
            if len(authorizer_object) != 1:
                raise InvalidSecurityDefinition(
                    "Invalid default security definition found, there must be an authorizer defined."
                )

            authorizer_name = str(authorizer_object[0])

            LOG.debug("Found default authorizer: %s", authorizer_name)

            return authorizer_name

        return None

    def get_routes(self, event_type=Route.API) -> List[Route]:
        """
        Parses a swagger document and returns a list of APIs configured in the document.

        Swagger documents have the following structure
        {
            "/path1": {    # path
                "get": {   # method
                    "x-amazon-apigateway-integration": {   # integration
                        "type": "aws_proxy",

                        # URI contains the Lambda function ARN that needs to be parsed to get Function Name
                        "uri": {
                            "Fn::Sub":
                                "arn:aws:apigateway:aws:lambda:path/2015-03-31/functions/${LambdaFunction.Arn}/..."
                        }
                    }
                },
                "post": {
                },
            },
            "/path2": {
                ...
            }
        }

        Returns
        -------
        list of list of samcli.commands.local.apigw.local_apigw_service.Route
            List of APIs that are configured in the Swagger document
        """

        result = []
        paths_dict = self.swagger.get("paths", {})

        for full_path, path_config in paths_dict.items():
            for method, method_config in path_config.items():
                function_name = self._get_integration_function_name(method_config)
                if not function_name:
                    LOG.debug(
                        "Lambda function integration not found in Swagger document at path='%s' method='%s'",
                        full_path,
                        method,
                    )
                    continue

                normalized_method = method
                if normalized_method.lower() == self._ANY_METHOD_EXTENSION_KEY:
                    # Convert to a more commonly used method notation
                    normalized_method = self._ANY_METHOD
                payload_format_version = self._get_payload_format_version(method_config)

                authorizers = method_config.get(SwaggerParser._SWAGGER_SECURITY, None)

                authorizer_name = None
                use_default_authorizer = True
                if authorizers is not None:
                    if not isinstance(authorizers, list):
                        raise InvalidSecurityDefinition(
                            "Invalid security definition found, authorizers for "
                            f"path='{full_path}' method='{method}' must be a list"
                        )

                    if len(authorizers) > 1:
                        raise MultipleAuthorizerException(
                            "There must only be a single authorizer defined "
                            f"for path='{full_path}' method='{method}', found '{len(authorizers)}'"
                        )

                    if len(authorizers) == 1 and authorizers[0] != {}:
                        # user has authorizer defined
                        authorizer_object = authorizers[0]
                        authorizer_object = list(authorizers[0])

                        # make sure that authorizer actually has keys
                        if len(authorizer_object) != 1:
                            raise InvalidSecurityDefinition(
                                "Invalid security definition found, authorizers for "
                                f"path='{full_path}' method='{method}' must contain an authorizer"
                            )

                        authorizer_name = str(authorizer_object[0])
                    else:
                        # customer provided empty list, do not use default authorizer
                        use_default_authorizer = False

                route = Route(
                    function_name,
                    full_path,
                    methods=[normalized_method],
                    event_type=event_type,
                    payload_format_version=payload_format_version,
                    operation_name=method_config.get("operationId"),
                    stack_path=self.stack_path,
                    authorizer_name=authorizer_name,
                    use_default_authorizer=use_default_authorizer,
                )
                result.append(route)

        return result

    def _get_integration(self, method_config):
        """
        Get Integration defined in the method configuration.
        Integration configuration is defined under the special "x-amazon-apigateway-integration" key. We care only
        about Lambda integrations, which are of type aws_proxy, and ignore the rest.

        Parameters
        ----------
        method_config : dict
            Dictionary containing the method configuration which might contain integration settings

        Returns
        -------
        dict or None
            integration, if possible. None, if not.
        """
        if not isinstance(method_config, dict) or self._INTEGRATION_KEY not in method_config:
            return None

        integration = method_config[self._INTEGRATION_KEY]

        if (
            integration
            and isinstance(integration, dict)
            and integration.get("type").lower() == IntegrationType.aws_proxy.value
        ):
            # Integration must be "aws_proxy" otherwise we don't care about it
            return integration

        return None

    def _get_integration_function_name(self, method_config):
        """
        Tries to parse the Lambda Function name from the Integration defined in the method configuration.
        Integration configuration is defined under the special "x-amazon-apigateway-integration" key. We care only
        about Lambda integrations, which are of type aws_proxy, and ignore the rest. Integration URI is complex and
        hard to parse. Hence we do our best to extract function name out of integration URI. If not possible, we
        return None.

        Parameters
        ----------
        method_config : dict
            Dictionary containing the method configuration which might contain integration settings

        Returns
        -------
        string or None
            Lambda function name, if possible. None, if not.
        """
        integration = self._get_integration(method_config)
        if integration is None:
            return None

        return LambdaUri.get_function_name(integration.get("uri"))

    def _get_payload_format_version(self, method_config):
        """
        Get the "payloadFormatVersion" from the Integration defined in the method configuration.

        Parameters
        ----------
        method_config : dict
            Dictionary containing the method configuration which might contain integration settings

        Returns
        -------
        string or None
            Payload format version, if exists. None, if not.
        """
        integration = self._get_integration(method_config)
        if integration is None:
            return None

        return integration.get("payloadFormatVersion")
