in main/cdk/train-stack.ts [32:290]
constructor(app: cdk.App, id: string, props: TrainStackProps) {
super(app, id, props);
this.trainComputeLambda = new lambda.Function(
this,
"trainComputeFunction",
{
code: lambda.Code.fromAsset("src/training-compute/build"),
handler: "training-compute.handler",
runtime: lambda.Runtime.PYTHON_3_8,
memorySize: 256,
timeout: cdk.Duration.minutes(5),
environment: {
MESSAGE_LAMBDA_ARN: props.messageLambda.functionArn,
IMAGE_MANAGEMENT_LAMBDA_ARN: props.imageManagementLambda.functionArn,
TRAIN_CONFIGURATION_LAMBDA_ARN: props.trainingConfigurationLambda.functionArn,
ARTIFACT_LAMBDA_ARN: props.artifactLambda.functionArn,
BUCKET: props.dataBucket.bucketName
},
}
);
this.trainBuildLambda = new lambda.Function(
this,
"trainBuildFunction",
{
code: lambda.Code.fromAsset("src/training-build/build"),
handler: "training-build.handler",
runtime: lambda.Runtime.PYTHON_3_8,
memorySize: 512,
timeout: cdk.Duration.minutes(15),
environment: {
MESSAGE_LAMBDA_ARN: props.messageLambda.functionArn,
IMAGE_MANAGEMENT_LAMBDA_ARN: props.imageManagementLambda.functionArn,
TRAIN_CONFIGURATION_LAMBDA_ARN: props.trainingConfigurationLambda.functionArn,
ARTIFACT_LAMBDA_ARN: props.artifactLambda.functionArn,
BUCKET: props.dataBucket.bucketName
},
}
);
///////////////////////////////////////////
//
// Train StepFunction
//
///////////////////////////////////////////
const trainInfoRequest = new sfn.Pass(this, "Train Info Request", {
parameters: {
method: "getTraining",
trainId: sfn.JsonPath.stringAt("$.trainId"),
},
resultPath: '$.trainInfoRequest'
});
const trainInfo = new tasks.LambdaInvoke(this, "Train Info", {
lambdaFunction: props.trainingConfigurationLambda,
resultPath: '$.trainInfo',
inputPath: '$.trainInfoRequest',
});
const embeddingInfoRequest = new sfn.Pass(this, "Embedding Info Request", {
parameters: {
method: "getEmbeddingInfo",
embeddingName: sfn.JsonPath.stringAt('$.trainInfo.Payload.body.embeddingName')
},
resultPath: '$.embeddingInfoRequest'
});
const embeddingInfo = new tasks.LambdaInvoke(this, "Embedding Info", {
lambdaFunction: props.trainingConfigurationLambda,
resultPath: sfn.JsonPath.stringAt('$.embeddingInfo'),
inputPath: '$.embeddingInfoRequest',
});
const plateSurveyRequest = new sfn.Pass(this, "Plate Survey Request", {
parameters: {
method: "listCompatiblePlates",
width: sfn.JsonPath.stringAt('$.embeddingInfo.Payload.body.Item.inputWidth'),
height: sfn.JsonPath.stringAt('$.embeddingInfo.Payload.body.Item.inputHeight'),
depth: sfn.JsonPath.stringAt('$.embeddingInfo.Payload.body.Item.inputDepth'),
channels: sfn.JsonPath.stringAt('$.embeddingInfo.Payload.body.Item.inputChannels')
},
resultPath: '$.plateSurveyRequest'
});
const plateList = new tasks.LambdaInvoke(this, "Plate List", {
lambdaFunction: props.imageManagementLambda,
resultPath: sfn.JsonPath.stringAt('$.plateList'),
inputPath: '$.plateSurveyRequest'
});
const plateProcessor = new tasks.LambdaInvoke(this, "Process Plate", {
lambdaFunction: props.processPlateLambda,
outputPath: '$.Payload.body'
});
const plateWait = new sfn.Wait(this, "Plate Wait", {
time: sfn.WaitTime.duration(cdk.Duration.seconds(240))
});
const plateStatusInput = new sfn.Pass(this, "Plate Status Input", {
parameters: {
method: "describeExecution",
executionArn: sfn.JsonPath.stringAt('$.executionArn')
}
});
const plateStatus = new tasks.LambdaInvoke(this, "Plate Status", {
lambdaFunction: props.processPlateLambda,
outputPath: '$.Payload.body'
});
const plateNotRunning = new sfn.Pass(this, "Plate Not Running", {
parameters: {
status: sfn.JsonPath.stringAt('$.status')
}
});
const plateSequence = plateProcessor
.next(plateWait)
.next(plateStatusInput)
.next(plateStatus)
.next(new sfn.Choice(this, 'Plate Sfn Complete?')
.when(sfn.Condition.stringEquals('$.status', 'RUNNING'), plateWait)
.otherwise(plateNotRunning));
const plateProcessMap = new sfn.Map(this, "Plate Process Map", {
maxConcurrency: 5,
itemsPath: '$.plateList.Payload.body',
resultPath: '$.plateProcessMapResult',
parameters: {
method: "processPlate",
embeddingName: sfn.JsonPath.stringAt('$.trainInfo.Payload.body.embeddingName'),
'plateId.$' : "$$.Map.Item.Value.plateId"
}
});
plateProcessMap.iterator(plateSequence);
const trainBuildInput = new sfn.Pass(this, "Train Build Input", {
parameters: {
trainId: sfn.JsonPath.stringAt("$.trainId"),
useSpot: sfn.JsonPath.stringAt("$.useSpot")
}
});
const trainBuild = new tasks.LambdaInvoke(this, "Train Build", {
lambdaFunction: this.trainBuildLambda,
resultPath: '$.trainBuildOutput'
});
const trainComputeInput= new sfn.Pass(this, "Train Compute Input", {
parameters: {
trainId: sfn.JsonPath.stringAt("$.trainId"),
useSpot: sfn.JsonPath.stringAt("$.useSpot")
}
});
const trainCompute = new tasks.LambdaInvoke(this, "Train Compute", {
lambdaFunction: this.trainComputeLambda,
resultPath: '$.trainComputeOutput'
});
const skipProcessPlate = new sfn.Pass(this, "Skip Process Plate", {
});
/////////////////////////////////
const trainWait = new sfn.Wait(this, "Train Wait", {
time: sfn.WaitTime.duration(cdk.Duration.seconds(240))
});
const trainStatusInput = new sfn.Pass(this, "Train Status Input", {
parameters: {
method: "getTrainingJobInfo",
trainingJobName: sfn.JsonPath.stringAt('$.trainComputeOutput.Payload.body.trainingJobName')
}
});
const trainStatus = new tasks.LambdaInvoke(this, "Train Status", {
lambdaFunction: props.trainingConfigurationLambda,
resultPath: '$.trainStatus'
});
const trainNotRunning = new sfn.Pass(this, "Train Not Running", {
parameters: {
status: sfn.JsonPath.stringAt('$.trainStatus')
}
});
// 'TrainingJobStatus': 'InProgress'|'Completed'|'Failed'|'Stopping'|'Stopped',
const trainSequence = trainComputeInput
.next(trainCompute)
.next(trainStatusInput)
.next(trainWait)
.next(trainStatus)
.next(new sfn.Choice(this, 'Train Complete?')
.when(sfn.Condition.stringEquals('$.trainStatus.Payload.body.TrainingJobStatus', 'Completed'), trainNotRunning)
.when(sfn.Condition.stringEquals('$.trainStatus.Payload.body.TrainingJobStatus', 'Stopped'), trainNotRunning)
.when(sfn.Condition.stringEquals('$.trainStatus.Payload.body.TrainingJobStatus', 'Failed'), trainNotRunning)
.otherwise(trainWait));
/////////////////////////////////
const trainStepFunctionDef = trainInfoRequest
.next(trainInfo)
.next(embeddingInfoRequest)
.next(embeddingInfo)
.next(plateSurveyRequest)
.next(plateList)
.next(new sfn.Choice(this, 'Skip ProcessPlate?')
.when(sfn.Condition.stringEquals('$.trainInfo.Payload.body.executeProcessPlate', 'true'), plateProcessMap)
.otherwise(skipProcessPlate)
.afterwards())
.next(trainBuildInput)
.next(trainBuild)
.next(trainSequence)
const trainLogGroup = new logs.LogGroup(this, "TrainLogGroup");
this.trainStateMachine = new sfn.StateMachine(
this,
"Train StateMachine",
{
definition: trainStepFunctionDef,
timeout: cdk.Duration.hours(48),
logs: {
destination: trainLogGroup,
level: sfn.LogLevel.ALL,
},
}
);
//////////////////////////////////////////
this.trainLambda = new lambda.Function(
this,
"trainFunction",
{
code: lambda.Code.fromAsset("src/train/build"),
handler: "train.handler",
runtime: lambda.Runtime.NODEJS_12_X,
memorySize: 512,
timeout: cdk.Duration.minutes(15),
environment: {
MESSAGE_LAMBDA_ARN: props.messageLambda.functionArn,
IMAGE_MANAGEMENT_LAMBDA_ARN: props.imageManagementLambda.functionArn,
TRAIN_CONFIGURATION_LAMBDA_ARN: props.trainingConfigurationLambda.functionArn,
TRAIN_SFN_ARN: this.trainStateMachine.stateMachineArn,
TRAIN_BUILD_LAMBDA_ARN: this.trainBuildLambda.functionArn
},
}
);
// This is not necessary if the function is included in external iam policy in resource-permissions-stack
//const trainLambdaOutput = new cdk.CfnOutput(this, 'trainLambda', { value: this.trainLambda.functionArn } )
}