def handler()

in main/src/embedding-compute/embedding-compute.py [0:0]


def handler(event, context):
    trainInfo = event['trainInfo']
    embeddingInfo = event['embeddingInfo']
    embeddingName = embeddingInfo['embeddingName']
    trainId = trainInfo['trainId']
    plateId = event['plateId']
    imageId = event['imageId']
    print(trainInfo)
    print(embeddingInfo)
    print(embeddingName)
    print(plateId)
    print(imageId)
    trainingScriptBucket=embeddingInfo['modelTrainingScriptBucket']
    trainingScriptKey=embeddingInfo['modelTrainingScriptKey']
    trainingJobName=trainInfo['sagemakerJobName']
    print(trainingScriptBucket)
    print(trainingScriptKey)
    print(trainingJobName)

    if 'trainingJobInfo' not in event: 
        trainingJobInfo = smc.describe_training_job(TrainingJobName=trainingJobName)
    else:
        trainingJobInfo = event['trainingJobInfo']
    
    localModelDir = os.path.join('/tmp/',trainingJobName)
    localModelGz = os.path.join(localModelDir, 'model.tar.gz')
    localModelPath = os.path.join(localModelDir, 'model.pth')
    if os.path.isfile(localModelPath):
        print("Using local model at location {}".format(localModelPath))
    else:
        print("Creating {} and downloading model.tar.gz".format(localModelDir))
        if not os.path.isdir(localModelDir):
            os.system("rm -r /tmp/*")
            os.mkdir(localModelDir)
        if 'ModelArtifacts' in trainingJobInfo:
            modelArtifacts=trainingJobInfo['ModelArtifacts']
            s3ModelPath=modelArtifacts['S3ModelArtifacts']
            print(s3ModelPath)
            copyS3ObjectPathToLocalPath(s3ModelPath, localModelGz)
            os.chdir(localModelDir)
            tar = tarfile.open("model.tar.gz")
            tar.extractall()
            tar.close()
        else:
            print("Model not available")
            response = {
                'statusCode': 400,
                'body': 'Could not obtain model or checkpoint'
            }
            return response

    localTrainScript = os.path.join(localModelDir, 'bioimstrain.py')
    if os.path.isfile(localTrainScript):
        print("Using train script {}".format(localTrainScript))
    else:
        print("Downloading trainscript from bucket={} key={}".format(trainingScriptBucket, trainingScriptKey))
        s3c.download_file(trainingScriptBucket, trainingScriptKey, localTrainScript)

    os.chdir(localModelDir)
    sys.path.insert(0, ".")
    import bioimstrain
    model=bioimstrain.model_fn(localModelDir)
    print(model)
    for m in model.modules():
      if isinstance(m, nn.BatchNorm2d):
        m.eval()
    
    roiTrainKey = bp.getTrainKey(embeddingName, plateId, imageId)
    data = getNumpyArrayFromS3(BUCKET, roiTrainKey).astype(np.float32)

    ##########################################################################
    # TODO: Handle 3D data 
    #
    # The dev model assumes input with 4 dimensions. It assumes 2D rather then 3D data, with 3 channels:
    #    image#, channels, y, x 
    #  
    # With channels=3, x and y = 128
    #
    # However, actual input is 3D and will be <#>, 3, 1, 128, 128
    #
    ##########################################################################

    print("v7")

    if data.shape[0] > 0:
        min=np.min(data)
        max=np.max(data)
        print(data.shape)
        print("pre roi min={} max={}".format(min, max))
        data /= 65535.0
        min=np.min(data)
        max=np.max(data)
        print("post roi min={} max={}".format(min, max))
        bufferSize=data.shape[0]
        bufferRemainder=bufferSize%8
        if bufferRemainder>0:
            bufferSize += (8-bufferRemainder)
        dataDimArr = [bufferSize, 3, 128, 128]
        dimTuple = tuple(dataDimArr)
        model_data = np.zeros(dimTuple, dtype=np.float32)
        for i in range(data.shape[0]):
            model_data[i][0]=data[i][0][0]
            model_data[i][1]=data[i][1][0]
            model_data[i][2]=data[i][2][0]
    
        print(model_data.shape)
        t1 = torch.tensor(model_data)
        print(t1.shape)
        embeddingDataTensor = model(t1)
        t2 = embeddingDataTensor.detach()
        embeddingData = t2.numpy()
        embeddingResult = embeddingData[:data.shape[0]]
        ea = np.mean(embeddingResult, axis=0)
        print(embeddingResult.shape)
        print(ea.shape)
        roiEmbeddingKey = bp.getRoiEmbeddingKey(imageId, plateId, trainId)
        writeNumpyToS3(embeddingResult, BUCKET, roiEmbeddingKey)
        artifactStr = 's3key#' + roiEmbeddingKey
        roiEmbeddingArtifact = {
            'contextId': imageId,
            'trainId': trainId,
            'artifact': artifactStr
        }
        createArtifact(roiEmbeddingArtifact)
        ea64 = base64.b64encode(ea)
        applyEmbeddingResult(imageId, trainId, ea64, roiEmbeddingKey)
        # de64 = base64.decodebytes(ea64)
        # ea2 = np.frombuffer(de64, dtype=np.float32)
        # isOk = np.allclose(ea, ea2)
        # print("base64 check={}".format(isOk))

    else:
        print("No input data found - skipping")
        ea64=''

    response = {
        'statusCode': 200,
        'body': 'success'
    }
    
    return response