def create_graph()

in model/run_pipeline.py [0:0]


def create_graph(create_experiment_step, baseline_step, training_step):
    sagemaker_jobs = steps.states.Parallel("SageMaker Jobs")
    sagemaker_jobs.add_branch(baseline_step)
    sagemaker_jobs.add_branch(training_step)

    # Do we need specific failure for the jobs for group?
    sagemaker_jobs.add_catch(
        stepfunctions.steps.states.Catch(
            error_equals=["States.TaskFailed"],
            next_step=stepfunctions.steps.states.Fail(
                "SageMaker Jobs failed", cause="SageMakerJobsFailed"
            ),
        )
    )

    # Return the workflow graph
    return steps.states.Chain([create_experiment_step, sagemaker_jobs])