in main/src/embedding-compute/embedding-compute.py [0:0]
def handler(event, context):
trainInfo = event['trainInfo']
embeddingInfo = event['embeddingInfo']
embeddingName = embeddingInfo['embeddingName']
trainId = trainInfo['trainId']
plateId = event['plateId']
imageId = event['imageId']
print(trainInfo)
print(embeddingInfo)
print(embeddingName)
print(plateId)
print(imageId)
trainingScriptBucket=embeddingInfo['modelTrainingScriptBucket']
trainingScriptKey=embeddingInfo['modelTrainingScriptKey']
trainingJobName=trainInfo['sagemakerJobName']
print(trainingScriptBucket)
print(trainingScriptKey)
print(trainingJobName)
if 'trainingJobInfo' not in event:
trainingJobInfo = smc.describe_training_job(TrainingJobName=trainingJobName)
else:
trainingJobInfo = event['trainingJobInfo']
localModelDir = os.path.join('/tmp/',trainingJobName)
localModelGz = os.path.join(localModelDir, 'model.tar.gz')
localModelPath = os.path.join(localModelDir, 'model.pth')
if os.path.isfile(localModelPath):
print("Using local model at location {}".format(localModelPath))
else:
print("Creating {} and downloading model.tar.gz".format(localModelDir))
if not os.path.isdir(localModelDir):
os.system("rm -r /tmp/*")
os.mkdir(localModelDir)
if 'ModelArtifacts' in trainingJobInfo:
modelArtifacts=trainingJobInfo['ModelArtifacts']
s3ModelPath=modelArtifacts['S3ModelArtifacts']
print(s3ModelPath)
copyS3ObjectPathToLocalPath(s3ModelPath, localModelGz)
os.chdir(localModelDir)
tar = tarfile.open("model.tar.gz")
tar.extractall()
tar.close()
else:
print("Model not available")
response = {
'statusCode': 400,
'body': 'Could not obtain model or checkpoint'
}
return response
localTrainScript = os.path.join(localModelDir, 'bioimstrain.py')
if os.path.isfile(localTrainScript):
print("Using train script {}".format(localTrainScript))
else:
print("Downloading trainscript from bucket={} key={}".format(trainingScriptBucket, trainingScriptKey))
s3c.download_file(trainingScriptBucket, trainingScriptKey, localTrainScript)
os.chdir(localModelDir)
sys.path.insert(0, ".")
import bioimstrain
model=bioimstrain.model_fn(localModelDir)
print(model)
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
roiTrainKey = bp.getTrainKey(embeddingName, plateId, imageId)
data = getNumpyArrayFromS3(BUCKET, roiTrainKey).astype(np.float32)
##########################################################################
# TODO: Handle 3D data
#
# The dev model assumes input with 4 dimensions. It assumes 2D rather then 3D data, with 3 channels:
# image#, channels, y, x
#
# With channels=3, x and y = 128
#
# However, actual input is 3D and will be <#>, 3, 1, 128, 128
#
##########################################################################
print("v7")
if data.shape[0] > 0:
min=np.min(data)
max=np.max(data)
print(data.shape)
print("pre roi min={} max={}".format(min, max))
data /= 65535.0
min=np.min(data)
max=np.max(data)
print("post roi min={} max={}".format(min, max))
bufferSize=data.shape[0]
bufferRemainder=bufferSize%8
if bufferRemainder>0:
bufferSize += (8-bufferRemainder)
dataDimArr = [bufferSize, 3, 128, 128]
dimTuple = tuple(dataDimArr)
model_data = np.zeros(dimTuple, dtype=np.float32)
for i in range(data.shape[0]):
model_data[i][0]=data[i][0][0]
model_data[i][1]=data[i][1][0]
model_data[i][2]=data[i][2][0]
print(model_data.shape)
t1 = torch.tensor(model_data)
print(t1.shape)
embeddingDataTensor = model(t1)
t2 = embeddingDataTensor.detach()
embeddingData = t2.numpy()
embeddingResult = embeddingData[:data.shape[0]]
ea = np.mean(embeddingResult, axis=0)
print(embeddingResult.shape)
print(ea.shape)
roiEmbeddingKey = bp.getRoiEmbeddingKey(imageId, plateId, trainId)
writeNumpyToS3(embeddingResult, BUCKET, roiEmbeddingKey)
artifactStr = 's3key#' + roiEmbeddingKey
roiEmbeddingArtifact = {
'contextId': imageId,
'trainId': trainId,
'artifact': artifactStr
}
createArtifact(roiEmbeddingArtifact)
ea64 = base64.b64encode(ea)
applyEmbeddingResult(imageId, trainId, ea64, roiEmbeddingKey)
# de64 = base64.decodebytes(ea64)
# ea2 = np.frombuffer(de64, dtype=np.float32)
# isOk = np.allclose(ea, ea2)
# print("base64 check={}".format(isOk))
else:
print("No input data found - skipping")
ea64=''
response = {
'statusCode': 200,
'body': 'success'
}
return response