def preprocess()

in MMS/dicom_featurization_service.py [0:0]


    def preprocess(self, data):
        ## Scales, crops, and normalizes a DICOM image for a PyTorch model, returns an Numpy array
        img_trsfm = transforms.Compose([
            transforms.ToPILImage(),
            # convert to PIL image. PIL does not support multi-channel floating point data, workaround is to transform first then expand to 3 channel
            transforms.Resize([516, 516], interpolation=2),  # resize and crop to [512x512]
            transforms.RandomCrop(512),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            # classifier is fine-tuned on pretrained imagenet using mimic-jpeg images [HxWx3]
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        for k, v in data[0].items():
            logger.info('received data in post request: {0}: {1}'.format(k, v.decode()))
        ## download dicom from s3
        bucket = data[0].get("dicombucket").decode()
        key = data[0].get("dicomkey").decode()
        if self.access_key is not None and self.secret_key is not None and self.aws_region is not None:
            s3_client = boto3.client(
                's3',
                region_name=self.aws_region,
                aws_access_key_id=self.access_key,
                aws_secret_access_key=self.secret_key
            )
            ddb_client = boto3.client(
                'dynamodb',
                region_name=self.aws_region,
                aws_access_key_id=self.access_key,
                aws_secret_access_key=self.secret_key
            )
        else:
            my_session = boto3.session.Session()
            s3_client = my_session.client('s3')
            ddb_client = my_session.client('dynamodb')

        s3_client.download_file(bucket, key, key)
        dicom = pydicom.dcmread(key)
        os.remove(key)
        if dicom.get_item('ViewPosition') and len(dicom.data_element('ViewPosition').value)>0:
            dicom_array = dicom.pixel_array.astype(float)
            ## convert dicom to png thumbnail
            dicom_array_scaled = (np.maximum(dicom_array,0) / dicom_array.max()) * 255.0
            dicom_array_scaled = np.uint8(dicom_array_scaled)
            out_png = key.replace('.dcm', '.full.png')
            with open(out_png, 'wb') as out_png_file:
                w = png.Writer(dicom_array.shape[1], dicom_array.shape[0], greyscale=True)
                w.write(out_png_file, dicom_array_scaled)
            pilimage = Image.open(out_png)
            os.remove(out_png)  
            newsize = (int(pilimage.size[0]/10), int(pilimage.size[1]/10)) ## thumbnail is 1/10 in size
            pilimage = pilimage.resize(newsize) 
            thumbnail_png = key.replace('.dcm', '.png')
            pilimage.save(thumbnail_png)
            ## upload png thumbnail to s3
            pngbucket = data[0].get("pngbucket").decode()
            prefix = data[0].get("prefix").decode()
            s3_thumbnail_png = prefix + '/' + thumbnail_png
            try:
                s3_client.head_object(Bucket=pngbucket, Key=s3_thumbnail_png)
            except:
                try:
                    s3_client.upload_file(thumbnail_png, pngbucket, s3_thumbnail_png)
                    os.remove(thumbnail_png)
                except ClientError as e:
                    print('upload thumbnail png error: {}'.format(e))
            
            ## dicom metadata to be saved in dynamodb
            metadata_dict = {"ImageId": {"S": key.replace('.dcm', '')}}
            metadata_dict["ViewPosition"] = {"S": dicom.data_element('ViewPosition').value}
            metadata_dict["Bucket"] = {"S": pngbucket}
            metadata_dict["Key"] = {"S": thumbnail_png}
            metadata_dict["ReportId"] = {"S": 's'+dicom.data_element('StudyID').value}
            metadata_dict["Modality"] = {"S": dicom.data_element('Modality').value}
            metadata_dict["BodyPartExamined"] = {"S": dicom.data_element('BodyPartExamined').value}

            ddb_table = data[0].get("ddb_table").decode()
            response = ddb_client.put_item(
                TableName=ddb_table,
                Item=metadata_dict
            )
            logger.info('Dynamodb create item status: {}'.format(response['ResponseMetadata']['HTTPStatusCode']))

            ## transform dicom arrary for pytorch model featurization
            X = np.asarray(dicom_array, np.float32) / (2**dicom.BitsStored-1)
            image = img_trsfm(X).unsqueeze(0)
            es_endpoint = data[0].get("es_endpoint").decode()
            return { 'id': key.split('.')[0], 'image': image, 'ViewPosition': dicom.data_element('ViewPosition').value, 'ES': es_endpoint }
        else:
            return None