samcli/local/apigw/authorizers/lambda_authorizer.py (233 lines of code) (raw):
"""
Custom Lambda Authorizer class definition
"""
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from json import JSONDecodeError, loads
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from urllib.parse import parse_qsl
from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator
from samcli.local.apigw.authorizers.authorizer import Authorizer
from samcli.local.apigw.exceptions import InvalidLambdaAuthorizerResponse, InvalidSecurityDefinition
from samcli.local.apigw.route import Route
_RESPONSE_PRINCIPAL_ID = "principalId"
_RESPONSE_CONTEXT = "context"
_RESPONSE_POLICY_DOCUMENT = "policyDocument"
_RESPONSE_IAM_STATEMENT = "Statement"
_RESPONSE_IAM_EFFECT = "Effect"
_RESPONSE_IAM_EFFECT_ALLOW = "Allow"
_RESPONSE_IAM_ACTION = "Action"
_RESPONSE_IAM_RESOURCE = "Resource"
_SIMPLE_RESPONSE_IS_AUTH = "isAuthorized"
_IAM_INVOKE_ACTION = "execute-api:Invoke"
class IdentitySource(ABC):
def __init__(self, identity_source: str):
"""
Abstract class representing an identity source validator
Paramters
---------
identity_source: str
The identity source without any prefix
"""
self.identity_source = identity_source
def is_valid(self, **kwargs) -> bool:
"""
Validates if the identity source is present
Parameters
----------
kwargs: dict
Key word arguments to search in
Returns
-------
bool:
True if the identity source is present
"""
return self.find_identity_value(**kwargs) is not None
@abstractmethod
def find_identity_value(self, **kwargs) -> Any:
"""
Returns the identity value, if found
"""
def __eq__(self, other) -> bool:
return (
isinstance(other, IdentitySource)
and self.identity_source == other.identity_source
and self.__class__ == other.__class__
)
class HeaderIdentitySource(IdentitySource):
def find_identity_value(self, **kwargs) -> Optional[str]:
"""
Finds the header value that the identity source corresponds to
Parameters
----------
kwargs
Keyword arguments that should contain `headers`
Returns
-------
Optional[str]
The string value of the header if it is found, otherwise None
"""
headers = kwargs.get("headers", {})
value = headers.get(self.identity_source)
return str(value) if value else None
def is_valid(self, **kwargs) -> bool:
"""
Validates whether the required header is present and matches the
validation expression, if defined.
Parameters
----------
kwargs: dict
Keyword arugments containing the incoming sources and validation expression
Returns
-------
bool
True if present and valid
"""
identity_source = self.find_identity_value(**kwargs)
validation_expression = kwargs.get("validation_expression")
if validation_expression and identity_source is not None:
return re.match(validation_expression, identity_source) is not None
return identity_source is not None
class QueryIdentitySource(IdentitySource):
def find_identity_value(self, **kwargs) -> Optional[str]:
"""
Finds the query string value that the identity source corresponds to
Parameters
----------
kwargs
Keyword arguments that should contain `querystring`
Returns
-------
Optional[str]
The string value of the query parameter if one is found, otherwise None
"""
query_string = kwargs.get("querystring", "")
if not query_string:
return None
query_string_list: List[Tuple[str, str]] = parse_qsl(query_string)
for key, value in query_string_list:
if key == self.identity_source and value:
return value
return None
class ContextIdentitySource(IdentitySource):
def find_identity_value(self, **kwargs) -> Optional[str]:
"""
Finds the context value that the identity source corresponds to
Parameters
----------
kwargs
Keyword arguments that should contain `context`
Returns
-------
Optional[str]
The string value of the context variable if it is found, otherwise None
"""
context = kwargs.get("context", {})
value = context.get(self.identity_source)
return str(value) if value else None
class StageVariableIdentitySource(IdentitySource):
def find_identity_value(self, **kwargs) -> Optional[str]:
"""
Finds the stage variable value that the identity source corresponds to
Parameters
----------
kwargs
Keyword arguments that should contain `stageVariables`
Returns
-------
Optional[str]
The stage variable if it is found, otherwise None
"""
stage_variables = kwargs.get("stageVariables", {})
value = stage_variables.get(self.identity_source)
return str(value) if value else None
@dataclass
class LambdaAuthorizer(Authorizer):
TOKEN = "token"
REQUEST = "request"
VALID_TYPES = [TOKEN, REQUEST]
PAYLOAD_V1 = "1.0"
PAYLOAD_V2 = "2.0"
PAYLOAD_VERSIONS = [PAYLOAD_V1, PAYLOAD_V2]
def __init__(
self,
authorizer_name: str,
type: str,
lambda_name: str,
identity_sources: List[str],
payload_version: str,
validation_string: Optional[str] = None,
use_simple_response: bool = False,
):
"""
Creates a Lambda Authorizer class
Parameters
----------
authorizer_name: str
The name of the Lambda Authorizer
type: str
The type of authorizer this is (token or request)
lambda_name: str
The name of the Lambda function this authorizer invokes
identity_sources: List[str]
A list of strings that this authorizer uses
payload_version: str
The payload format version (1.0 or 2.0)
validation_string: Optional[str] = None
The regular expression that can be used to validate headers
use_simple_responses: bool = False
Boolean representing whether to return a simple response or not
"""
self.authorizer_name = authorizer_name
self.lambda_name = lambda_name
self.type = type
self.validation_string = validation_string
self.payload_version = payload_version
self.use_simple_response = use_simple_response
self._parse_identity_sources(identity_sources)
def __eq__(self, other):
return (
isinstance(other, LambdaAuthorizer)
and self.lambda_name == other.lambda_name
and sorted(self._identity_sources_raw) == sorted(other._identity_sources_raw)
and self.validation_string == other.validation_string
and self.use_simple_response == other.use_simple_response
and self.payload_version == other.payload_version
and self.authorizer_name == other.authorizer_name
and self.type == other.type
)
@property
def identity_sources(self) -> List[IdentitySource]:
"""
The list of identity source validation objects
Returns
-------
List[IdentitySource]
A list of concrete identity source validation objects
"""
return self._identity_sources
@identity_sources.setter
def identity_sources(self, identity_sources: List[str]) -> None:
"""
Parses and sets the identity source validation objects
Parameters
----------
identity_sources: List[str]
A list of strings of identity sources
"""
self._parse_identity_sources(identity_sources)
def _parse_identity_sources(self, identity_sources: List[str]) -> None:
"""
Helper function to create identity source validation objects
Parameters
----------
identity_sources: List[str]
A list of identity sources to parse
"""
# validate incoming identity sources first
for source in identity_sources:
is_valid = IdentitySourceValidator.validate_identity_source(
source, Route.API
) or IdentitySourceValidator.validate_identity_source(source, Route.HTTP)
if not is_valid:
raise InvalidSecurityDefinition(
f"Invalid identity source '{source}' for Lambda authorizer '{self.authorizer_name}"
)
identity_source_type = {
"method.request.header.": HeaderIdentitySource,
"$request.header.": HeaderIdentitySource,
"method.request.querystring.": QueryIdentitySource,
"$request.querystring.": QueryIdentitySource,
"context.": ContextIdentitySource,
"$context.": ContextIdentitySource,
"stageVariables.": StageVariableIdentitySource,
"$stageVariables.": StageVariableIdentitySource,
}
self._identity_sources_raw = identity_sources
self._identity_sources = []
for identity_source in self._identity_sources_raw:
for prefix, identity_source_object in identity_source_type.items():
if identity_source.startswith(prefix):
# get the stuff after the prefix
# and create the corresponding identity source object
property = identity_source[len(prefix) :]
# NOTE (lucashuy):
# need to ignore the typing here so that mypy doesn't complain
# about instantiating an abstract class
#
# `identity_source_object` (which comes from `identity_source_type`)
# is always a concrete class
identity_source_validator = identity_source_object(identity_source=property) # type: ignore
self._identity_sources.append(identity_source_validator)
break
def is_valid_response(self, response: Union[str, bytes], method_arn: str) -> bool:
"""
Validates whether a Lambda authorizer request is authenticated or not.
Parameters
----------
response: Union[str, bytes]
JSON string containing the output from a Lambda authorizer
method_arn: str
The method ARN of the route that invoked the Lambda authorizer
Returns
-------
bool
True if the request is properly authenticated
"""
try:
json_response = loads(response)
except (ValueError, JSONDecodeError):
raise InvalidLambdaAuthorizerResponse(
f"Authorizer {self.authorizer_name} return an invalid response payload"
)
if self.payload_version == LambdaAuthorizer.PAYLOAD_V2 and self.use_simple_response:
return self._validate_simple_response(json_response)
# validate IAM policy document
LambdaAuthorizerIAMPolicyValidator.validate_policy_document(self.authorizer_name, json_response)
LambdaAuthorizerIAMPolicyValidator.validate_statement(self.authorizer_name, json_response)
return self._is_resource_authorized(json_response, method_arn)
def _is_resource_authorized(self, response: dict, method_arn: str) -> bool:
"""
Validate if the current method ARN is actually authorized
Parameters
----------
response: dict
The response output from the Lambda authorizer (should be in IAM format)
method_arn: str
The route's method ARN
Returns
-------
bool
True if authorized
"""
policy_document = response.get(_RESPONSE_POLICY_DOCUMENT, {})
all_statements = policy_document.get(_RESPONSE_IAM_STATEMENT, [])
for statement in all_statements:
if statement.get(_RESPONSE_IAM_EFFECT) != _RESPONSE_IAM_EFFECT_ALLOW:
continue
action = statement.get(_RESPONSE_IAM_ACTION, [])
action_list = action if isinstance(action, list) else [action]
if _IAM_INVOKE_ACTION not in action_list:
continue
resource = statement.get(_RESPONSE_IAM_RESOURCE, [])
resource_list = resource if isinstance(resource, list) else [resource]
for resource_arn in resource_list:
# form a regular expression from the possible wildcard resource ARN
regex_method_arn = resource_arn.replace("*", ".*").replace("?", ".")
regex_method_arn += "$"
if re.match(regex_method_arn, method_arn):
return True
return False
def _validate_simple_response(self, response: dict) -> bool:
"""
Helper method to validate if a Lambda authorizer response using simple responses is valid and authorized
Parameters
----------
response: dict
JSON object containing required simple response paramters
Returns
-------
bool
True if the request is authorized
"""
is_authorized = response.get(_SIMPLE_RESPONSE_IS_AUTH)
if is_authorized is None or not isinstance(is_authorized, bool):
raise InvalidLambdaAuthorizerResponse(
f"Authorizer {self.authorizer_name} is missing or contains an invalid " f"{_SIMPLE_RESPONSE_IS_AUTH}"
)
return cast(bool, is_authorized)
def get_context(self, response: Union[str, bytes]) -> Dict[str, Any]:
"""
Returns the context (if set) from the authorizer response and appends the principalId to it.
Parameters
----------
response: Union[str, bytes]
Output from Lambda authorizer
Returns
-------
Dict[str, Any]
The built authorizer context object
"""
invalid_message = f"Authorizer {self.authorizer_name} return an invalid response payload"
try:
json_response = loads(response)
except (ValueError, JSONDecodeError) as ex:
raise InvalidLambdaAuthorizerResponse(invalid_message) from ex
if not isinstance(json_response, dict):
raise InvalidLambdaAuthorizerResponse(invalid_message)
built_context = json_response.get(_RESPONSE_CONTEXT, {})
if not isinstance(built_context, dict):
raise InvalidLambdaAuthorizerResponse(invalid_message)
principal_id = json_response.get(_RESPONSE_PRINCIPAL_ID)
if principal_id:
# only V1 response contains this ID in the output
built_context[_RESPONSE_PRINCIPAL_ID] = principal_id
return built_context
@dataclass
class LambdaAuthorizerIAMPolicyPropertyValidator:
property_key: str
property_types: List[Type]
def is_valid(self, response: dict) -> bool:
"""
Validates whether the property is present and of the correct type
Parameters
----------
response: dict
The response output from the Lambda authorizer (should be in IAM format)
Returns
-------
bool
True if present and of correct type
"""
value = response.get(self.property_key)
if value is None:
return False
for property_type in self.property_types:
if isinstance(value, property_type):
return True
return False
class LambdaAuthorizerIAMPolicyValidator:
@staticmethod
def validate_policy_document(auth_name: str, response: dict) -> None:
"""
Validate the properties of a Lambda authorizer response at the root level
Parameters
----------
auth_name: str
Name of the authorizer
response: dict
The response output from the Lambda authorizer (should be in IAM format)
"""
validators = {
_RESPONSE_PRINCIPAL_ID: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_PRINCIPAL_ID, [str]),
_RESPONSE_POLICY_DOCUMENT: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_POLICY_DOCUMENT, [dict]),
}
for prop_name, validator in validators.items():
if not validator.is_valid(response):
raise InvalidLambdaAuthorizerResponse(
f"Authorizer '{auth_name}' contains an invalid or " f"missing '{prop_name}' from response"
)
@staticmethod
def validate_statement(auth_name: str, response: dict) -> None:
"""
Validate the Statement(s) of a Lambda authorizer response's policy document
Parameters
----------
auth_name: str
Name of the authorizer
response: dict
The response output from the Lambda authorizer (should be in IAM format)
"""
policy_document = response.get(_RESPONSE_POLICY_DOCUMENT, {})
all_statements = policy_document.get(_RESPONSE_IAM_STATEMENT)
if not all_statements or not isinstance(all_statements, list) or not len(all_statements) > 0:
raise InvalidLambdaAuthorizerResponse(
f"Authorizer '{auth_name}' contains an invalid or " f"missing '{_RESPONSE_IAM_STATEMENT}' from response"
)
validators = {
_RESPONSE_IAM_ACTION: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_ACTION, [str, list]),
_RESPONSE_IAM_EFFECT: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_EFFECT, [str]),
_RESPONSE_IAM_RESOURCE: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_RESOURCE, [str, list]),
}
for statement in all_statements:
if not isinstance(statement, dict):
raise InvalidLambdaAuthorizerResponse(
f"Authorizer '{auth_name}' policy document must be a list of objects"
)
for prop_name, validator in validators.items():
if not validator.is_valid(statement):
raise InvalidLambdaAuthorizerResponse(
f"Authorizer '{auth_name}' policy document contains an invalid '{prop_name}'"
)