def handler()

in main/src/training-compute/training-compute.py [0:0]


def handler(event, context):
    trainId = event['trainId']
    useSpotArg = event['useSpot']
    useSpot=True
    if useSpotArg.lower()=='false':
        useSpot=False
    uniqueId = su.uuid()
    trainingConfigurationClient = bioims.client('training-configuration')
    trainInfo = trainingConfigurationClient.getTraining(trainId)
    embeddingName = trainInfo['embeddingName']
    embeddingInfo = trainingConfigurationClient.getEmbeddingInfo(embeddingName)
    trainScriptBucket = embeddingInfo['modelTrainingScriptBucket']
    trainScriptKey =embeddingInfo['modelTrainingScriptKey']
    localTrainingScript = '/tmp/bioims-training-script.py'
    getS3TextObjectWriteToPath(trainScriptBucket, trainScriptKey, localTrainingScript)
    trainListArtifactKey = bp.getTrainImageListArtifactPath(trainId)
    sagemaker_session = sagemaker.Session()
    sagemaker_bucket = sagemaker_session.default_bucket()
    sagemaker_role = sagemaker.get_execution_role()
    py_version = '1.6.0'
    instance_type = embeddingInfo['trainingInstanceType']
    trainingHyperparameters = embeddingInfo['trainingHyperparameters']
    fsxInfo = getFsxInfo()
    print(fsxInfo)
    directory_path = '/' + fsxInfo['mountName']
    sgIds=[]
    sgIds.append(fsxInfo['securityGroup'])
    jobName = 'bioims-' + trainId + '-' + uniqueId
    checkpoint_s3_uri = "s3://" + sagemaker_bucket + "/checkpoints/" + jobName

    file_system_input = FileSystemInput(file_system_id=fsxInfo['fsxId'],
                                    file_system_type='FSxLustre',
                                    directory_path=directory_path,
                                    file_system_access_mode='ro')
    
    trainingHyperparameters['train_list_file'] = trainListArtifactKey

    if useSpot:
        estimator = PyTorch(entry_point=localTrainingScript,
                    role=sagemaker_role,
                    framework_version=py_version,
                    instance_count=1,
                    instance_type=instance_type,
                    py_version='py36',
                    image_name='763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04',
                    subnets=fsxInfo['subnetIds'],
                    security_group_ids=sgIds,
                    hyperparameters = trainingHyperparameters,
                    train_use_spot_instances=True,
                    train_max_wait=100000,
                    train_max_run=100000,
                    checkpoint_s3_uri = checkpoint_s3_uri,
                    debugger_hook_config=False
                )
    else:
              estimator = PyTorch(entry_point=localTrainingScript,
                    role=sagemaker_role,
                    framework_version=py_version,
                    instance_count=1,
                    instance_type=instance_type,
                    py_version='py36',
                    image_name='763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04',
                    subnets=fsxInfo['subnetIds'],
                    security_group_ids=sgIds,
                    hyperparameters = trainingHyperparameters,
                    train_use_spot_instances=False,
                    checkpoint_s3_uri = checkpoint_s3_uri,
                    debugger_hook_config=False
                )
                    
    trainingConfigurationClient.updateTraining(trainId, 'sagemakerJobName', jobName)

    estimator.fit(file_system_input, wait=False, job_name=jobName)
    
    responseInfo = {
        'trainingJobName': jobName
    }

    response = {
        'statusCode': 200,
        'body': responseInfo
    }

    return response