samcli/commands/local/lib/swagger/parser.py (252 lines of code) (raw):
"""Handles Swagger Parsing"""
import logging
from typing import Dict, List, Union
from samcli.commands.local.lib.swagger.integration_uri import IntegrationType, LambdaUri
from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator
from samcli.local.apigw.authorizers.authorizer import Authorizer
from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer
from samcli.local.apigw.exceptions import (
IncorrectOasWithDefaultAuthorizerException,
InvalidOasVersion,
InvalidSecurityDefinition,
MultipleAuthorizerException,
)
from samcli.local.apigw.route import Route
LOG = logging.getLogger(__name__)
class SwaggerParser:
_AUTHORIZER_KEY = "x-amazon-apigateway-authorizer"
_INTEGRATION_KEY = "x-amazon-apigateway-integration"
_ANY_METHOD_EXTENSION_KEY = "x-amazon-apigateway-any-method"
_BINARY_MEDIA_TYPES_EXTENSION_KEY = "x-amazon-apigateway-binary-media-types" # pylint: disable=C0103
_ANY_METHOD = "ANY"
_SWAGGER = "swagger"
_OPENAPI = "openapi"
_2_X_VERSION = "2."
_3_X_VERSION = "3."
_SWAGGER_COMPONENTS = "components"
_SWAGGER_SECURITY = "security"
_SWAGGER_SECURITY_SCHEMES = "securitySchemes"
_SWAGGER_SECURITY_DEFINITIONS = "securityDefinitions"
_AUTHORIZER_TYPE = "type"
_AUTHORIZER_PAYLOAD_VERSION = "authorizerPayloadFormatVersion"
_AUTHORIZER_LAMBDA_URI = "authorizerUri"
_AUTHORIZER_LAMBDA_VALIDATION = "identityValidationExpression"
_AUTHORIZER_NAME = "name"
_AUTHORIZER_IN = "in"
_AUTHORIZER_IDENTITY_SOURCE = "identitySource"
_AUTHORIZER_SIMPLE_RESPONSES = "enableSimpleResponses"
def __init__(self, stack_path: str, swagger):
"""
Constructs an Swagger Parser object
:param str stack_path: Path of the stack the resource is located
:param dict swagger: Dictionary representation of a Swagger document
"""
self.swagger = swagger or {}
self.stack_path = stack_path
def get_binary_media_types(self):
"""
Get the list of Binary Media Types from Swagger
Returns
-------
list of str
List of strings that represent the Binary Media Types for the API, defaulting to empty list is None
"""
return self.swagger.get(self._BINARY_MEDIA_TYPES_EXTENSION_KEY) or []
def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]:
"""
Parse Swagger document and returns a list of Authorizer objects
Parameters
----------
event_type: str
String indicating what type of API Gateway this is
Returns
-------
dict[str, Authorizer]
A map of authorizer names and Authorizer objects found in the body definition
"""
authorizers: Dict[str, Authorizer] = {}
authorizer_dict = {}
document_version = self._get_document_version()
if document_version.startswith(SwaggerParser._2_X_VERSION):
LOG.debug("Parsing Swagger document using 2.0 specification")
authorizer_dict = self.swagger.get(SwaggerParser._SWAGGER_SECURITY_DEFINITIONS, {})
elif document_version.startswith(SwaggerParser._3_X_VERSION):
LOG.debug("Parsing Swagger document using 3.0 specification")
authorizer_dict = self.swagger.get(SwaggerParser._SWAGGER_COMPONENTS, {}).get(
SwaggerParser._SWAGGER_SECURITY_SCHEMES, {}
)
else:
raise InvalidOasVersion(
f"An invalid OpenApi version was detected: '{document_version}', must be one of 2.x or 3.x",
)
for auth_name, properties in authorizer_dict.items():
authorizer_object = properties.get(self._AUTHORIZER_KEY)
if not authorizer_object:
LOG.warning("Skip parsing unsupported authorizer '%s'", auth_name)
continue
authorizer_type = authorizer_object.get(SwaggerParser._AUTHORIZER_TYPE, "").lower()
payload_version = authorizer_object.get(SwaggerParser._AUTHORIZER_PAYLOAD_VERSION)
if event_type == Route.HTTP and payload_version not in LambdaAuthorizer.PAYLOAD_VERSIONS:
raise InvalidSecurityDefinition(f"Authorizer '{auth_name}' contains an invalid payload version")
if event_type == Route.API:
payload_version = LambdaAuthorizer.PAYLOAD_V1
lambda_name = LambdaUri.get_function_name(authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_URI))
if not lambda_name:
LOG.warning("Unable to parse authorizerUri '%s' for authorizer '%s', skipping", lambda_name, auth_name)
continue
# only add authorizer if it is Lambda token or request based (not jwt)
if authorizer_type not in LambdaAuthorizer.VALID_TYPES:
LOG.warning("Lambda authorizer '%s' type '%s' is unsupported, skipping", auth_name, authorizer_type)
continue
identity_sources = self._get_lambda_identity_sources(
auth_name, authorizer_type, event_type, properties, authorizer_object
)
validation_expression = authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_VALIDATION)
if event_type == Route.HTTP and validation_expression:
validation_expression = None
LOG.warning(
"Validation expressions is only available on REST APIs, ignoring for Lambda authorizer '%s'",
auth_name,
)
enable_simple_response = authorizer_object.get(SwaggerParser._AUTHORIZER_SIMPLE_RESPONSES, False)
if (
event_type != Route.HTTP
or payload_version != LambdaAuthorizer.PAYLOAD_V2
or not isinstance(enable_simple_response, bool)
):
enable_simple_response = False
if authorizer_object.get(SwaggerParser._AUTHORIZER_SIMPLE_RESPONSES) is not None:
LOG.warning(
"Simple responses are only available on HTTP APIs with payload version "
"2.0, ignoring for Lambda authorizer '%s'",
auth_name,
)
# token based authorizers must have an identity source defined
# this is determined by taking the header key in the properties
# to form the identity source in a previous method call
if not identity_sources and authorizer_type == LambdaAuthorizer.TOKEN:
LOG.warning(
"Skip parsing Lambda authorizer '%s', must contain valid "
"identity sources for Rest Api based token authorizers",
auth_name,
)
continue
lambda_authorizer = LambdaAuthorizer(
authorizer_name=auth_name,
type=authorizer_type,
payload_version=payload_version,
lambda_name=lambda_name,
identity_sources=identity_sources,
validation_string=validation_expression,
use_simple_response=enable_simple_response,
)
authorizers[auth_name] = lambda_authorizer
LOG.debug("Parsing Lambda authorizer '%s' type '%s'", auth_name, authorizer_type)
return authorizers
@staticmethod
def _get_lambda_identity_sources(
auth_name: str, auth_type: str, event_type: str, properties: dict, authorizer_object: dict
) -> List[str]:
"""
Parses the properties depending on the Lambda Authorizer type (token or request) and retrieves identity sources
Parameters
----------
auth_name: str
Name of the authorizer used for logging
auth_type: str
Type of authorizer (token, request)
event_type: str
API Gateway type (API, HTTP API)
properties: dict
Swagger Lambda Authorizer properties
authorizer_object: dict
Lambda Authorizer integration properties
Returns
-------
List[str]
A list of identity sources
"""
identity_sources: List[str] = []
if auth_type == LambdaAuthorizer.TOKEN:
header_name = properties.get(SwaggerParser._AUTHORIZER_NAME)
if not properties.get(SwaggerParser._AUTHORIZER_IN) == "header" or not header_name:
LOG.warning(
"Missing properties for Lambda Authorizer '%s', "
"property 'in' must be set to 'header' and "
"property 'name' must be provided",
auth_name,
)
elif event_type == Route.HTTP:
LOG.info("Type 'token' for Lambda Authorizer '%s' is unsupported ", auth_name)
else:
identity_sources.append(f"method.request.header.{header_name}")
else:
identity_source_string = authorizer_object.get(SwaggerParser._AUTHORIZER_IDENTITY_SOURCE, "")
# split the identity sources, remove any trailing spaces, and validate
# we check for false-y string since .split() will return [""] instead of [] on an empty string
split_identity_source: List[str] = identity_source_string.split(",") if identity_source_string else []
for identity in split_identity_source:
trimmed_identity = identity.strip()
is_valid_format = IdentitySourceValidator.validate_identity_source(trimmed_identity, event_type)
if not is_valid_format:
raise InvalidSecurityDefinition(
f"Identity source '{trimmed_identity}' for Lambda Authorizer '{auth_name}' "
"is not a valid identity source, check the spelling/format."
)
identity_sources.append(trimmed_identity)
return identity_sources
def _get_document_version(self) -> str:
"""
Helper method to fetch the Swagger document version
Returns
-------
str
A string representing a version, blank if not found
"""
document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or ""
return str(document_version)
def get_default_authorizer(self, event_type: str) -> Union[str, None]:
"""
Parses the body definition to find root level Authorizer definitions
Parameters
----------
event_type: str
String representing the type of API the definition body is defined as
Returns
-------
Union[str, None]
Returns the name of the authorizer, if there is one defined, otherwise None
"""
document_version = self._get_document_version()
authorizers = self.swagger.get(SwaggerParser._SWAGGER_SECURITY, [])
if not authorizers:
return None
if not document_version.startswith(SwaggerParser._3_X_VERSION):
raise IncorrectOasWithDefaultAuthorizerException(
"Root level definition of default authorizers are only supported for API "
"resources using an OpenApi 3.x body"
)
if len(authorizers) > 1:
raise MultipleAuthorizerException(
f"There must only be a single authorizer defined for a single route, found '{len(authorizers)}'"
)
if len(authorizers) == 1:
# user has authorizer defined
authorizer_object = authorizers[0]
authorizer_object = list(authorizers[0])
# make sure that authorizer actually has keys
if len(authorizer_object) != 1:
raise InvalidSecurityDefinition(
"Invalid default security definition found, there must be an authorizer defined."
)
authorizer_name = str(authorizer_object[0])
LOG.debug("Found default authorizer: %s", authorizer_name)
return authorizer_name
return None
def get_routes(self, event_type=Route.API) -> List[Route]:
"""
Parses a swagger document and returns a list of APIs configured in the document.
Swagger documents have the following structure
{
"/path1": { # path
"get": { # method
"x-amazon-apigateway-integration": { # integration
"type": "aws_proxy",
# URI contains the Lambda function ARN that needs to be parsed to get Function Name
"uri": {
"Fn::Sub":
"arn:aws:apigateway:aws:lambda:path/2015-03-31/functions/${LambdaFunction.Arn}/..."
}
}
},
"post": {
},
},
"/path2": {
...
}
}
Returns
-------
list of list of samcli.commands.local.apigw.local_apigw_service.Route
List of APIs that are configured in the Swagger document
"""
result = []
paths_dict = self.swagger.get("paths", {})
for full_path, path_config in paths_dict.items():
for method, method_config in path_config.items():
function_name = self._get_integration_function_name(method_config)
if not function_name:
LOG.debug(
"Lambda function integration not found in Swagger document at path='%s' method='%s'",
full_path,
method,
)
continue
normalized_method = method
if normalized_method.lower() == self._ANY_METHOD_EXTENSION_KEY:
# Convert to a more commonly used method notation
normalized_method = self._ANY_METHOD
payload_format_version = self._get_payload_format_version(method_config)
authorizers = method_config.get(SwaggerParser._SWAGGER_SECURITY, None)
authorizer_name = None
use_default_authorizer = True
if authorizers is not None:
if not isinstance(authorizers, list):
raise InvalidSecurityDefinition(
"Invalid security definition found, authorizers for "
f"path='{full_path}' method='{method}' must be a list"
)
if len(authorizers) > 1:
raise MultipleAuthorizerException(
"There must only be a single authorizer defined "
f"for path='{full_path}' method='{method}', found '{len(authorizers)}'"
)
if len(authorizers) == 1 and authorizers[0] != {}:
# user has authorizer defined
authorizer_object = authorizers[0]
authorizer_object = list(authorizers[0])
# make sure that authorizer actually has keys
if len(authorizer_object) != 1:
raise InvalidSecurityDefinition(
"Invalid security definition found, authorizers for "
f"path='{full_path}' method='{method}' must contain an authorizer"
)
authorizer_name = str(authorizer_object[0])
else:
# customer provided empty list, do not use default authorizer
use_default_authorizer = False
route = Route(
function_name,
full_path,
methods=[normalized_method],
event_type=event_type,
payload_format_version=payload_format_version,
operation_name=method_config.get("operationId"),
stack_path=self.stack_path,
authorizer_name=authorizer_name,
use_default_authorizer=use_default_authorizer,
)
result.append(route)
return result
def _get_integration(self, method_config):
"""
Get Integration defined in the method configuration.
Integration configuration is defined under the special "x-amazon-apigateway-integration" key. We care only
about Lambda integrations, which are of type aws_proxy, and ignore the rest.
Parameters
----------
method_config : dict
Dictionary containing the method configuration which might contain integration settings
Returns
-------
dict or None
integration, if possible. None, if not.
"""
if not isinstance(method_config, dict) or self._INTEGRATION_KEY not in method_config:
return None
integration = method_config[self._INTEGRATION_KEY]
if (
integration
and isinstance(integration, dict)
and integration.get("type").lower() == IntegrationType.aws_proxy.value
):
# Integration must be "aws_proxy" otherwise we don't care about it
return integration
return None
def _get_integration_function_name(self, method_config):
"""
Tries to parse the Lambda Function name from the Integration defined in the method configuration.
Integration configuration is defined under the special "x-amazon-apigateway-integration" key. We care only
about Lambda integrations, which are of type aws_proxy, and ignore the rest. Integration URI is complex and
hard to parse. Hence we do our best to extract function name out of integration URI. If not possible, we
return None.
Parameters
----------
method_config : dict
Dictionary containing the method configuration which might contain integration settings
Returns
-------
string or None
Lambda function name, if possible. None, if not.
"""
integration = self._get_integration(method_config)
if integration is None:
return None
return LambdaUri.get_function_name(integration.get("uri"))
def _get_payload_format_version(self, method_config):
"""
Get the "payloadFormatVersion" from the Integration defined in the method configuration.
Parameters
----------
method_config : dict
Dictionary containing the method configuration which might contain integration settings
Returns
-------
string or None
Payload format version, if exists. None, if not.
"""
integration = self._get_integration(method_config)
if integration is None:
return None
return integration.get("payloadFormatVersion")