in model/run_pipeline.py [0:0]
def create_graph(create_experiment_step, baseline_step, training_step):
sagemaker_jobs = steps.states.Parallel("SageMaker Jobs")
sagemaker_jobs.add_branch(baseline_step)
sagemaker_jobs.add_branch(training_step)
# Do we need specific failure for the jobs for group?
sagemaker_jobs.add_catch(
stepfunctions.steps.states.Catch(
error_equals=["States.TaskFailed"],
next_step=stepfunctions.steps.states.Fail(
"SageMaker Jobs failed", cause="SageMakerJobsFailed"
),
)
)
# Return the workflow graph
return steps.states.Chain([create_experiment_step, sagemaker_jobs])