in main/src/training-compute/training-compute.py [0:0]
def handler(event, context):
trainId = event['trainId']
useSpotArg = event['useSpot']
useSpot=True
if useSpotArg.lower()=='false':
useSpot=False
uniqueId = su.uuid()
trainingConfigurationClient = bioims.client('training-configuration')
trainInfo = trainingConfigurationClient.getTraining(trainId)
embeddingName = trainInfo['embeddingName']
embeddingInfo = trainingConfigurationClient.getEmbeddingInfo(embeddingName)
trainScriptBucket = embeddingInfo['modelTrainingScriptBucket']
trainScriptKey =embeddingInfo['modelTrainingScriptKey']
localTrainingScript = '/tmp/bioims-training-script.py'
getS3TextObjectWriteToPath(trainScriptBucket, trainScriptKey, localTrainingScript)
trainListArtifactKey = bp.getTrainImageListArtifactPath(trainId)
sagemaker_session = sagemaker.Session()
sagemaker_bucket = sagemaker_session.default_bucket()
sagemaker_role = sagemaker.get_execution_role()
py_version = '1.6.0'
instance_type = embeddingInfo['trainingInstanceType']
trainingHyperparameters = embeddingInfo['trainingHyperparameters']
fsxInfo = getFsxInfo()
print(fsxInfo)
directory_path = '/' + fsxInfo['mountName']
sgIds=[]
sgIds.append(fsxInfo['securityGroup'])
jobName = 'bioims-' + trainId + '-' + uniqueId
checkpoint_s3_uri = "s3://" + sagemaker_bucket + "/checkpoints/" + jobName
file_system_input = FileSystemInput(file_system_id=fsxInfo['fsxId'],
file_system_type='FSxLustre',
directory_path=directory_path,
file_system_access_mode='ro')
trainingHyperparameters['train_list_file'] = trainListArtifactKey
if useSpot:
estimator = PyTorch(entry_point=localTrainingScript,
role=sagemaker_role,
framework_version=py_version,
instance_count=1,
instance_type=instance_type,
py_version='py36',
image_name='763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04',
subnets=fsxInfo['subnetIds'],
security_group_ids=sgIds,
hyperparameters = trainingHyperparameters,
train_use_spot_instances=True,
train_max_wait=100000,
train_max_run=100000,
checkpoint_s3_uri = checkpoint_s3_uri,
debugger_hook_config=False
)
else:
estimator = PyTorch(entry_point=localTrainingScript,
role=sagemaker_role,
framework_version=py_version,
instance_count=1,
instance_type=instance_type,
py_version='py36',
image_name='763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04',
subnets=fsxInfo['subnetIds'],
security_group_ids=sgIds,
hyperparameters = trainingHyperparameters,
train_use_spot_instances=False,
checkpoint_s3_uri = checkpoint_s3_uri,
debugger_hook_config=False
)
trainingConfigurationClient.updateTraining(trainId, 'sagemakerJobName', jobName)
estimator.fit(file_system_input, wait=False, job_name=jobName)
responseInfo = {
'trainingJobName': jobName
}
response = {
'statusCode': 200,
'body': responseInfo
}
return response