in source/infrastructure/forecast/stack.py [0:0]
def __init__(self, scope: Construct, construct_id: str, *args, **kwargs) -> None:
super().__init__(scope, construct_id, *args, **kwargs)
# Parameters
self.parameters = Parameters(self, "ForecastStackParameters")
# Conditions
create_notebook = CfnCondition(
self,
"CreateNotebook",
expression=Fn.condition_equals(self.parameters.notebook_deploy, "Yes"),
)
email_provided = CfnCondition(
self,
"EmailProvided",
expression=Fn.condition_not(Fn.condition_equals(self.parameters.email, "")),
)
create_analysis = CfnCondition(
self,
"CreateAnalysis",
expression=Fn.condition_not(
Fn.condition_equals(self.parameters.quicksight_analysis_owner, ""),
),
)
forecast_kms_enabled = CfnCondition(
self,
"ForecastSseKmsEnabled",
expression=Fn.condition_not(
Fn.condition_equals(self.parameters.forecast_kms_key_arn, "")
),
)
create_forecast_cdn = CfnCondition(
self,
"CreateForecast",
expression=Fn.condition_equals(self.parameters.forecast_deploy, "Yes"),
)
# Buckets
data_bucket_name_resource = ResourceName(
self,
"DataBucketName",
purpose="data-bucket",
max_length=63,
)
access_logs_bucket = AccessLogsBucket(self)
athena_bucket = AthenaBucket(
self,
server_access_logs_bucket=access_logs_bucket,
server_access_logs_prefix="athena-bucket-access-logs/",
)
data_bucket = DataBucket(
self,
bucket_name=data_bucket_name_resource.resource_name.to_string(),
server_access_logs_bucket=access_logs_bucket,
server_access_logs_prefix="forecast-bucket-access-logs/",
)
policy_factory = PolicyFactory(
self,
"ForecastPolicyFactory",
data_bucket=data_bucket,
kms_key_arn=self.parameters.forecast_kms_key_arn.value_as_string,
kms_enabled=forecast_kms_enabled,
)
# Lambda Functions
default_timeout = Duration.minutes(3)
solution_layer = ForecastSolutionLayer(self, "SolutionLayer")
create_dataset_group = CreateDatasetGroup(
self, "CreateDatasetGroup", layers=[solution_layer], timeout=default_timeout
)
create_dataset_import_job = CreateDatasetImportJob(
self,
"CreateDatasetImportJob",
layers=[solution_layer],
timeout=default_timeout,
)
create_predictor = CreatePredictor(
self, "CreatePredictor", layers=[solution_layer], timeout=default_timeout
)
create_forecast = CreateForecast(
self, "CreateForecast", layers=[solution_layer], timeout=default_timeout
)
create_forecast_export = CreateForecastExport(
self,
"CreateForecastExport",
layers=[solution_layer],
timeout=default_timeout,
)
create_predictor_backtest_export = CreatePredictorBacktestExport(
self,
"CreatePredictorBacktestExport",
layers=[solution_layer],
timeout=default_timeout,
)
create_glue_table_name = CreateGlueTableName(
self, "CreateGlueTableName", layers=[solution_layer]
)
create_quicksight_analysis = CreateQuickSightAnalysis(
self,
"CreateQuickSightAnalysis",
layers=[solution_layer],
timeout=Duration.minutes(15),
)
notifications = Notifications(
self,
"SNS Notification",
email=self.parameters.email,
email_provided=email_provided,
layers=[solution_layer],
)
# State Machine
check_error = Choice(self, "Check for Error")
notify_failure = notifications.state(
self, "Notify on Failure", result_path=JsonPath.DISCARD
)
notify_success = notifications.state(
self, "Notify on Success", result_path=JsonPath.DISCARD
)
create_predictor_state = create_predictor.state(
self,
"Create Predictor",
result_path="$.PredictorArn", # NOSONAR (python:S1192) - string for clarity
max_attempts=100,
interval=Duration.seconds(120),
backoff_rate=1.02,
)
create_predictor_state.start_state.add_catch(
Succeed(self, "Update Not Required"), errors=["NotMostRecentUpdate"]
)
create_predictor_state.start_state.add_retry(
backoff_rate=1.02,
interval=Duration.seconds(120),
max_attempts=100,
errors=["DatasetsImporting"],
)
forecast_etl = GlueStartJobRun(
self,
"Forecast ETL",
glue_job_name=f"{Aws.STACK_NAME}-Forecast-ETL",
integration_pattern=IntegrationPattern.RUN_JOB,
result_path=JsonPath.DISCARD,
arguments=TaskInput.from_object(
{
"--dataset_group": JsonPath.string_at("$.dataset_group_name"),
"--glue_table_name": JsonPath.string_at("$.glue_table_name"),
}
),
)
forecast_etl.add_retry(
backoff_rate=1.02,
interval=Duration.seconds(120),
max_attempts=100,
errors=["Glue.ConcurrentRunsExceededException"],
)
definition = Chain.start(
check_error.when(
Condition.is_present("$.error.serviceError"), notify_failure
).otherwise(
Parallel(self, "Manage the Execution")
.branch(
create_dataset_group.state(
self,
"Create Dataset Group",
result_path="$.DatasetGroupNames",
)
.next(
create_dataset_import_job.state(
self,
"Create Dataset Import Job",
result_path="$.DatasetImportJobArn",
max_attempts=100,
interval=Duration.seconds(120),
backoff_rate=1.02,
)
)
.next(
Map(
self,
"Create Forecasts",
items_path="$.DatasetGroupNames",
parameters={
"bucket.$": "$.bucket",
"dataset_file.$": "$.dataset_file",
"dataset_group_name.$": "$$.Map.Item.Value",
"config.$": "$.config",
},
).iterator(
create_predictor_state.next(
create_forecast.state(
self,
"Create Forecast",
result_path="$.ForecastArn",
max_attempts=100,
)
).next(
Parallel(
self,
"Export Predictor Backtest and Forecast",
result_path=JsonPath.DISCARD,
)
.branch(
create_forecast_export.state(
self,
"Create Forecast Export",
result_path="$.PredictorArn", # NOSONAR (python:S1192) - string for clarity
max_attempts=100,
)
)
.branch(
create_predictor_backtest_export.state(
self,
"Create Predictor Backtest Export",
result_path="$.PredictorArn", # NOSONAR (python:S1192) - string for clarity
max_attempts=100,
)
)
.next(
create_glue_table_name.state(
self,
"Create Glue Table Name",
result_path="$.glue_table_name",
)
)
.next(forecast_etl)
.next(
create_quicksight_analysis.state(
self,
"Create QuickSight Analysis",
result_path=JsonPath.DISCARD,
)
)
.next(notify_success)
)
)
)
)
.add_catch(
notify_failure.next(Fail(self, "Failure")), result_path="$.error"
)
)
)
# fmt: on
state_machine_namer = ResourceName(
self, "StateMachineName", purpose="forecast-workflow", max_length=80
)
state_machine = StateMachine(
self,
"ForecastStateMachine",
state_machine_name=state_machine_namer.resource_name.to_string(),
definition=definition,
tracing_enabled=True,
)
add_cfn_nag_suppressions(
resource=state_machine.role.node.children[1].node.default_child,
suppressions=[
CfnNagSuppression(
"W76",
"Large step functions need larger IAM roles to access all managed lambda functions",
),
CfnNagSuppression(
"W12", "IAM policy for AWS X-Ray requires an allow on *"
),
],
)
# S3 Notifications
s3_event_handler = S3EventHandler(
self,
"S3EventHandler",
state_machine=state_machine,
bucket=data_bucket,
layers=[solution_layer],
timeout=Duration.minutes(1),
)
s3_event_notification = LambdaDestination(s3_event_handler)
data_bucket.add_event_notification(
EventType.OBJECT_CREATED,
s3_event_notification,
NotificationKeyFilter(prefix="train/", suffix=".csv"),
)
# Handle suppressions for the notification handler resource generated by CDK
bucket_notification_handler = self.node.try_find_child(
"BucketNotificationsHandler050a0587b7544547bf325f094a3db834"
)
bucket_notification_policy = (
bucket_notification_handler.node.find_child("Role")
.node.find_child("DefaultPolicy")
.node.find_child("Resource")
)
add_cfn_nag_suppressions(
bucket_notification_policy,
[
CfnNagSuppression(
"W12",
"bucket resource is '*' due to circular dependency with bucket and role creation at the same time",
)
],
)
# ETL Components
glue = Glue(
self,
"GlueResources",
unique_name=data_bucket_name_resource.resource_id.to_string(),
forecast_bucket=data_bucket,
athena_bucket=athena_bucket,
glue_jobs_path=Path(__file__).parents[2] / "glue" / "jobs",
)
athena = Athena(self, "AthenaResources", athena_bucket=athena_bucket)
# Permissions
policy_factory.grant_forecast_read_write(create_dataset_group.function)
policy_factory.grant_forecast_read_write(create_dataset_import_job.function)
policy_factory.grant_forecast_read_write(create_predictor.function)
policy_factory.grant_forecast_read_write(
create_predictor_backtest_export.function
)
policy_factory.grant_forecast_read(create_forecast.function)
policy_factory.grant_forecast_read_write(create_forecast_export.function)
policy_factory.grant_forecast_read(create_quicksight_analysis.function)
policy_factory.quicksight_access(
create_quicksight_analysis.function,
catalog=glue.database,
workgroup=athena.workgroup,
quicksight_principal=self.parameters.quicksight_analysis_owner,
quicksight_source=self.mappings.source_mapping,
athena_bucket=athena_bucket,
data_bucket=data_bucket,
)
data_bucket.grant_read(create_dataset_group.function)
data_bucket.grant_read(create_dataset_import_job.function)
data_bucket.grant_read(create_predictor.function)
data_bucket.grant_read_write(create_predictor_backtest_export.function)
data_bucket.grant_read(create_forecast.function)
data_bucket.grant_read_write(create_forecast_export.function)
data_bucket.grant_read(s3_event_handler)
# Notebook
Notebook(
self,
"Notebook",
buckets=[data_bucket],
instance_type=self.parameters.notebook_instance_type.value_as_string,
instance_volume_size=self.parameters.notebook_volume_size.value_as_number,
notebook_path=Path(__file__).parents[2]
/ "notebook"
/ "samples"
/ "notebooks",
notebook_destination_bucket=data_bucket,
notebook_destination_prefix="notebooks",
create_notebook=create_notebook,
)
# Demo components
self.forecast_defaults_url_info = UrlHelper(
self,
"ForecastDefaults",
self.parameters.forecast_defaults_url.value_as_string,
)
self.tts_url_info = UrlHelper(
self, "TTS", self.parameters.tts_url.value_as_string
)
self.rts_url_info = UrlHelper(
self, "RTS", self.parameters.rts_url.value_as_string
)
self.md_url_info = UrlHelper(self, "MD", self.parameters.md_url.value_as_string)
# prepare the nested stack that performs the actual downloads
downloader = Downloader(
self,
"Downloader",
description="Improving Forecast Accuracy with Machine Learning Data copier",
template_filename="improving-forecast-accuracy-with-machine-learning-downloader.template",
parameters={
"DestinationBucket": data_bucket.bucket_name,
"ForecastName": self.parameters.forecast_name.value_as_string,
"Version": self.node.try_get_context("SOLUTION_VERSION"),
**self.forecast_defaults_url_info.properties,
**self.tts_url_info.properties,
**self.rts_url_info.properties,
**self.md_url_info.properties,
},
)
downloader.nested_stack_resource.override_logical_id("Downloader")
Aspects.of(downloader.nested_stack_resource).add(
ConditionalResources(create_forecast_cdn)
)
# Tagging
Tags.of(self).add("SOLUTION_ID", self.node.try_get_context("SOLUTION_ID"))
Tags.of(self).add("SOLUTION_NAME", self.node.try_get_context("SOLUTION_NAME"))
Tags.of(self).add(
"SOLUTION_VERSION", self.node.try_get_context("SOLUTION_VERSION")
)
# Aspects
Aspects.of(self).add(
CfnNagSuppressAll(
suppress=[
CfnNagSuppression(
"W89",
"functions deployed by this solution do not require VPC access",
),
CfnNagSuppression(
"W92",
"functions deployed by this solution do not require reserved concurrency",
),
CfnNagSuppression(
"W58",
"functions deployed by this solution use custom policy to write to CloudWatch logs",
),
],
resource_type="AWS::Lambda::Function",
)
)
# Metrics
self.metrics.update(
{
"Solution": self.mappings.solution_mapping.find_in_map("Data", "ID"),
"Version": self.mappings.solution_mapping.find_in_map(
"Data", "Version"
),
"Region": Aws.REGION,
"NotebookDeployed": Fn.condition_if(
create_notebook.node.id, "Yes", "No"
),
"NotebookType": Fn.condition_if(
create_notebook.node.id,
self.parameters.notebook_instance_type.value_as_string,
Aws.NO_VALUE,
),
"QuickSightDeployed": Fn.condition_if(
create_analysis.node.id, "Yes", "No"
),
"ForecastDeployed": Fn.condition_if(
create_forecast_cdn.node.id, "Yes", "No"
),
}
)
# Outputs
CfnOutput(
self,
"ForecastBucketName",
value=data_bucket.bucket_name,
export_name=f"{Aws.STACK_NAME}-ForecastBucketName",
)
CfnOutput(
self,
"AthenaBucketName",
value=athena_bucket.bucket_name,
export_name=f"{Aws.STACK_NAME}-AthenaBucketName",
)
CfnOutput(
self,
"StepFunctionsName",
value=state_machine.state_machine_name,
export_name=f"{Aws.STACK_NAME}-StepFunctionsName",
)