sdk/python/foundation-models/healthcare-ai/medimageparse/sematic-segmentation-demo/processing_utils.py (295 lines of code) (raw):

import numpy as np import matplotlib.pyplot as plt import pydicom import nibabel as nib import SimpleITK as sitk from io import BytesIO from PIL import Image from skimage import transform, measure import urllib.request import json import base64 import cv2 """ This script contains utility functions for reading and processing different imaging modalities. """ CT_WINDOWS = { "abdomen": [-150, 250], "lung": [-1000, 1000], "pelvis": [-55, 200], "liver": [-25, 230], "colon": [-68, 187], "pancreas": [-100, 200], } def process_intensity_image(image_data, is_CT, site=None): # process intensity-based image. If CT, apply site specific windowing # image_data: 2D numpy array of shape (H, W) # return: 3-channel numpy array of shape (H, W, 3) as model input if is_CT: # process image with windowing if site and site in CT_WINDOWS: window = CT_WINDOWS[site] else: raise ValueError(f"Please choose CT site from {CT_WINDOWS.keys()}") lower_bound, upper_bound = window else: # process image with intensity range 0.5-99.5 percentile lower_bound, upper_bound = np.percentile( image_data[image_data > 0], 0.5 ), np.percentile(image_data[image_data > 0], 99.5) image_data_pre = np.clip(image_data, lower_bound, upper_bound) image_data_pre = ( (image_data_pre - image_data_pre.min()) / (image_data_pre.max() - image_data_pre.min()) * 255.0 ) # pad to square with equal padding on both sides shape = image_data_pre.shape if shape[0] > shape[1]: pad = (shape[0] - shape[1]) // 2 pad_width = ((0, 0), (pad, pad)) elif shape[0] < shape[1]: pad = (shape[1] - shape[0]) // 2 pad_width = ((pad, pad), (0, 0)) else: pad_width = None if pad_width is not None: image_data_pre = np.pad( image_data_pre, pad_width, "constant", constant_values=0 ) # Important: resize image to 1024x1024 image_size = 1024 resize_image = transform.resize( image_data_pre, (image_size, image_size), order=3, mode="constant", preserve_range=True, anti_aliasing=True, ) # convert to 3-channel image resize_image = np.stack([resize_image] * 3, axis=-1) return resize_image.astype(np.uint8) def read_dicom(image_path, is_CT, site=None): # read dicom file and return pixel data # dicom_file: str, path to dicom file # is_CT: bool, whether image is CT or not # site: str, one of CT_WINDOWS.keys() # return: 2D numpy array of shape (H, W) ds = pydicom.dcmread(image_path) image_array = ds.pixel_array * ds.RescaleSlope + ds.RescaleIntercept image_array = process_intensity_image(image_array, is_CT, site) # Step 1: Convert NumPy array to an image image = Image.fromarray(image_array) # Step 2: Save image to a BytesIO buffer buffer = BytesIO() image.save(buffer, format="PNG") # or "JPEG", depending on your preference buffer.seek(0) return buffer def read_nifti( image_path, is_CT, slice_idx, site=None, HW_index=(0, 1), channel_idx=None ): # read nifti file and return pixel data # image_path: str, path to nifti file # is_CT: bool, whether image is CT or not # slice_idx: int, slice index to read # site: str, one of CT_WINDOWS.keys() # HW_index: tuple, index of height and width in the image shape # return: 2D numpy array of shape (H, W) nii = nib.load(image_path) image_array = nii.get_fdata() if HW_index != (0, 1): image_array = np.moveaxis(image_array, HW_index, (0, 1)) # get slice if channel_idx is None: image_array = image_array[:, :, slice_idx] else: image_array = image_array[:, :, slice_idx, channel_idx] image_array = process_intensity_image(image_array, is_CT, site) # Step 1: Convert NumPy array to an image image = Image.fromarray(image_array) # Step 2: Save image to a BytesIO buffer buffer = BytesIO() image.save(buffer, format="PNG") # or "JPEG", depending on your preference buffer.seek(0) return buffer def read_rgb(image_path): # read RGB image and return resized pixel data # image_path: str, path to RGB image # return: BytesIO buffer # read image into numpy array image = Image.open(image_path) image = np.array(image) if len(image.shape) == 2: image = np.stack([image] * 3, axis=-1) elif image.shape[2] == 4: image = image[:, :, :3] # pad to square with equal padding on both sides shape = image.shape if shape[0] > shape[1]: pad = (shape[0] - shape[1]) // 2 pad_width = ((0, 0), (pad, pad), (0, 0)) elif shape[0] < shape[1]: pad = (shape[1] - shape[0]) // 2 pad_width = ((pad, pad), (0, 0), (0, 0)) else: pad_width = None if pad_width is not None: image = np.pad(image, pad_width, "constant", constant_values=0) # resize image to 1024x1024 for each channel image_size = 1024 resize_image = np.zeros((image_size, image_size, 3), dtype=np.uint8) for i in range(3): resize_image[:, :, i] = transform.resize( image[:, :, i], (image_size, image_size), order=3, mode="constant", preserve_range=True, anti_aliasing=True, ) # Step 1: Convert NumPy array to an image resize_image = Image.fromarray(resize_image) # Step 2: Save image to a BytesIO buffer buffer = BytesIO() resize_image.save(buffer, format="PNG") # or "JPEG", depending on your preference buffer.seek(0) return buffer def get_instances(mask): # get intances from binary mask seg = sitk.GetImageFromArray(mask) filled = sitk.BinaryFillhole(seg) d = sitk.SignedMaurerDistanceMap( filled, insideIsPositive=False, squaredDistance=False, useImageSpacing=False ) ws = sitk.MorphologicalWatershed(d, markWatershedLine=False, level=1) ws = sitk.Mask(ws, sitk.Cast(seg, ws.GetPixelID())) ins_mask = sitk.GetArrayFromImage(ws) # filter out instances with small area outliers props = measure.regionprops_table(ins_mask, properties=("label", "area")) mean_area = np.mean(props["area"]) std_area = np.std(props["area"]) threshold = mean_area - 2 * std_area - 1 ins_mask_filtered = ins_mask.copy() for i, area in zip(props["label"], props["area"]): if area < threshold: ins_mask_filtered[ins_mask == i] = 0 return ins_mask_filtered def read_image(image_path): """Read image pixel data from a file path. Return image pixel data as an array. """ with open(image_path, "rb") as f: return f.read() def decode_json_to_array(json_encoded): """Decode an image pixel data array in JSON. Return image pixel data as an array. """ # Parse the JSON string array_metadata = json.loads(json_encoded) # Extract Base64 string, shape, and dtype base64_encoded = array_metadata["data"] shape = tuple(array_metadata["shape"]) dtype = np.dtype(array_metadata["dtype"]) # Decode Base64 to byte string array_bytes = base64.b64decode(base64_encoded) # Convert byte string back to NumPy array and reshape array = np.frombuffer(array_bytes, dtype=dtype).reshape(shape) return array def plot_segmentation_masks( original_image, segmentation_masks, text_prompt=None, aspect_ratio="auto" ): """ Plot a list of segmentation masks over an image with a controllable aspect ratio. Parameters: - original_image: numpy array The original image to be displayed as the background. It should be a 2D (grayscale) or 3D (RGB) array. - segmentation_masks: list of numpy arrays A list where each element is a segmentation mask corresponding to the original image. Each mask should be a 2D array with the same spatial dimensions as the original image. - text_prompt: string, optional A string containing mask names separated by '&'. If provided, these names will be used as titles for the masks. Example: 'Cell Nuclei & Cytoplasm & Background' - aspect_ratio: float or string, optional The aspect ratio for each subplot. Can be a numeric value, 'auto', or 'equal'. - If 'equal', each subplot will have equal aspect ratio (no distortion). - If 'auto' (default), the aspect ratio is determined automatically. - If a numeric value is provided, it sets the aspect ratio as y/x. For example, aspect_ratio=1 makes y-axis equal to x-axis. The function displays the original image alongside each segmentation mask overlaid in red. """ # Ensure the image has at least 3 channels (RGB) if original_image.ndim == 2: # Convert grayscale to RGB by stacking the 2D array into three channels original_image = np.stack((original_image,) * 3, axis=-1) elif original_image.shape[2] > 3: # If more than 3 channels, take the first three original_image = original_image[:, :, :3] num_masks = len(segmentation_masks) # Create subplots: one for the original image and one for each mask fig, ax = plt.subplots(1, num_masks + 1, figsize=(5 * (num_masks + 1), 5)) # If there's only one subplot, wrap it in a list for consistency if num_masks == 0: ax = [ax] elif num_masks == 1: ax = [ax[0], ax[1]] # Display the original image ax[0].imshow(original_image) ax[0].set_title("Original Image") ax[0].set_aspect(aspect_ratio) # Remove axes for all subplots for a in ax: a.axis("off") # Generate mask names if text_prompt: # Split the text prompt into mask names mask_names = [name.strip() for name in text_prompt.split("&")] # Check if the number of names matches the number of masks if len(mask_names) != num_masks: print( "Warning: Number of mask names does not match number of masks. Using default names." ) mask_names = [f"Mask {i+1}" for i in range(num_masks)] else: # Default mask names if no text prompt is provided mask_names = [f"Mask {i+1}" for i in range(num_masks)] # Overlay each mask on the original image for i, mask in enumerate(segmentation_masks): # Set the title for the subplot ax[i + 1].set_title(mask_names[i]) # Create an overlay with the same dimensions as the original image overlay = np.zeros_like(original_image, dtype=np.uint8) # Define the mask threshold (assumes masks are in the range [0, 255]) threshold = 128 # Set the red channel where the mask is greater than the threshold overlay[mask > threshold, 0] = 255 # Red channel # Display the original image ax[i + 1].imshow(original_image) # Overlay the mask with transparency ax[i + 1].imshow(overlay, alpha=0.5) # Set the aspect ratio for each subplot ax[i + 1].set_aspect(aspect_ratio) plt.tight_layout() plt.show() # Combined inference function to handle both NIFTI and RGB inputs def run_inference( inference_config, file_path, text_prompt, is_CT=False, slice_idx=None, site=None, HW_index=(0, 1), channel_idx=None, ): """ Runs inference on the provided image and text input using the specified configuration. Parameters: - inference_config: dict with endpoint URL, API key, and model deployment info. - file_path: str, path to the image file. - text_prompt: str, text prompt for the model input. - is_CT: bool, True if the image is a CT scan (only used for NIFTI). - slice_idx: int, slice index for NIFTI images. - site: Optional, additional parameter for NIFTI images. - HW_index: tuple, used for indexing height and width of NIFTI images. - channel_idx: Optional, channel index for NIFTI images. Returns: - sample_image_arr: np.ndarray, the original image as an array. - image_features: np.ndarray, the decoded image features. """ # Get file extension from file_path if file_path.lower().endswith(".nii.gz"): file_extension = "nii.gz" else: file_extension = file_path.split(".")[-1].lower() # Read and encode image based on type if file_extension == "nii" or file_extension == "nii.gz": image_data = base64.encodebytes( read_nifti( file_path, is_CT, slice_idx, site=site, HW_index=HW_index, channel_idx=channel_idx, ).read() ).decode("utf-8") sample_image_arr = np.array( Image.open( read_nifti( file_path, is_CT, slice_idx, site=site, HW_index=HW_index, channel_idx=channel_idx, ) ) ) elif file_extension == "png" or file_extension == "jpg" or file_extension == "jpeg": image_data = base64.encodebytes(read_rgb(file_path).read()).decode("utf-8") sample_image_arr = np.array(Image.open(read_rgb(file_path))) elif file_extension == "dcm": image_data = base64.encodebytes( read_dicom(file_path, is_CT, site=site).read() ).decode("utf-8") sample_image_arr = np.array(Image.open(read_dicom(file_path, is_CT, site=site))) else: raise ValueError("Unsupported image type. Use 'RGB' or 'NIFTI'.") # Prepare the request payload data = { "input_data": { "columns": ["image", "text"], "index": [0], "data": [[image_data, text_prompt]], } } body = str.encode(json.dumps(data)) url = f"{inference_config['endpoint']}/score" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {inference_config['api_key']}", } deployment = inference_config.get("azureml_model_deployment", None) if deployment: headers["azureml-model-deployment"] = deployment # Send the request and handle response req = urllib.request.Request(url, body, headers) try: response = urllib.request.urlopen(req) result = response.read() result_list = json.loads(result) # Decode image features from response image_features_str = result_list[0]["image_features"] text_features = result_list[0]["text_features"] image_features = decode_json_to_array(image_features_str) # Plot the segmentation masks over the original image plot_segmentation_masks(sample_image_arr, image_features, text_prompt) except urllib.error.HTTPError as error: print(f"The request failed with status code: {error.code}") print(error.info()) print(error.read().decode("utf8", "ignore")) return sample_image_arr, image_features, text_features from processing_utils import get_instances import cv2 def plot_instance_segmentation_masks( original_image, segmentation_masks, text_prompt=None ): """Plot a list of segmentation mask over an image.""" original_image = original_image[:, :, :3] fig, ax = plt.subplots(1, len(segmentation_masks) + 1, figsize=(10, 5)) ax[0].imshow(original_image, cmap="gray") ax[0].set_title("Original Image") # grid off for a in ax: a.axis("off") instance_masks = [get_instances(1 * (mask > 127)) for mask in segmentation_masks] mask_names = [f"Mask {i+1}" for i in range(len(segmentation_masks))] if text_prompt: mask_names = text_prompt.split("&") for i in range(len(mask_names)): mask_names[i] = mask_names[i].strip() for i, mask in enumerate(instance_masks): ins_ids = np.unique(mask) count = len(ins_ids[ins_ids > 0]) ax[i + 1].set_title(f"{mask_names[i]} ({count})") mask_temp = np.zeros_like(original_image) for ins_id in ins_ids: if ins_id == 0: continue mask_temp[mask == ins_id] = np.random.randint(0, 255, 3) if ins_id == 1: mask_temp[mask == ins_id] = [255, 0, 0] ax[i + 1].imshow(mask_temp, alpha=1) ax[i + 1].imshow(original_image, cmap="gray", alpha=0.5) plt.show()