samtranslator/model/api/api_generator.py (1,159 lines of code) (raw):
import logging
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
from samtranslator.feature_toggle.feature_toggle import FeatureToggle
from samtranslator.metrics.method_decorator import cw_timer
from samtranslator.model import Resource
from samtranslator.model.apigateway import (
ApiGatewayApiKey,
ApiGatewayAuthorizer,
ApiGatewayBasePathMapping,
ApiGatewayBasePathMappingV2,
ApiGatewayDeployment,
ApiGatewayDomainName,
ApiGatewayDomainNameAccessAssociation,
ApiGatewayDomainNameV2,
ApiGatewayResponse,
ApiGatewayRestApi,
ApiGatewayStage,
ApiGatewayUsagePlan,
ApiGatewayUsagePlanKey,
)
from samtranslator.model.exceptions import (
ExpectedType,
InvalidDocumentException,
InvalidResourceException,
InvalidTemplateException,
)
from samtranslator.model.intrinsics import fnGetAtt, fnSub, is_intrinsic, make_or_condition, ref
from samtranslator.model.lambda_ import LambdaPermission
from samtranslator.model.route53 import Route53RecordSetGroup
from samtranslator.model.s3_utils.uri_parser import parse_s3_uri
from samtranslator.model.tags.resource_tagging import get_tag_list
from samtranslator.model.types import PassThrough
from samtranslator.region_configuration import RegionConfiguration
from samtranslator.swagger.swagger import SwaggerEditor
from samtranslator.translator.arn_generator import ArnGenerator
from samtranslator.translator.logical_id_generator import LogicalIdGenerator
from samtranslator.utils.py27hash_fix import Py27Dict, Py27UniStr
from samtranslator.utils.types import Intrinsicable
from samtranslator.utils.utils import InvalidValueType, dict_deep_get
from samtranslator.validator.value_validator import sam_expect
LOG = logging.getLogger(__name__)
FEATURE_FLAG_NORMALIZED_OPENAPI_VERSION = "normalized_open_api_version"
_CORS_WILDCARD = "'*'"
CorsProperties = namedtuple(
"CorsProperties", ["AllowMethods", "AllowHeaders", "AllowOrigin", "MaxAge", "AllowCredentials"]
)
# Default the Cors Properties to '*' wildcard and False AllowCredentials. Other properties are actually Optional
CorsProperties.__new__.__defaults__ = (None, None, _CORS_WILDCARD, None, False)
AuthProperties = namedtuple(
"AuthProperties",
[
"Authorizers",
"DefaultAuthorizer",
"InvokeRole",
"AddDefaultAuthorizerToCorsPreflight",
"AddApiKeyRequiredToCorsPreflight",
"ApiKeyRequired",
"ResourcePolicy",
"UsagePlan",
],
)
AuthProperties.__new__.__defaults__ = (None, None, None, True, True, None, None, None)
UsagePlanProperties = namedtuple(
"UsagePlanProperties", ["CreateUsagePlan", "Description", "Quota", "Tags", "Throttle", "UsagePlanName"]
)
UsagePlanProperties.__new__.__defaults__ = (None, None, None, None, None, None)
GatewayResponseProperties = ["ResponseParameters", "ResponseTemplates", "StatusCode"]
@dataclass
class ApiDomainResponse:
domain: Optional[ApiGatewayDomainName]
apigw_basepath_mapping_list: Optional[List[ApiGatewayBasePathMapping]]
recordset_group: Any
@dataclass
class ApiDomainResponseV2:
domain: Optional[ApiGatewayDomainNameV2]
apigw_basepath_mapping_list: Optional[List[ApiGatewayBasePathMappingV2]]
recordset_group: Any
domain_access_association: Any
class SharedApiUsagePlan:
"""
Collects API information from different API resources in the same template,
so that these information can be used in the shared usage plan
"""
SHARED_USAGE_PLAN_CONDITION_NAME = "SharedUsagePlanCondition"
def __init__(self) -> None:
self.usage_plan_shared = False
self.stage_keys_shared: List[str] = []
self.api_stages_shared: List[str] = []
self.depends_on_shared: List[str] = []
# shared resource level attributes
self.conditions: Set[str] = set()
self.any_api_without_condition = False
self.deletion_policy: Optional[str] = None
self.update_replace_policy: Optional[str] = None
def get_combined_resource_attributes(self, resource_attributes, conditions): # type: ignore[no-untyped-def]
"""
This method returns a dictionary which combines 'DeletionPolicy', 'UpdateReplacePolicy' and 'Condition'
values of API definitions that could be used in Shared Usage Plan resources.
Parameters
----------
resource_attributes: Dict[str]
A dictionary of resource level attributes of the API resource
conditions: Dict[str]
Conditions section of the template
"""
self._set_deletion_policy(resource_attributes.get("DeletionPolicy")) # type: ignore[no-untyped-call]
self._set_update_replace_policy(resource_attributes.get("UpdateReplacePolicy")) # type: ignore[no-untyped-call]
self._set_condition(resource_attributes.get("Condition"), conditions) # type: ignore[no-untyped-call]
combined_resource_attributes = {}
if self.deletion_policy:
combined_resource_attributes["DeletionPolicy"] = self.deletion_policy
if self.update_replace_policy:
combined_resource_attributes["UpdateReplacePolicy"] = self.update_replace_policy
# do not set Condition if any of the API resource does not have Condition in it
if self.conditions and not self.any_api_without_condition:
combined_resource_attributes["Condition"] = SharedApiUsagePlan.SHARED_USAGE_PLAN_CONDITION_NAME
return combined_resource_attributes
def _set_deletion_policy(self, deletion_policy): # type: ignore[no-untyped-def]
if deletion_policy:
if self.deletion_policy:
# update only if new deletion policy is Retain
if deletion_policy == "Retain":
self.deletion_policy = deletion_policy
else:
self.deletion_policy = deletion_policy
def _set_update_replace_policy(self, update_replace_policy): # type: ignore[no-untyped-def]
if update_replace_policy:
if self.update_replace_policy:
# if new value is Retain or
# new value is retain and current value is Delete then update its value
if (update_replace_policy == "Retain") or (
update_replace_policy == "Snapshot" and self.update_replace_policy == "Delete"
):
self.update_replace_policy = update_replace_policy
else:
self.update_replace_policy = update_replace_policy
def _set_condition(self, condition, template_conditions): # type: ignore[no-untyped-def]
# if there are any API without condition, then skip
if self.any_api_without_condition:
return
if condition and condition not in self.conditions:
if template_conditions is None:
raise InvalidTemplateException(
"Can't have condition without having 'Conditions' section in the template"
)
if self.conditions:
self.conditions.add(condition)
or_condition = make_or_condition(self.conditions)
template_conditions[SharedApiUsagePlan.SHARED_USAGE_PLAN_CONDITION_NAME] = or_condition
else:
self.conditions.add(condition)
template_conditions[SharedApiUsagePlan.SHARED_USAGE_PLAN_CONDITION_NAME] = condition
elif condition is None:
self.any_api_without_condition = True
if template_conditions and SharedApiUsagePlan.SHARED_USAGE_PLAN_CONDITION_NAME in template_conditions:
del template_conditions[SharedApiUsagePlan.SHARED_USAGE_PLAN_CONDITION_NAME]
class ApiGenerator:
def __init__( # noqa: PLR0913
self,
logical_id: str,
cache_cluster_enabled: Optional[Intrinsicable[bool]],
cache_cluster_size: Optional[Intrinsicable[str]],
variables: Optional[Dict[str, Any]],
depends_on: Optional[List[str]],
definition_body: Optional[Dict[str, Any]],
definition_uri: Optional[Intrinsicable[str]],
name: Optional[Intrinsicable[str]],
stage_name: Optional[Intrinsicable[str]],
shared_api_usage_plan: Any,
template_conditions: Any,
merge_definitions: Optional[bool] = None,
tags: Optional[Dict[str, Any]] = None,
endpoint_configuration: Optional[Dict[str, Any]] = None,
method_settings: Optional[List[Any]] = None,
binary_media: Optional[List[Any]] = None,
minimum_compression_size: Optional[Intrinsicable[int]] = None,
disable_execute_api_endpoint: Optional[Intrinsicable[bool]] = None,
cors: Optional[Intrinsicable[str]] = None,
auth: Optional[Dict[str, Any]] = None,
gateway_responses: Optional[Dict[str, Any]] = None,
access_log_setting: Optional[Dict[str, Any]] = None,
canary_setting: Optional[Dict[str, Any]] = None,
tracing_enabled: Optional[Intrinsicable[bool]] = None,
resource_attributes: Optional[Dict[str, Any]] = None,
passthrough_resource_attributes: Optional[Dict[str, Any]] = None,
open_api_version: Optional[Intrinsicable[str]] = None,
models: Optional[Dict[str, Any]] = None,
domain: Optional[Dict[str, Any]] = None,
fail_on_warnings: Optional[Intrinsicable[bool]] = None,
description: Optional[Intrinsicable[str]] = None,
mode: Optional[Intrinsicable[str]] = None,
api_key_source_type: Optional[Intrinsicable[str]] = None,
always_deploy: Optional[bool] = False,
feature_toggle: Optional[FeatureToggle] = None,
policy: Optional[Union[Dict[str, Any], Intrinsicable[str]]] = None,
):
"""Constructs an API Generator class that generates API Gateway resources
:param logical_id: Logical id of the SAM API Resource
:param cache_cluster_enabled: Whether cache cluster is enabled
:param cache_cluster_size: Size of the cache cluster
:param variables: API Gateway Variables
:param depends_on: Any resources that need to be depended on
:param definition_body: API definition
:param definition_uri: URI to API definition
:param name: Name of the API Gateway resource
:param stage_name: Name of the Stage
:param tags: Stage Tags
:param access_log_setting: Whether to send access logs and where for Stage
:param canary_setting: Canary Setting for Stage
:param tracing_enabled: Whether active tracing with X-ray is enabled
:param resource_attributes: Resource attributes to add to API resources
:param passthrough_resource_attributes: Attributes such as `Condition` that are added to derived resources
:param models: Model definitions to be used by API methods
:param description: Description of the API Gateway resource
"""
self.logical_id = logical_id
self.cache_cluster_enabled = cache_cluster_enabled
self.cache_cluster_size = cache_cluster_size
self.variables = variables
self.depends_on = depends_on
self.definition_body = definition_body
self.definition_uri = definition_uri
self.merge_definitions = merge_definitions
self.name = name
self.stage_name = stage_name
self.tags = tags
self.endpoint_configuration = endpoint_configuration
self.method_settings = method_settings
self.binary_media = binary_media
self.minimum_compression_size = minimum_compression_size
self.disable_execute_api_endpoint = disable_execute_api_endpoint
self.cors = cors
self.auth = auth
self.gateway_responses = gateway_responses
self.access_log_setting = access_log_setting
self.canary_setting = canary_setting
self.tracing_enabled = tracing_enabled
self.resource_attributes = resource_attributes
self.passthrough_resource_attributes = passthrough_resource_attributes
self.open_api_version = open_api_version
self.remove_extra_stage = open_api_version
self.models = models
self.domain = domain
self.fail_on_warnings = fail_on_warnings
self.description = description
self.shared_api_usage_plan = shared_api_usage_plan
self.template_conditions = template_conditions
self.mode = mode
self.api_key_source_type = api_key_source_type
self.always_deploy = always_deploy
self.feature_toggle = feature_toggle
self.policy = policy
def _construct_rest_api(self) -> ApiGatewayRestApi:
"""Constructs and returns the ApiGateway RestApi.
:returns: the RestApi to which this SAM Api corresponds
:rtype: model.apigateway.ApiGatewayRestApi
"""
self._validate_properties()
rest_api = ApiGatewayRestApi(self.logical_id, depends_on=self.depends_on, attributes=self.resource_attributes)
# NOTE: For backwards compatibility we need to retain BinaryMediaTypes on the CloudFormation Property
# Removing this and only setting x-amazon-apigateway-binary-media-types results in other issues.
rest_api.BinaryMediaTypes = self.binary_media
rest_api.MinimumCompressionSize = self.minimum_compression_size
if self.endpoint_configuration:
self._set_endpoint_configuration(rest_api, self.endpoint_configuration)
elif not RegionConfiguration.is_apigw_edge_configuration_supported():
# Since this region does not support EDGE configuration, we explicitly set the endpoint type
# to Regional which is the only supported config.
self._set_endpoint_configuration(rest_api, "REGIONAL")
self._add_cors()
self._add_auth()
self._add_gateway_responses()
self._add_binary_media_types()
self._add_models()
if self.fail_on_warnings:
rest_api.FailOnWarnings = self.fail_on_warnings
if self.disable_execute_api_endpoint is not None:
self._add_endpoint_extension()
if self.definition_uri:
rest_api.BodyS3Location = self._construct_body_s3_dict()
elif self.definition_body:
# # Post Process OpenApi Auth Settings
self.definition_body = self._openapi_postprocess(self.definition_body)
rest_api.Body = self.definition_body
if self.name:
rest_api.Name = self.name
if self.description:
rest_api.Description = self.description
if self.mode:
rest_api.Mode = self.mode
if self.api_key_source_type:
rest_api.ApiKeySourceType = self.api_key_source_type
if self.policy:
rest_api.Policy = self.policy
return rest_api
def _validate_properties(self) -> None:
if self.definition_uri and self.definition_body:
raise InvalidResourceException(
self.logical_id, "Specify either 'DefinitionUri' or 'DefinitionBody' property and not both."
)
if self.definition_uri and self.merge_definitions:
raise InvalidResourceException(
self.logical_id, "Cannot set 'MergeDefinitions' to True when using `DefinitionUri`."
)
if self.open_api_version and not SwaggerEditor.safe_compare_regex_with_string(
SwaggerEditor.get_openapi_versions_supported_regex(), self.open_api_version
):
raise InvalidResourceException(self.logical_id, "The OpenApiVersion value must be of the format '3.0.0'.")
def _add_endpoint_extension(self) -> None:
"""Add disableExecuteApiEndpoint if it is set in SAM
Note:
If neither DefinitionUri nor DefinitionBody are specified,
SAM will generate a openapi definition body based on template configuration.
https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/sam-resource-api.html#sam-api-definitionbody
For this reason, we always put DisableExecuteApiEndpoint into openapi object irrespective of origin of DefinitionBody.
"""
if self.disable_execute_api_endpoint is not None and not self.definition_body:
raise InvalidResourceException(
self.logical_id, "DisableExecuteApiEndpoint works only within 'DefinitionBody' property."
)
editor = SwaggerEditor(self.definition_body)
editor.add_disable_execute_api_endpoint_extension(self.disable_execute_api_endpoint)
self.definition_body = editor.swagger
def _construct_body_s3_dict(self) -> Dict[str, Any]:
"""Constructs the RestApi's `BodyS3Location property`_, from the SAM Api's DefinitionUri property.
:returns: a BodyS3Location dict, containing the S3 Bucket, Key, and Version of the Swagger definition
:rtype: dict
"""
if isinstance(self.definition_uri, dict):
if not self.definition_uri.get("Bucket", None) or not self.definition_uri.get("Key", None):
# DefinitionUri is a dictionary but does not contain Bucket or Key property
raise InvalidResourceException(
self.logical_id, "'DefinitionUri' requires Bucket and Key properties to be specified."
)
s3_pointer = self.definition_uri
else:
# DefinitionUri is a string
_parsed_s3_pointer = parse_s3_uri(self.definition_uri)
if _parsed_s3_pointer is None:
raise InvalidResourceException(
self.logical_id,
"'DefinitionUri' is not a valid S3 Uri of the form "
"'s3://bucket/key' with optional versionId query parameter.",
)
s3_pointer = _parsed_s3_pointer
if isinstance(self.definition_uri, Py27UniStr):
# self.defintion_uri is a Py27UniStr instance if it is defined in the template
# we need to preserve the Py27UniStr type
s3_pointer["Bucket"] = Py27UniStr(s3_pointer["Bucket"])
s3_pointer["Key"] = Py27UniStr(s3_pointer["Key"])
if "Version" in s3_pointer:
s3_pointer["Version"] = Py27UniStr(s3_pointer["Version"])
# Construct body_s3 as py27 dict
body_s3 = Py27Dict()
body_s3["Bucket"] = s3_pointer["Bucket"]
body_s3["Key"] = s3_pointer["Key"]
if "Version" in s3_pointer:
body_s3["Version"] = s3_pointer["Version"]
return body_s3
def _construct_deployment(self, rest_api: ApiGatewayRestApi) -> ApiGatewayDeployment:
"""Constructs and returns the ApiGateway Deployment.
:param model.apigateway.ApiGatewayRestApi rest_api: the RestApi for this Deployment
:returns: the Deployment to which this SAM Api corresponds
:rtype: model.apigateway.ApiGatewayDeployment
"""
deployment = ApiGatewayDeployment(
self.logical_id + "Deployment", attributes=self.passthrough_resource_attributes
)
deployment.RestApiId = rest_api.get_runtime_attr("rest_api_id")
if not self.remove_extra_stage:
deployment.StageName = "Stage"
return deployment
def _construct_stage(
self, deployment: ApiGatewayDeployment, swagger: Optional[Dict[str, Any]], redeploy_restapi_parameters: Any
) -> ApiGatewayStage:
"""Constructs and returns the ApiGateway Stage.
:param model.apigateway.ApiGatewayDeployment deployment: the Deployment for this Stage
:returns: the Stage to which this SAM Api corresponds
:rtype: model.apigateway.ApiGatewayStage
"""
# If StageName is some intrinsic function, then don't prefix the Stage's logical ID
# This will NOT create duplicates because we allow only ONE stage per API resource
stage_name_prefix = self.stage_name if isinstance(self.stage_name, str) else ""
if stage_name_prefix.isalnum():
stage_logical_id = self.logical_id + stage_name_prefix + "Stage"
else:
generator = LogicalIdGenerator(self.logical_id + "Stage", stage_name_prefix)
stage_logical_id = generator.gen()
stage = ApiGatewayStage(stage_logical_id, attributes=self.passthrough_resource_attributes)
stage.RestApiId = ref(self.logical_id)
stage.update_deployment_ref(deployment.logical_id)
stage.StageName = self.stage_name
stage.CacheClusterEnabled = self.cache_cluster_enabled
stage.CacheClusterSize = self.cache_cluster_size
stage.Variables = self.variables
stage.MethodSettings = self.method_settings
stage.AccessLogSetting = self.access_log_setting
stage.CanarySetting = self.canary_setting
stage.TracingEnabled = self.tracing_enabled
if swagger is not None:
deployment.make_auto_deployable(
stage,
self.remove_extra_stage,
swagger,
self.domain,
redeploy_restapi_parameters,
self.always_deploy,
)
if self.tags is not None:
stage.Tags = get_tag_list(self.tags)
return stage
def _construct_api_domain( # noqa: PLR0912, PLR0915
self, rest_api: ApiGatewayRestApi, route53_record_set_groups: Any
) -> ApiDomainResponse:
"""
Constructs and returns the ApiGateway Domain and BasepathMapping
"""
if self.domain is None:
return ApiDomainResponse(None, None, None)
sam_expect(self.domain, self.logical_id, "Domain").to_be_a_map()
domain_name: PassThrough = sam_expect(
self.domain.get("DomainName"), self.logical_id, "Domain.DomainName"
).to_not_be_none()
certificate_arn: PassThrough = sam_expect(
self.domain.get("CertificateArn"), self.logical_id, "Domain.CertificateArn"
).to_not_be_none()
api_domain_name = "{}{}".format("ApiGatewayDomainName", LogicalIdGenerator("", domain_name).gen())
self.domain["ApiDomainName"] = api_domain_name
domain = ApiGatewayDomainName(api_domain_name, attributes=self.passthrough_resource_attributes)
domain.DomainName = domain_name
endpoint = self.domain.get("EndpointConfiguration")
if endpoint is None:
endpoint = "REGIONAL"
self.domain["EndpointConfiguration"] = "REGIONAL"
elif endpoint not in ["EDGE", "REGIONAL", "PRIVATE"]:
raise InvalidResourceException(
self.logical_id,
"EndpointConfiguration for Custom Domains must be"
" one of {}.".format(["EDGE", "REGIONAL", "PRIVATE"]),
)
if endpoint == "REGIONAL":
domain.RegionalCertificateArn = certificate_arn
else:
domain.CertificateArn = certificate_arn
domain.EndpointConfiguration = {"Types": [endpoint]}
mutual_tls_auth = self.domain.get("MutualTlsAuthentication", None)
if mutual_tls_auth:
sam_expect(mutual_tls_auth, self.logical_id, "Domain.MutualTlsAuthentication").to_be_a_map()
if not set(mutual_tls_auth.keys()).issubset({"TruststoreUri", "TruststoreVersion"}):
invalid_keys = []
for key in mutual_tls_auth:
if key not in {"TruststoreUri", "TruststoreVersion"}:
invalid_keys.append(key)
invalid_keys.sort()
raise InvalidResourceException(
self.logical_id,
"Available Domain.MutualTlsAuthentication fields are {}.".format(
["TruststoreUri", "TruststoreVersion"]
),
)
domain.MutualTlsAuthentication = {}
if mutual_tls_auth.get("TruststoreUri", None):
domain.MutualTlsAuthentication["TruststoreUri"] = mutual_tls_auth["TruststoreUri"]
if mutual_tls_auth.get("TruststoreVersion", None):
domain.MutualTlsAuthentication["TruststoreVersion"] = mutual_tls_auth["TruststoreVersion"]
self._set_optional_domain_properties(domain)
basepaths: Optional[List[str]]
basepath_value = self.domain.get("BasePath")
# Create BasepathMappings
if self.domain.get("BasePath") and isinstance(basepath_value, str):
basepaths = [basepath_value]
elif self.domain.get("BasePath") and isinstance(basepath_value, list):
basepaths = cast(Optional[List[Any]], basepath_value)
else:
basepaths = None
# Boolean to allow/disallow symbols in BasePath property
normalize_basepath = self.domain.get("NormalizeBasePath", True)
basepath_resource_list: List[ApiGatewayBasePathMapping] = []
if basepaths is None:
basepath_mapping = self._create_basepath_mapping(api_domain_name, rest_api, None, None)
basepath_resource_list.extend([basepath_mapping])
else:
sam_expect(basepaths, self.logical_id, "Domain.BasePath").to_be_a_list_of(ExpectedType.STRING)
for basepath in basepaths:
# Remove possible leading and trailing '/' because a base path may only
# contain letters, numbers, and one of "$-_.+!*'()"
path = "".join(e for e in basepath if e.isalnum())
mapping_basepath = path if normalize_basepath else basepath
logical_id = "{}{}{}".format(self.logical_id, path, "BasePathMapping")
basepath_mapping = self._create_basepath_mapping(
api_domain_name, rest_api, logical_id, mapping_basepath
)
basepath_resource_list.extend([basepath_mapping])
# Create the Route53 RecordSetGroup resource
record_set_group = None
route53 = self.domain.get("Route53")
if route53 is not None:
sam_expect(route53, self.logical_id, "Domain.Route53").to_be_a_map()
if route53.get("HostedZoneId") is None and route53.get("HostedZoneName") is None:
raise InvalidResourceException(
self.logical_id,
"HostedZoneId or HostedZoneName is required to enable Route53 support on Custom Domains.",
)
logical_id_suffix = LogicalIdGenerator(
"", route53.get("HostedZoneId") or route53.get("HostedZoneName")
).gen()
logical_id = "RecordSetGroup" + logical_id_suffix
record_set_group = route53_record_set_groups.get(logical_id)
if route53.get("SeparateRecordSetGroup"):
sam_expect(
route53.get("SeparateRecordSetGroup"), self.logical_id, "Domain.Route53.SeparateRecordSetGroup"
).to_be_a_bool()
return ApiDomainResponse(
domain,
basepath_resource_list,
self._construct_single_record_set_group(self.domain, api_domain_name, route53),
)
if not record_set_group:
record_set_group = self._get_record_set_group(logical_id, route53)
route53_record_set_groups[logical_id] = record_set_group
record_set_group.RecordSets += self._construct_record_sets_for_domain(self.domain, api_domain_name, route53)
return ApiDomainResponse(domain, basepath_resource_list, record_set_group)
def _construct_api_domain_v2(
self, rest_api: ApiGatewayRestApi, route53_record_set_groups: Any
) -> ApiDomainResponseV2:
"""
Constructs and returns the ApiGateway Domain V2 and BasepathMapping V2
"""
if self.domain is None:
return ApiDomainResponseV2(None, None, None, None)
sam_expect(self.domain, self.logical_id, "Domain").to_be_a_map()
domain_name: PassThrough = sam_expect(
self.domain.get("DomainName"), self.logical_id, "Domain.DomainName"
).to_not_be_none()
certificate_arn: PassThrough = sam_expect(
self.domain.get("CertificateArn"), self.logical_id, "Domain.CertificateArn"
).to_not_be_none()
api_domain_name = "{}{}".format("ApiGatewayDomainNameV2", LogicalIdGenerator("", domain_name).gen())
domain_name_arn = ref(api_domain_name)
domain = ApiGatewayDomainNameV2(api_domain_name, attributes=self.passthrough_resource_attributes)
domain.DomainName = domain_name
endpoint = self.domain.get("EndpointConfiguration")
if endpoint not in ["EDGE", "REGIONAL", "PRIVATE"]:
raise InvalidResourceException(
self.logical_id,
"EndpointConfiguration for Custom Domains must be"
" one of {}.".format(["EDGE", "REGIONAL", "PRIVATE"]),
)
domain.CertificateArn = certificate_arn
domain.EndpointConfiguration = {"Types": [endpoint]}
self._set_optional_domain_properties(domain)
basepaths: Optional[List[str]] = self._get_basepaths()
# Boolean to allow/disallow symbols in BasePath property
normalize_basepath = self.domain.get("NormalizeBasePath", True)
basepath_resource_list: List[ApiGatewayBasePathMappingV2] = []
if basepaths is None:
basepath_mapping = self._create_basepath_mapping_v2(domain_name_arn, rest_api)
basepath_resource_list.extend([basepath_mapping])
else:
sam_expect(basepaths, self.logical_id, "Domain.BasePath").to_be_a_list_of(ExpectedType.STRING)
for basepath in basepaths:
# Remove possible leading and trailing '/' because a base path may only
# contain letters, numbers, and one of "$-_.+!*'()"
path = "".join(e for e in basepath if e.isalnum())
logical_id = "{}{}{}".format(self.logical_id, path, "BasePathMapping")
basepath_mapping = ApiGatewayBasePathMappingV2(
logical_id, attributes=self.passthrough_resource_attributes
)
basepath_mapping.DomainNameArn = domain_name_arn
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")
basepath_mapping.BasePath = path if normalize_basepath else basepath
basepath_resource_list.extend([basepath_mapping])
# Create the DomainNameAccessAssociation
domain_access_association = self.domain.get("AccessAssociation")
domain_access_association_resource = None
if domain_access_association is not None:
domain_access_association_resource = self._generate_domain_access_association(
domain_access_association, domain_name_arn, api_domain_name
)
# Create the Route53 RecordSetGroup resource
record_set_group = None
route53 = self.domain.get("Route53")
if route53 is not None:
sam_expect(route53, self.logical_id, "Domain.Route53").to_be_a_map()
if route53.get("HostedZoneId") is None and route53.get("HostedZoneName") is None:
raise InvalidResourceException(
self.logical_id,
"HostedZoneId or HostedZoneName is required to enable Route53 support on Custom Domains.",
)
logical_id_suffix = LogicalIdGenerator(
"", route53.get("HostedZoneId") or route53.get("HostedZoneName")
).gen()
logical_id = "RecordSetGroup" + logical_id_suffix
record_set_group = route53_record_set_groups.get(logical_id)
if route53.get("SeparateRecordSetGroup"):
sam_expect(
route53.get("SeparateRecordSetGroup"), self.logical_id, "Domain.Route53.SeparateRecordSetGroup"
).to_be_a_bool()
return ApiDomainResponseV2(
domain,
basepath_resource_list,
self._construct_single_record_set_group(self.domain, domain_name, route53),
domain_access_association_resource,
)
if not record_set_group:
record_set_group = self._get_record_set_group(logical_id, route53)
route53_record_set_groups[logical_id] = record_set_group
record_set_group.RecordSets += self._construct_record_sets_for_domain(self.domain, domain_name, route53)
return ApiDomainResponseV2(domain, basepath_resource_list, record_set_group, domain_access_association_resource)
def _get_basepaths(self) -> Optional[List[str]]:
if self.domain is None:
return None
basepath_value = self.domain.get("BasePath")
if self.domain.get("BasePath") and isinstance(basepath_value, str):
return [basepath_value]
if self.domain.get("BasePath") and isinstance(basepath_value, list):
return cast(Optional[List[Any]], basepath_value)
return None
def _set_optional_domain_properties(self, domain: Union[ApiGatewayDomainName, ApiGatewayDomainNameV2]) -> None:
if self.domain is None:
return
if self.domain.get("SecurityPolicy", None):
domain.SecurityPolicy = self.domain["SecurityPolicy"]
if self.domain.get("Policy", None):
domain.Policy = self.domain["Policy"]
if self.domain.get("OwnershipVerificationCertificateArn", None):
domain.OwnershipVerificationCertificateArn = self.domain["OwnershipVerificationCertificateArn"]
def _get_record_set_group(self, logical_id: str, route53: Dict[str, Any]) -> Route53RecordSetGroup:
record_set_group = Route53RecordSetGroup(logical_id, attributes=self.passthrough_resource_attributes)
if "HostedZoneId" in route53:
record_set_group.HostedZoneId = route53.get("HostedZoneId")
if "HostedZoneName" in route53:
record_set_group.HostedZoneName = route53.get("HostedZoneName")
record_set_group.RecordSets = []
return record_set_group
def _construct_single_record_set_group(
self, domain: Dict[str, Any], api_domain_name: str, route53: Any
) -> Route53RecordSetGroup:
hostedZoneId = route53.get("HostedZoneId")
hostedZoneName = route53.get("HostedZoneName")
domainName = domain.get("DomainName")
logical_id = LogicalIdGenerator("RecordSetGroup", [hostedZoneId or hostedZoneName, domainName]).gen()
record_set_group = Route53RecordSetGroup(logical_id, attributes=self.passthrough_resource_attributes)
if hostedZoneId:
record_set_group.HostedZoneId = hostedZoneId
if hostedZoneName:
record_set_group.HostedZoneName = hostedZoneName
record_set_group.RecordSets = []
record_set_group.RecordSets += self._construct_record_sets_for_domain(domain, api_domain_name, route53)
return record_set_group
def _construct_record_sets_for_domain(
self, custom_domain_config: Dict[str, Any], api_domain_name: str, route53_config: Dict[str, Any]
) -> List[Dict[str, Any]]:
recordset_list = []
alias_target = self._construct_alias_target(custom_domain_config, api_domain_name, route53_config)
recordset = {}
recordset["Name"] = custom_domain_config.get("DomainName")
recordset["Type"] = "A"
recordset["AliasTarget"] = alias_target
self._update_route53_routing_policy_properties(route53_config, recordset)
recordset_list.append(recordset)
if route53_config.get("IpV6") is not None and route53_config.get("IpV6") is True:
recordset_ipv6 = {}
recordset_ipv6["Name"] = custom_domain_config.get("DomainName")
recordset_ipv6["Type"] = "AAAA"
recordset_ipv6["AliasTarget"] = alias_target
self._update_route53_routing_policy_properties(route53_config, recordset_ipv6)
recordset_list.append(recordset_ipv6)
return recordset_list
@staticmethod
def _update_route53_routing_policy_properties(route53_config: Dict[str, Any], recordset: Dict[str, Any]) -> None:
if route53_config.get("Region") is not None:
recordset["Region"] = route53_config.get("Region")
if route53_config.get("SetIdentifier") is not None:
recordset["SetIdentifier"] = route53_config.get("SetIdentifier")
def _construct_alias_target(self, domain: Dict[str, Any], api_domain_name: str, route53: Any) -> Dict[str, Any]:
alias_target = {}
target_health = route53.get("EvaluateTargetHealth")
if target_health is not None:
alias_target["EvaluateTargetHealth"] = target_health
if domain.get("EndpointConfiguration") == "REGIONAL":
alias_target["HostedZoneId"] = fnGetAtt(api_domain_name, "RegionalHostedZoneId")
alias_target["DNSName"] = fnGetAtt(api_domain_name, "RegionalDomainName")
elif domain.get("EndpointConfiguration") == "EDGE":
if route53.get("DistributionDomainName") is None:
route53["DistributionDomainName"] = fnGetAtt(api_domain_name, "DistributionDomainName")
alias_target["HostedZoneId"] = "Z2FDTNDATAQYW2"
alias_target["DNSName"] = route53.get("DistributionDomainName")
else:
alias_target["HostedZoneId"] = route53.get("VpcEndpointHostedZoneId")
alias_target["DNSName"] = route53.get("VpcEndpointDomainName")
return alias_target
def _create_basepath_mapping(
self,
api_domain_name: PassThrough,
rest_api: ApiGatewayRestApi,
logical_id: Optional[str],
basepath: Optional[str],
) -> ApiGatewayBasePathMapping:
basepath_mapping: ApiGatewayBasePathMapping
basepath_mapping = (
ApiGatewayBasePathMapping(logical_id, attributes=self.passthrough_resource_attributes)
if logical_id
else ApiGatewayBasePathMapping(
self.logical_id + "BasePathMapping", attributes=self.passthrough_resource_attributes
)
)
basepath_mapping.DomainName = ref(api_domain_name)
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")
if basepath is not None:
basepath_mapping.BasePath = basepath
return basepath_mapping
def _create_basepath_mapping_v2(
self, domain_name_arn: PassThrough, rest_api: ApiGatewayRestApi
) -> ApiGatewayBasePathMappingV2:
basepath_mapping = ApiGatewayBasePathMappingV2(
self.logical_id + "BasePathMapping", attributes=self.passthrough_resource_attributes
)
basepath_mapping.DomainNameArn = domain_name_arn
basepath_mapping.RestApiId = ref(rest_api.logical_id)
basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage")
return basepath_mapping
@cw_timer(prefix="Generator", name="Api")
def to_cloudformation(
self, redeploy_restapi_parameters: Optional[Any], route53_record_set_groups: Dict[str, Route53RecordSetGroup]
) -> List[Resource]:
"""Generates CloudFormation resources from a SAM API resource
:returns: a tuple containing the RestApi, Deployment, and Stage for an empty Api.
:rtype: tuple
"""
api_domain_response: Union[ApiDomainResponseV2, ApiDomainResponse]
domain: Union[Resource, None]
basepath_mapping: Union[List[ApiGatewayBasePathMapping], List[ApiGatewayBasePathMappingV2], None]
rest_api = self._construct_rest_api()
is_private_domain = isinstance(self.domain, dict) and self.domain.get("EndpointConfiguration") == "PRIVATE"
api_domain_response = (
self._construct_api_domain_v2(rest_api, route53_record_set_groups)
if is_private_domain
else self._construct_api_domain(rest_api, route53_record_set_groups)
)
domain_access_association = None
if is_private_domain:
domain_access_association = cast(ApiDomainResponseV2, api_domain_response).domain_access_association
domain = api_domain_response.domain
basepath_mapping = api_domain_response.apigw_basepath_mapping_list
route53_recordsetGroup = api_domain_response.recordset_group
deployment = self._construct_deployment(rest_api)
swagger = None
if rest_api.Body is not None:
swagger = rest_api.Body
elif rest_api.BodyS3Location is not None:
swagger = rest_api.BodyS3Location
stage = self._construct_stage(deployment, swagger, redeploy_restapi_parameters)
permissions = self._construct_authorizer_lambda_permission()
usage_plan = self._construct_usage_plan(rest_api_stage=stage)
# mypy complains if the type in List doesn't match exactly
# TODO: refactor to have a list of single resource
generated_resources: List[
Union[
Optional[Resource],
List[Resource],
Tuple[Resource],
List[LambdaPermission],
List[ApiGatewayBasePathMapping],
List[ApiGatewayBasePathMappingV2],
],
] = []
generated_resources.extend(
[
rest_api,
deployment,
stage,
permissions,
domain,
basepath_mapping,
route53_recordsetGroup,
usage_plan,
]
)
if domain_access_association is not None:
generated_resources.append(domain_access_association)
# Make a list of single resources
generated_resources_list: List[Resource] = []
for resource in generated_resources:
if resource:
if isinstance(resource, (list, tuple)):
generated_resources_list.extend(resource)
else:
generated_resources_list.extend([resource])
return generated_resources_list
def _add_cors(self) -> None:
"""
Add CORS configuration to the Swagger file, if necessary
"""
INVALID_ERROR = "Invalid value for 'Cors' property"
if not self.cors:
return
if self.cors and not self.definition_body:
raise InvalidResourceException(
self.logical_id, "Cors works only with inline Swagger specified in 'DefinitionBody' property."
)
if isinstance(self.cors, str) or is_intrinsic(self.cors):
# Just set Origin property. Others will be defaults
properties = CorsProperties(AllowOrigin=self.cors) # type: ignore[call-arg]
elif isinstance(self.cors, dict):
# Make sure keys in the dict are recognized
for key in self.cors:
if key not in CorsProperties._fields:
raise InvalidResourceException(self.logical_id, f"Invalid key '{key}' for 'Cors' property.")
properties = CorsProperties(**self.cors)
else:
raise InvalidResourceException(self.logical_id, INVALID_ERROR)
if not SwaggerEditor.is_valid(self.definition_body):
raise InvalidResourceException(
self.logical_id,
"Unable to add Cors configuration because "
"'DefinitionBody' does not contain a valid Swagger definition.",
)
if properties.AllowCredentials is True and properties.AllowOrigin == _CORS_WILDCARD:
raise InvalidResourceException(
self.logical_id,
"Unable to add Cors configuration because "
"'AllowCredentials' can not be true when "
"'AllowOrigin' is \"'*'\" or not set",
)
editor = SwaggerEditor(self.definition_body)
for path in editor.iter_on_path():
try:
editor.add_cors( # type: ignore[no-untyped-call]
path,
properties.AllowOrigin,
properties.AllowHeaders,
properties.AllowMethods,
max_age=properties.MaxAge,
allow_credentials=properties.AllowCredentials,
)
except InvalidTemplateException as ex:
raise InvalidResourceException(self.logical_id, ex.message) from ex
# Assign the Swagger back to template
self.definition_body = editor.swagger
def _add_binary_media_types(self) -> None:
"""
Add binary media types to Swagger
"""
if not self.binary_media:
return
# We don't raise an error here like we do for similar cases because that would be backwards incompatible
if self.binary_media and not self.definition_body:
return
editor = SwaggerEditor(self.definition_body)
editor.add_binary_media_types(self.binary_media) # type: ignore[no-untyped-call]
# Assign the Swagger back to template
self.definition_body = editor.swagger
def _add_auth(self) -> None:
"""
Add Auth configuration to the Swagger file, if necessary
"""
if not self.auth:
return
if self.auth and not self.definition_body:
raise InvalidResourceException(
self.logical_id, "Auth works only with inline Swagger specified in 'DefinitionBody' property."
)
# Make sure keys in the dict are recognized
if not all(key in AuthProperties._fields for key in self.auth):
raise InvalidResourceException(self.logical_id, "Invalid value for 'Auth' property")
if not SwaggerEditor.is_valid(self.definition_body):
raise InvalidResourceException(
self.logical_id,
"Unable to add Auth configuration because "
"'DefinitionBody' does not contain a valid Swagger definition.",
)
swagger_editor = SwaggerEditor(self.definition_body)
auth_properties = AuthProperties(**self.auth)
authorizers = self._get_authorizers(auth_properties.Authorizers, auth_properties.DefaultAuthorizer) # type: ignore[no-untyped-call]
if authorizers:
swagger_editor.add_authorizers_security_definitions(authorizers) # type: ignore[no-untyped-call]
self._set_default_authorizer(
swagger_editor,
authorizers,
auth_properties.DefaultAuthorizer,
auth_properties.AddDefaultAuthorizerToCorsPreflight,
)
if auth_properties.ApiKeyRequired:
swagger_editor.add_apikey_security_definition()
self._set_default_apikey_required(swagger_editor, auth_properties.AddApiKeyRequiredToCorsPreflight)
if auth_properties.ResourcePolicy:
SwaggerEditor.validate_is_dict(
auth_properties.ResourcePolicy, "ResourcePolicy must be a map (ResourcePolicyStatement)."
)
for path in swagger_editor.iter_on_path():
swagger_editor.add_resource_policy(auth_properties.ResourcePolicy, path, self.stage_name)
if auth_properties.ResourcePolicy.get("CustomStatements"):
swagger_editor.add_custom_statements(auth_properties.ResourcePolicy.get("CustomStatements")) # type: ignore[no-untyped-call]
self.definition_body = self._openapi_postprocess(swagger_editor.swagger)
def _construct_usage_plan(self, rest_api_stage: Optional[ApiGatewayStage] = None) -> Any: # noqa: PLR0912
"""Constructs and returns the ApiGateway UsagePlan, ApiGateway UsagePlanKey, ApiGateway ApiKey for Auth.
:param model.apigateway.ApiGatewayStage stage: the stage of rest api
:returns: UsagePlan, UsagePlanKey, ApiKey for this rest Api
:rtype: model.apigateway.ApiGatewayUsagePlan, model.apigateway.ApiGatewayUsagePlanKey,
model.apigateway.ApiGatewayApiKey
"""
create_usage_plans_accepted_values = ["SHARED", "PER_API", "NONE"]
if not self.auth:
return []
auth_properties = AuthProperties(**self.auth)
if auth_properties.UsagePlan is None:
return []
usage_plan_properties = auth_properties.UsagePlan
# throws error if UsagePlan is not a dict
if not isinstance(usage_plan_properties, dict):
raise InvalidResourceException(self.logical_id, "'UsagePlan' must be a dictionary")
# throws error if the property invalid/ unsupported for UsagePlan
if not all(key in UsagePlanProperties._fields for key in usage_plan_properties):
raise InvalidResourceException(self.logical_id, "Invalid property for 'UsagePlan'")
create_usage_plan = usage_plan_properties.get("CreateUsagePlan")
usage_plan: Optional[ApiGatewayUsagePlan] = None
api_key = None
usage_plan_key = None
if create_usage_plan is None:
raise InvalidResourceException(self.logical_id, "'CreateUsagePlan' is a required field for UsagePlan.")
if create_usage_plan not in create_usage_plans_accepted_values:
raise InvalidResourceException(
self.logical_id, f"'CreateUsagePlan' accepts one of {create_usage_plans_accepted_values}."
)
if create_usage_plan == "NONE":
return []
if not rest_api_stage:
return []
# create usage plan for this api only
if usage_plan_properties.get("CreateUsagePlan") == "PER_API":
usage_plan_logical_id = self.logical_id + "UsagePlan"
usage_plan = ApiGatewayUsagePlan(
logical_id=usage_plan_logical_id,
depends_on=[self.logical_id],
attributes=self.passthrough_resource_attributes,
)
api_stages = []
api_stage = {}
api_stage["ApiId"] = ref(self.logical_id)
api_stage["Stage"] = ref(rest_api_stage.logical_id)
api_stages.append(api_stage)
usage_plan.ApiStages = api_stages
api_key = self._construct_api_key(usage_plan_logical_id, create_usage_plan, rest_api_stage)
usage_plan_key = self._construct_usage_plan_key(usage_plan_logical_id, create_usage_plan, api_key)
# create a usage plan for all the Apis
elif create_usage_plan == "SHARED":
LOG.info("Creating SHARED usage plan for all the Apis")
usage_plan_logical_id = "ServerlessUsagePlan"
if self.logical_id not in self.shared_api_usage_plan.depends_on_shared:
self.shared_api_usage_plan.depends_on_shared.append(self.logical_id)
usage_plan = ApiGatewayUsagePlan(
logical_id=usage_plan_logical_id,
depends_on=self.shared_api_usage_plan.depends_on_shared,
attributes=self.shared_api_usage_plan.get_combined_resource_attributes(
self.passthrough_resource_attributes, self.template_conditions
),
)
api_stage = {}
api_stage["ApiId"] = ref(self.logical_id)
api_stage["Stage"] = ref(rest_api_stage.logical_id)
if api_stage not in self.shared_api_usage_plan.api_stages_shared:
self.shared_api_usage_plan.api_stages_shared.append(api_stage)
usage_plan.ApiStages = self.shared_api_usage_plan.api_stages_shared
api_key = self._construct_api_key(usage_plan_logical_id, create_usage_plan, rest_api_stage)
usage_plan_key = self._construct_usage_plan_key(usage_plan_logical_id, create_usage_plan, api_key)
for name in ["UsagePlanName", "Description", "Quota", "Tags", "Throttle"]:
if usage_plan and usage_plan_properties.get(name):
setattr(usage_plan, name, usage_plan_properties.get(name))
return usage_plan, api_key, usage_plan_key
def _construct_api_key(
self, usage_plan_logical_id: str, create_usage_plan: Any, rest_api_stage: ApiGatewayStage
) -> ApiGatewayApiKey:
"""
:param usage_plan_logical_id: String
:param create_usage_plan: String
:param rest_api_stage: model.apigateway.ApiGatewayStage stage: the stage of rest api
:return: api_key model.apigateway.ApiGatewayApiKey resource which is created for the given usage plan
"""
if create_usage_plan == "SHARED":
# create an api key resource for all the apis
LOG.info("Creating api key resource for all the Apis from SHARED usage plan")
api_key_logical_id = "ServerlessApiKey"
api_key = ApiGatewayApiKey(
logical_id=api_key_logical_id,
depends_on=[usage_plan_logical_id],
attributes=self.shared_api_usage_plan.get_combined_resource_attributes(
self.passthrough_resource_attributes, self.template_conditions
),
)
api_key.Enabled = True
stage_key = {}
stage_key["RestApiId"] = ref(self.logical_id)
stage_key["StageName"] = ref(rest_api_stage.logical_id)
if stage_key not in self.shared_api_usage_plan.stage_keys_shared:
self.shared_api_usage_plan.stage_keys_shared.append(stage_key)
api_key.StageKeys = self.shared_api_usage_plan.stage_keys_shared
# for create_usage_plan = "PER_API"
else:
# create an api key resource for this api
api_key_logical_id = self.logical_id + "ApiKey"
api_key = ApiGatewayApiKey(
logical_id=api_key_logical_id,
depends_on=[usage_plan_logical_id],
attributes=self.passthrough_resource_attributes,
)
api_key.Enabled = True
stage_keys = []
stage_key = {}
stage_key["RestApiId"] = ref(self.logical_id)
stage_key["StageName"] = ref(rest_api_stage.logical_id)
stage_keys.append(stage_key)
api_key.StageKeys = stage_keys
return api_key
def _construct_usage_plan_key(
self, usage_plan_logical_id: str, create_usage_plan: Any, api_key: ApiGatewayApiKey
) -> ApiGatewayUsagePlanKey:
"""
:param usage_plan_logical_id: String
:param create_usage_plan: String
:param api_key: model.apigateway.ApiGatewayApiKey resource
:return: model.apigateway.ApiGatewayUsagePlanKey resource that contains the mapping between usage plan and api key
"""
if create_usage_plan == "SHARED":
# create a mapping between api key and the usage plan
usage_plan_key_logical_id = "ServerlessUsagePlanKey"
resource_attributes = self.shared_api_usage_plan.get_combined_resource_attributes(
self.passthrough_resource_attributes, self.template_conditions
)
# for create_usage_plan = "PER_API"
else:
# create a mapping between api key and the usage plan
usage_plan_key_logical_id = self.logical_id + "UsagePlanKey"
resource_attributes = self.passthrough_resource_attributes
usage_plan_key = ApiGatewayUsagePlanKey(
logical_id=usage_plan_key_logical_id,
depends_on=[api_key.logical_id],
attributes=resource_attributes,
)
usage_plan_key.KeyId = ref(api_key.logical_id)
usage_plan_key.KeyType = "API_KEY"
usage_plan_key.UsagePlanId = ref(usage_plan_logical_id)
return usage_plan_key
def _add_gateway_responses(self) -> None:
"""
Add Gateway Response configuration to the Swagger file, if necessary
"""
if not self.gateway_responses:
return
if self.gateway_responses and not self.definition_body:
raise InvalidResourceException(
self.logical_id,
"GatewayResponses works only with inline Swagger specified in 'DefinitionBody' property.",
)
# Make sure keys in the dict are recognized
for responses_key, responses_value in self.gateway_responses.items():
if is_intrinsic(responses_value):
# TODO: Add intrinsic support for this field.
raise InvalidResourceException(
self.logical_id,
"Unable to set GatewayResponses attribute because "
"intrinsic functions are not supported for this field.",
)
if not isinstance(responses_value, dict):
raise InvalidResourceException(
self.logical_id,
f"Invalid property type '{type(responses_value).__name__}' for GatewayResponses. "
"Expected an object of type 'GatewayResponse'.",
)
for response_key in responses_value:
if response_key not in GatewayResponseProperties:
raise InvalidResourceException(
self.logical_id,
f"Invalid property '{response_key}' in 'GatewayResponses' property '{responses_key}'.",
)
if not SwaggerEditor.is_valid(self.definition_body):
raise InvalidResourceException(
self.logical_id,
"Unable to add Auth configuration because "
"'DefinitionBody' does not contain a valid Swagger definition.",
)
swagger_editor = SwaggerEditor(self.definition_body)
# The dicts below will eventually become part of swagger/openapi definition, thus requires using Py27Dict()
gateway_responses = Py27Dict()
for response_type, response in self.gateway_responses.items():
sam_expect(response, self.logical_id, f"GatewayResponses.{response_type}").to_be_a_map()
response_parameters = response.get("ResponseParameters", Py27Dict())
response_templates = response.get("ResponseTemplates", Py27Dict())
if response_parameters:
sam_expect(
response_parameters, self.logical_id, f"GatewayResponses.{response_type}.ResponseParameters"
).to_be_a_map()
gateway_responses[response_type] = ApiGatewayResponse(
api_logical_id=self.logical_id,
response_parameters=response_parameters,
response_templates=response_templates,
status_code=response.get("StatusCode", None),
)
if gateway_responses:
swagger_editor.add_gateway_responses(gateway_responses) # type: ignore[no-untyped-call]
# Assign the Swagger back to template
self.definition_body = swagger_editor.swagger
def _add_models(self) -> None:
"""
Add Model definitions to the Swagger file, if necessary
:return:
"""
if not self.models:
return
if self.models and not self.definition_body:
raise InvalidResourceException(
self.logical_id, "Models works only with inline Swagger specified in 'DefinitionBody' property."
)
if not SwaggerEditor.is_valid(self.definition_body):
raise InvalidResourceException(
self.logical_id,
"Unable to add Models definitions because "
"'DefinitionBody' does not contain a valid Swagger definition.",
)
if not all(isinstance(model, dict) for model in self.models.values()):
raise InvalidResourceException(self.logical_id, "Invalid value for 'Models' property")
swagger_editor = SwaggerEditor(self.definition_body)
swagger_editor.add_models(self.models) # type: ignore[no-untyped-call]
# Assign the Swagger back to template
self.definition_body = self._openapi_postprocess(swagger_editor.swagger)
def _openapi_postprocess(self, definition_body: Dict[str, Any]) -> Dict[str, Any]: # noqa: PLR0912
"""
Convert definitions to openapi 3 in definition body if OpenApiVersion flag is specified.
If the is swagger defined in the definition body, we treat it as a swagger spec and do not
make any openapi 3 changes to it
"""
if definition_body.get("swagger") is not None:
return definition_body
if self.feature_toggle and self.feature_toggle.is_enabled(FEATURE_FLAG_NORMALIZED_OPENAPI_VERSION):
normalized_open_api_version = definition_body.get("openapi", self.open_api_version)
elif definition_body.get("openapi") is not None and self.open_api_version is None:
normalized_open_api_version = definition_body.get("openapi")
else:
normalized_open_api_version = self.open_api_version
if normalized_open_api_version and SwaggerEditor.safe_compare_regex_with_string(
SwaggerEditor._OPENAPI_VERSION_3_REGEX, normalized_open_api_version
):
if definition_body.get("securityDefinitions"):
components = definition_body.get("components", Py27Dict())
# In the previous line, the default value `Py27Dict()` will be only returned only if `components`
# property is not in definition_body dict, but if it exist, and its value is None, so None will be
# returned and not the default value. That is why the below line is required.
components = components if components else Py27Dict()
components["securitySchemes"] = definition_body["securityDefinitions"]
definition_body["components"] = components
del definition_body["securityDefinitions"]
if definition_body.get("definitions"):
components = definition_body.get("components", Py27Dict())
# the following line to check if components is None
# is copied from the previous if...
# In the previous line, the default value `Py27Dict()` will be only returned only if `components`
# property is not in definition_body dict, but if it exist, and its value is None, so None will be
# returned and not the default value. That is why the below line is required.
components = components if components else Py27Dict()
components["schemas"] = definition_body["definitions"]
definition_body["components"] = components
del definition_body["definitions"]
# removes `consumes` and `produces` options for CORS in openapi3 and
# adds `schema` for the headers in responses for openapi3
paths = definition_body.get("paths")
if paths:
SwaggerEditor.validate_is_dict(
paths,
"Value of paths must be a dictionary according to Swagger spec.",
)
for path, path_item in paths.items():
SwaggerEditor.validate_path_item_is_dict(path_item, path)
if path_item.get("options"):
SwaggerEditor.validate_is_dict(
path_item.get("options"),
f"Value of options method for path {path} must be a "
"dictionary according to Swagger spec.",
)
options = path_item.get("options").copy()
for field, field_val in options.items():
# remove unsupported produces and consumes in options for openapi3
if field in ["produces", "consumes"]:
del definition_body["paths"][path]["options"][field]
# add schema for the headers in options section for openapi3
if field in ["responses"]:
try:
response_200_headers = dict_deep_get(field_val, "200.headers")
except InvalidValueType as ex:
raise InvalidDocumentException(
[
InvalidTemplateException(
f"Invalid responses in options method for path {path}: {ex!s}.",
)
]
) from ex
if not response_200_headers:
continue
SwaggerEditor.validate_is_dict(
response_200_headers,
f"Value of response's headers in options method for path {path} must be a "
"dictionary according to Swagger spec.",
)
for header, header_val in response_200_headers.items():
new_header_val_with_schema = Py27Dict()
new_header_val_with_schema["schema"] = header_val
definition_body["paths"][path]["options"][field]["200"]["headers"][
header
] = new_header_val_with_schema
return definition_body
def _get_authorizers(self, authorizers_config, default_authorizer=None): # type: ignore[no-untyped-def]
# The dict below will eventually become part of swagger/openapi definition, thus requires using Py27Dict()
authorizers = Py27Dict()
if default_authorizer == "AWS_IAM":
authorizers[default_authorizer] = ApiGatewayAuthorizer(
api_logical_id=self.logical_id, name=default_authorizer, is_aws_iam_authorizer=True
)
if not authorizers_config:
if "AWS_IAM" in authorizers:
return authorizers
return None
sam_expect(authorizers_config, self.logical_id, "Auth.Authorizers").to_be_a_map()
for authorizer_name, authorizer in authorizers_config.items():
sam_expect(authorizer, self.logical_id, f"Auth.Authorizers.{authorizer_name}").to_be_a_map()
authorizers[authorizer_name] = ApiGatewayAuthorizer(
api_logical_id=self.logical_id,
name=authorizer_name,
user_pool_arn=authorizer.get("UserPoolArn"),
function_arn=authorizer.get("FunctionArn"),
identity=authorizer.get("Identity"),
function_payload_type=authorizer.get("FunctionPayloadType"),
function_invoke_role=authorizer.get("FunctionInvokeRole"),
authorization_scopes=authorizer.get("AuthorizationScopes"),
disable_function_default_permissions=authorizer.get("DisableFunctionDefaultPermissions"),
)
return authorizers
def _get_permission(self, authorizer_name, authorizer_lambda_function_arn): # type: ignore[no-untyped-def]
"""Constructs and returns the Lambda Permission resource allowing the Authorizer to invoke the function.
:returns: the permission resource
:rtype: model.lambda_.LambdaPermission
"""
rest_api = ApiGatewayRestApi(self.logical_id, depends_on=self.depends_on, attributes=self.resource_attributes)
api_id = rest_api.get_runtime_attr("rest_api_id")
partition = ArnGenerator.get_partition_name()
resource = "${__ApiId__}/authorizers/*"
source_arn = fnSub(
ArnGenerator.generate_arn(partition=partition, service="execute-api", resource=resource),
{"__ApiId__": api_id},
)
lambda_permission = LambdaPermission(
self.logical_id + authorizer_name + "AuthorizerPermission", attributes=self.passthrough_resource_attributes
)
lambda_permission.Action = "lambda:InvokeFunction"
lambda_permission.FunctionName = authorizer_lambda_function_arn
lambda_permission.Principal = "apigateway.amazonaws.com"
lambda_permission.SourceArn = source_arn
return lambda_permission
def _construct_authorizer_lambda_permission(self) -> List[LambdaPermission]:
if not self.auth:
return []
auth_properties = AuthProperties(**self.auth)
authorizers = self._get_authorizers(auth_properties.Authorizers) # type: ignore[no-untyped-call]
if not authorizers:
return []
permissions = []
for authorizer_name, authorizer in authorizers.items():
# Construct permissions for Lambda Authorizers only
if not authorizer.function_arn or authorizer.disable_function_default_permissions:
continue
permission = self._get_permission(authorizer_name, authorizer.function_arn) # type: ignore[no-untyped-call]
permissions.append(permission)
return permissions
def _set_default_authorizer(
self,
swagger_editor: SwaggerEditor,
authorizers: Dict[str, ApiGatewayAuthorizer],
default_authorizer: str,
add_default_auth_to_preflight: bool = True,
) -> None:
if not default_authorizer:
return
if not isinstance(default_authorizer, str):
raise InvalidResourceException(
self.logical_id,
"DefaultAuthorizer is not a string.",
)
if not authorizers.get(default_authorizer) and default_authorizer != "AWS_IAM":
raise InvalidResourceException(
self.logical_id,
"Unable to set DefaultAuthorizer because '"
+ default_authorizer
+ "' was not defined in 'Authorizers'.",
)
for path in swagger_editor.iter_on_path():
swagger_editor.set_path_default_authorizer(
path,
default_authorizer,
authorizers=authorizers,
add_default_auth_to_preflight=add_default_auth_to_preflight,
)
def _set_default_apikey_required(self, swagger_editor: SwaggerEditor, required_options_api_key: bool) -> None:
for path in swagger_editor.iter_on_path():
swagger_editor.set_path_default_apikey_required(path, required_options_api_key)
def _set_endpoint_configuration(self, rest_api: ApiGatewayRestApi, value: Union[str, Dict[str, Any]]) -> None:
"""
Sets endpoint configuration property of AWS::ApiGateway::RestApi resource
:param rest_api: RestApi resource
:param string/dict value: Value to be set
"""
if isinstance(value, dict) and value.get("Type"):
rest_api.Parameters = {"endpointConfigurationTypes": value.get("Type")}
rest_api.EndpointConfiguration = {"Types": [value.get("Type")]}
# SAM API `EndpointConfiguration` uses `VPCEndpointIds` but APIGateway RestApi uses `VpcEndpointIds`.
# Deny when both properties are defined at the same time.
if "VPCEndpointIds" in value and "VpcEndpointIds" in value:
raise InvalidResourceException(
rest_api.logical_id,
"'VPCEndpointIds' and 'VpcEndpointIds' cannot be used together in EndpointConfiguration.",
)
# Accept when either `VPCEndpointIds` or `VpcEndpointIds` is defined by users
if "VPCEndpointIds" in value or "VpcEndpointIds" in value:
rest_api.EndpointConfiguration["VpcEndpointIds"] = value.get("VPCEndpointIds") or value.get(
"VpcEndpointIds"
)
else:
rest_api.EndpointConfiguration = {"Types": [value]}
rest_api.Parameters = {"endpointConfigurationTypes": value}
def _generate_domain_access_association(
self,
domain_access_association: Dict[str, Any],
domain_name_arn: Dict[str, str],
domain_logical_id: str,
) -> ApiGatewayDomainNameAccessAssociation:
"""
Generate domain access association resource
"""
vpcEndpointId = domain_access_association.get("VpcEndpointId")
logical_id = LogicalIdGenerator("DomainNameAccessAssociation", [vpcEndpointId, domain_logical_id]).gen()
domain_access_association_resource = ApiGatewayDomainNameAccessAssociation(
logical_id, attributes=self.passthrough_resource_attributes
)
domain_access_association_resource.DomainNameArn = domain_name_arn
domain_access_association_resource.AccessAssociationSourceType = "VPCE"
domain_access_association_resource.AccessAssociationSource = vpcEndpointId
return domain_access_association_resource