"""
Custom Lambda Authorizer class definition
"""

import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from json import JSONDecodeError, loads
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from urllib.parse import parse_qsl

from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator
from samcli.local.apigw.authorizers.authorizer import Authorizer
from samcli.local.apigw.exceptions import InvalidLambdaAuthorizerResponse, InvalidSecurityDefinition
from samcli.local.apigw.route import Route

_RESPONSE_PRINCIPAL_ID = "principalId"
_RESPONSE_CONTEXT = "context"
_RESPONSE_POLICY_DOCUMENT = "policyDocument"
_RESPONSE_IAM_STATEMENT = "Statement"
_RESPONSE_IAM_EFFECT = "Effect"
_RESPONSE_IAM_EFFECT_ALLOW = "Allow"
_RESPONSE_IAM_ACTION = "Action"
_RESPONSE_IAM_RESOURCE = "Resource"
_SIMPLE_RESPONSE_IS_AUTH = "isAuthorized"
_IAM_INVOKE_ACTION = "execute-api:Invoke"


class IdentitySource(ABC):
    def __init__(self, identity_source: str):
        """
        Abstract class representing an identity source validator

        Paramters
        ---------
        identity_source: str
            The identity source without any prefix
        """
        self.identity_source = identity_source

    def is_valid(self, **kwargs) -> bool:
        """
        Validates if the identity source is present

        Parameters
        ----------
        kwargs: dict
            Key word arguments to search in

        Returns
        -------
        bool:
            True if the identity source is present
        """
        return self.find_identity_value(**kwargs) is not None

    @abstractmethod
    def find_identity_value(self, **kwargs) -> Any:
        """
        Returns the identity value, if found
        """

    def __eq__(self, other) -> bool:
        return (
            isinstance(other, IdentitySource)
            and self.identity_source == other.identity_source
            and self.__class__ == other.__class__
        )


class HeaderIdentitySource(IdentitySource):
    def find_identity_value(self, **kwargs) -> Optional[str]:
        """
        Finds the header value that the identity source corresponds to

        Parameters
        ----------
        kwargs
            Keyword arguments that should contain `headers`

        Returns
        -------
        Optional[str]
            The string value of the header if it is found, otherwise None
        """
        headers = kwargs.get("headers", {})
        value = headers.get(self.identity_source)

        return str(value) if value else None

    def is_valid(self, **kwargs) -> bool:
        """
        Validates whether the required header is present and matches the
        validation expression, if defined.

        Parameters
        ----------
        kwargs: dict
            Keyword arugments containing the incoming sources and validation expression

        Returns
        -------
        bool
            True if present and valid
        """
        identity_source = self.find_identity_value(**kwargs)
        validation_expression = kwargs.get("validation_expression")

        if validation_expression and identity_source is not None:
            return re.match(validation_expression, identity_source) is not None

        return identity_source is not None


class QueryIdentitySource(IdentitySource):
    def find_identity_value(self, **kwargs) -> Optional[str]:
        """
        Finds the query string value that the identity source corresponds to

        Parameters
        ----------
        kwargs
            Keyword arguments that should contain `querystring`

        Returns
        -------
        Optional[str]
            The string value of the query parameter if one is found, otherwise None
        """
        query_string = kwargs.get("querystring", "")

        if not query_string:
            return None

        query_string_list: List[Tuple[str, str]] = parse_qsl(query_string)

        for key, value in query_string_list:
            if key == self.identity_source and value:
                return value

        return None


class ContextIdentitySource(IdentitySource):
    def find_identity_value(self, **kwargs) -> Optional[str]:
        """
        Finds the context value that the identity source corresponds to

        Parameters
        ----------
        kwargs
            Keyword arguments that should contain `context`

        Returns
        -------
        Optional[str]
            The string value of the context variable if it is found, otherwise None
        """
        context = kwargs.get("context", {})
        value = context.get(self.identity_source)

        return str(value) if value else None


class StageVariableIdentitySource(IdentitySource):
    def find_identity_value(self, **kwargs) -> Optional[str]:
        """
        Finds the stage variable value that the identity source corresponds to

        Parameters
        ----------
        kwargs
            Keyword arguments that should contain `stageVariables`

        Returns
        -------
        Optional[str]
            The stage variable if it is found, otherwise None
        """
        stage_variables = kwargs.get("stageVariables", {})
        value = stage_variables.get(self.identity_source)

        return str(value) if value else None


@dataclass
class LambdaAuthorizer(Authorizer):
    TOKEN = "token"
    REQUEST = "request"
    VALID_TYPES = [TOKEN, REQUEST]

    PAYLOAD_V1 = "1.0"
    PAYLOAD_V2 = "2.0"
    PAYLOAD_VERSIONS = [PAYLOAD_V1, PAYLOAD_V2]

    def __init__(
        self,
        authorizer_name: str,
        type: str,
        lambda_name: str,
        identity_sources: List[str],
        payload_version: str,
        validation_string: Optional[str] = None,
        use_simple_response: bool = False,
    ):
        """
        Creates a Lambda Authorizer class

        Parameters
        ----------
        authorizer_name: str
            The name of the Lambda Authorizer
        type: str
            The type of authorizer this is (token or request)
        lambda_name: str
            The name of the Lambda function this authorizer invokes
        identity_sources: List[str]
            A list of strings that this authorizer uses
        payload_version: str
            The payload format version (1.0 or 2.0)
        validation_string: Optional[str] = None
            The regular expression that can be used to validate headers
        use_simple_responses: bool = False
            Boolean representing whether to return a simple response or not
        """
        self.authorizer_name = authorizer_name
        self.lambda_name = lambda_name
        self.type = type
        self.validation_string = validation_string
        self.payload_version = payload_version
        self.use_simple_response = use_simple_response

        self._parse_identity_sources(identity_sources)

    def __eq__(self, other):
        return (
            isinstance(other, LambdaAuthorizer)
            and self.lambda_name == other.lambda_name
            and sorted(self._identity_sources_raw) == sorted(other._identity_sources_raw)
            and self.validation_string == other.validation_string
            and self.use_simple_response == other.use_simple_response
            and self.payload_version == other.payload_version
            and self.authorizer_name == other.authorizer_name
            and self.type == other.type
        )

    @property
    def identity_sources(self) -> List[IdentitySource]:
        """
        The list of identity source validation objects

        Returns
        -------
        List[IdentitySource]
            A list of concrete identity source validation objects
        """
        return self._identity_sources

    @identity_sources.setter
    def identity_sources(self, identity_sources: List[str]) -> None:
        """
        Parses and sets the identity source validation objects

        Parameters
        ----------
        identity_sources: List[str]
            A list of strings of identity sources
        """
        self._parse_identity_sources(identity_sources)

    def _parse_identity_sources(self, identity_sources: List[str]) -> None:
        """
        Helper function to create identity source validation objects

        Parameters
        ----------
        identity_sources: List[str]
            A list of identity sources to parse
        """

        # validate incoming identity sources first
        for source in identity_sources:
            is_valid = IdentitySourceValidator.validate_identity_source(
                source, Route.API
            ) or IdentitySourceValidator.validate_identity_source(source, Route.HTTP)

            if not is_valid:
                raise InvalidSecurityDefinition(
                    f"Invalid identity source '{source}' for Lambda authorizer '{self.authorizer_name}"
                )

        identity_source_type = {
            "method.request.header.": HeaderIdentitySource,
            "$request.header.": HeaderIdentitySource,
            "method.request.querystring.": QueryIdentitySource,
            "$request.querystring.": QueryIdentitySource,
            "context.": ContextIdentitySource,
            "$context.": ContextIdentitySource,
            "stageVariables.": StageVariableIdentitySource,
            "$stageVariables.": StageVariableIdentitySource,
        }

        self._identity_sources_raw = identity_sources
        self._identity_sources = []

        for identity_source in self._identity_sources_raw:
            for prefix, identity_source_object in identity_source_type.items():
                if identity_source.startswith(prefix):
                    # get the stuff after the prefix
                    # and create the corresponding identity source object
                    property = identity_source[len(prefix) :]

                    # NOTE (lucashuy):
                    # need to ignore the typing here so that mypy doesn't complain
                    # about instantiating an abstract class
                    #
                    # `identity_source_object` (which comes from `identity_source_type`)
                    # is always a concrete class
                    identity_source_validator = identity_source_object(identity_source=property)  # type: ignore

                    self._identity_sources.append(identity_source_validator)

                    break

    def is_valid_response(self, response: Union[str, bytes], method_arn: str) -> bool:
        """
        Validates whether a Lambda authorizer request is authenticated or not.

        Parameters
        ----------
        response: Union[str, bytes]
            JSON string containing the output from a Lambda authorizer
        method_arn: str
            The method ARN of the route that invoked the Lambda authorizer

        Returns
        -------
        bool
            True if the request is properly authenticated
        """
        try:
            json_response = loads(response)
        except (ValueError, JSONDecodeError):
            raise InvalidLambdaAuthorizerResponse(
                f"Authorizer {self.authorizer_name} return an invalid response payload"
            )

        if self.payload_version == LambdaAuthorizer.PAYLOAD_V2 and self.use_simple_response:
            return self._validate_simple_response(json_response)

        # validate IAM policy document
        LambdaAuthorizerIAMPolicyValidator.validate_policy_document(self.authorizer_name, json_response)
        LambdaAuthorizerIAMPolicyValidator.validate_statement(self.authorizer_name, json_response)

        return self._is_resource_authorized(json_response, method_arn)

    def _is_resource_authorized(self, response: dict, method_arn: str) -> bool:
        """
        Validate if the current method ARN is actually authorized

        Parameters
        ----------
        response: dict
            The response output from the Lambda authorizer (should be in IAM format)
        method_arn: str
            The route's method ARN

        Returns
        -------
        bool
            True if authorized
        """
        policy_document = response.get(_RESPONSE_POLICY_DOCUMENT, {})
        all_statements = policy_document.get(_RESPONSE_IAM_STATEMENT, [])

        for statement in all_statements:
            if statement.get(_RESPONSE_IAM_EFFECT) != _RESPONSE_IAM_EFFECT_ALLOW:
                continue

            action = statement.get(_RESPONSE_IAM_ACTION, [])
            action_list = action if isinstance(action, list) else [action]

            if _IAM_INVOKE_ACTION not in action_list:
                continue

            resource = statement.get(_RESPONSE_IAM_RESOURCE, [])
            resource_list = resource if isinstance(resource, list) else [resource]

            for resource_arn in resource_list:
                # form a regular expression from the possible wildcard resource ARN
                regex_method_arn = resource_arn.replace("*", ".*").replace("?", ".")
                regex_method_arn += "$"

                if re.match(regex_method_arn, method_arn):
                    return True

        return False

    def _validate_simple_response(self, response: dict) -> bool:
        """
        Helper method to validate if a Lambda authorizer response using simple responses is valid and authorized

        Parameters
        ----------
        response: dict
            JSON object containing required simple response paramters

        Returns
        -------
        bool
            True if the request is authorized
        """
        is_authorized = response.get(_SIMPLE_RESPONSE_IS_AUTH)

        if is_authorized is None or not isinstance(is_authorized, bool):
            raise InvalidLambdaAuthorizerResponse(
                f"Authorizer {self.authorizer_name} is missing or contains an invalid " f"{_SIMPLE_RESPONSE_IS_AUTH}"
            )

        return cast(bool, is_authorized)

    def get_context(self, response: Union[str, bytes]) -> Dict[str, Any]:
        """
        Returns the context (if set) from the authorizer response and appends the principalId to it.

        Parameters
        ----------
        response: Union[str, bytes]
            Output from Lambda authorizer

        Returns
        -------
        Dict[str, Any]
            The built authorizer context object
        """
        invalid_message = f"Authorizer {self.authorizer_name} return an invalid response payload"

        try:
            json_response = loads(response)
        except (ValueError, JSONDecodeError) as ex:
            raise InvalidLambdaAuthorizerResponse(invalid_message) from ex

        if not isinstance(json_response, dict):
            raise InvalidLambdaAuthorizerResponse(invalid_message)

        built_context = json_response.get(_RESPONSE_CONTEXT, {})

        if not isinstance(built_context, dict):
            raise InvalidLambdaAuthorizerResponse(invalid_message)

        principal_id = json_response.get(_RESPONSE_PRINCIPAL_ID)
        if principal_id:
            # only V1 response contains this ID in the output
            built_context[_RESPONSE_PRINCIPAL_ID] = principal_id

        return built_context


@dataclass
class LambdaAuthorizerIAMPolicyPropertyValidator:
    property_key: str
    property_types: List[Type]

    def is_valid(self, response: dict) -> bool:
        """
        Validates whether the property is present and of the correct type

        Parameters
        ----------
        response: dict
            The response output from the Lambda authorizer (should be in IAM format)

        Returns
        -------
        bool
            True if present and of correct type
        """
        value = response.get(self.property_key)

        if value is None:
            return False

        for property_type in self.property_types:
            if isinstance(value, property_type):
                return True

        return False


class LambdaAuthorizerIAMPolicyValidator:
    @staticmethod
    def validate_policy_document(auth_name: str, response: dict) -> None:
        """
        Validate the properties of a Lambda authorizer response at the root level

        Parameters
        ----------
        auth_name: str
            Name of the authorizer
        response: dict
            The response output from the Lambda authorizer (should be in IAM format)
        """
        validators = {
            _RESPONSE_PRINCIPAL_ID: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_PRINCIPAL_ID, [str]),
            _RESPONSE_POLICY_DOCUMENT: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_POLICY_DOCUMENT, [dict]),
        }

        for prop_name, validator in validators.items():
            if not validator.is_valid(response):
                raise InvalidLambdaAuthorizerResponse(
                    f"Authorizer '{auth_name}' contains an invalid or " f"missing '{prop_name}' from response"
                )

    @staticmethod
    def validate_statement(auth_name: str, response: dict) -> None:
        """
        Validate the Statement(s) of a Lambda authorizer response's policy document

        Parameters
        ----------
        auth_name: str
            Name of the authorizer
        response: dict
            The response output from the Lambda authorizer (should be in IAM format)
        """
        policy_document = response.get(_RESPONSE_POLICY_DOCUMENT, {})

        all_statements = policy_document.get(_RESPONSE_IAM_STATEMENT)
        if not all_statements or not isinstance(all_statements, list) or not len(all_statements) > 0:
            raise InvalidLambdaAuthorizerResponse(
                f"Authorizer '{auth_name}' contains an invalid or " f"missing '{_RESPONSE_IAM_STATEMENT}' from response"
            )

        validators = {
            _RESPONSE_IAM_ACTION: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_ACTION, [str, list]),
            _RESPONSE_IAM_EFFECT: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_EFFECT, [str]),
            _RESPONSE_IAM_RESOURCE: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_RESOURCE, [str, list]),
        }

        for statement in all_statements:
            if not isinstance(statement, dict):
                raise InvalidLambdaAuthorizerResponse(
                    f"Authorizer '{auth_name}' policy document must be a list of objects"
                )

            for prop_name, validator in validators.items():
                if not validator.is_valid(statement):
                    raise InvalidLambdaAuthorizerResponse(
                        f"Authorizer '{auth_name}' policy document contains an invalid '{prop_name}'"
                    )
