in stepfn-batch/app.py [0:0]
def __init__(self, app: core.App, id: str, **kwargs) -> None:
super().__init__(app, id, **kwargs)
# Create both lambdas
with open("lambda-submit.py", encoding="utf8") as fp:
lambda_submit_code = fp.read()
lambdaFn1 = lambda_.Function(
self,
"submitsmbatch",
code=lambda_.InlineCode(lambda_submit_code),
handler="index.lambda_handler",
timeout=core.Duration.seconds(300),
runtime=lambda_.Runtime.PYTHON_3_7,
environment={
"transform_job_name":transform_job_name,
"model_name":model_name,
"max_concurrent":max_concurrent,
"max_payload_size":max_payload_size,
"s3_uri_in":s3_uri_in,
"s3_uri_out":s3_uri_out,
"instance_type":instance_type,
"instance_count":instance_count,
}
)
# Add perms
lambdaFn1.add_to_role_policy(aws_iam.PolicyStatement(
actions = ['sagemaker:CreateTransformJob',],
resources = ['arn:aws:sagemaker:{}:{}:transform-job/{}*'.format(my_region,my_acc_id,transform_job_name),]
))
with open("lambda-check.py", encoding="utf8") as fp:
lambda_check_code = fp.read()
lambdaFn2 = lambda_.Function(
self,
"checksmbatch",
code=lambda_.InlineCode(lambda_check_code),
handler="index.lambda_handler",
timeout=core.Duration.seconds(300),
runtime=lambda_.Runtime.PYTHON_3_7,
environment={"model_name":model_name, # CHANGE TO YOUR ENDPOINT NAME!!
"content_type":"text/csv"}
)
# Add perms
lambdaFn2.add_to_role_policy(aws_iam.PolicyStatement(
actions = ['sagemaker:DescribeTransformJob',],
resources = ['arn:aws:sagemaker:{}:{}:transform-job/{}*'.format(my_region,my_acc_id,transform_job_name),]
))
# Define state machine
# submit_job_activity = sfn.Activity(
# self, "SubmitJob"
# )
# check_job_activity = sfn.Activity(
# self, "CheckJob"
# )
submit_job = sfn.Task(
self, "Submit Job",
task=sfn_tasks.InvokeFunction(lambdaFn1),
)
wait_x = sfn.Wait(
self, "Wait 1 minute",
time=sfn.WaitTime.duration(core.Duration.minutes(1)),
)
get_status = sfn.Task(
self, "Get Job Status",
task=sfn_tasks.InvokeFunction(lambdaFn2),
)
is_complete = sfn.Choice(
self, "Job Complete?"
)
job_failed = sfn.Fail(
self, "Job Failed",
cause="AWS Batch Job Failed",
error="DescribeJob returned FAILED"
)
final_status = sfn.Task(
self, "Get Final Job Status",
task=sfn_tasks.InvokeFunction(lambdaFn2),
)
definition = submit_job\
.next(wait_x)\
.next(get_status)\
.next(is_complete
.when(sfn.Condition.string_equals(
"$.status", "Failed"), job_failed)
.when(sfn.Condition.string_equals(
"$.status", "Completed"), final_status)
.otherwise(wait_x))
sfn.StateMachine(
self, "SMbatchInference",
definition=definition,
)