sdk/python/foundation-models/healthcare-ai/medimageinsight/exam-parameter-demo/exam_parameter_helpers.py (233 lines of code) (raw):
# Exam Parameter Detection notebook helper functions
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy, pandas as pd
import os
import pickle
import json
from openai import AzureOpenAI
from umap import UMAP
import plotly.graph_objs as go
from scipy.spatial import distance
from azureml.core import Workspace
from azureml.core.keyvault import Keyvault
def plot_parameter_distribution_categorical(df, parameter_name, plot_title, height=5):
"""Use matplotlib to plot parameter distribution"""
df[parameter_name].value_counts(dropna=False).plot(
kind="barh", figsize=(8, height), color="#86bf91", zorder=2, width=0.9
)
plt.title(plot_title)
# Create labels for the plot and include NaN values
labels = df[parameter_name].value_counts(dropna=False)
for i, v in enumerate(labels):
clr = "black"
if pd.isnull(labels.index[i]):
clr = "red"
plt.text(v + 1, i, str(v), color=clr, va="center")
def plot_parameter_distribution_histogram(
df, parameter_name, plot_title, bin_count, logscale=False
):
"""Use matplotlib to plot a histogram of data"""
# Plot the histogram directly using pandas
df[parameter_name].hist(bins=bin_count, edgecolor="black", log=logscale)
# Customize the plot
plt.title(plot_title)
plt.xlabel(parameter_name)
plt.ylabel("Frequency")
plt.grid(False)
# Show the plot
plt.show()
def sample_holdout_set(df, param_name, param_values, n_sample=5):
"""Samples a subset from the dataframe based on the parameter values provided"""
out_pd = pd.DataFrame()
for v in param_values:
sampled = df[
df[param_name].isnull() if v == None else df[param_name] == v
].sample(n_sample, random_state=42, replace=True)
out_pd = pd.concat([out_pd, sampled])
return out_pd
def create_exam_param_struct_from_dicom_tags(df_item):
"""Pack DICOM fields into a JSON object that can be sent to GPT"""
exam_params = {}
exam_params["Body Part Examined"] = df_item["BodyPartExamined"]
exam_params["Protocol Name"] = df_item["ProtocolName"]
exam_params["Series Description"] = df_item["SeriesDescription"]
exam_params["Image Type"] = df_item["ImageType"]
exam_params["Sequence Variant"] = df_item["SequenceVariant"]
return json.dumps(exam_params)
def load_environment_variables(json_file_path):
with open(json_file_path, "r") as file:
env_vars = json.load(file)
for key, value in env_vars.items():
os.environ[key] = value
def create_openai_client():
"""Plumbing to create the OpenAI client"""
# Try to load endpoint URL and API key from the JSON file
# (and load as environment variables)
load_environment_variables("environment.json")
# Try to get the key from environment
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
api_key = os.environ.get("AZURE_OPENAI_API_KEY", "")
if api_key == "":
# Try to get the key from AML workspace
# Load the workspace
ws = Workspace.from_config()
# Access the linked key vault
keyvault = ws.get_default_keyvault()
# Get the secret
api_key = keyvault.get_secret("azure-openai-api-key-westus")
client = AzureOpenAI(
azure_endpoint=endpoint,
api_key=api_key,
api_version="2024-02-01",
)
return client
def create_oai_assistant(client):
"""Creates assistant to keep track of prior responses"""
# Assistant API example: https://github.com/openai/openai-python/blob/main/examples/assistant.py
# Available in limited regions
deployment = "gpt-4o"
assistant = client.beta.assistants.create(
name="Math Tutor",
instructions="You are a categorizer. For each question answered, extract entities related to people's names and "
" jobs and categorize them. You always return result in JSON. You reuse categories from past responses when possible",
model=deployment,
tools=[{"type": "code_interpreter"}],
)
return assistant.id
def plot_clusters(df, column):
# Create a list of unique values for the column
colors = [
"red",
"blue",
"green",
"orange",
"purple",
"brown",
"pink",
"gray",
"olive",
"cyan",
"magenta",
"yellow",
"black",
"darkred",
"darkblue",
"darkgreen",
]
traces = []
for protocol, color in zip(df[column].unique(), colors):
trace = go.Scatter(
x=df[df[column] == protocol]["embedding_p1"],
y=df[df[column] == protocol]["embedding_p2"],
mode="markers",
marker=dict(color=color),
name=protocol,
text=df[df[column] == protocol][column].values,
)
traces.append(trace)
# Create a layout for the plot
layout = go.Layout(
title="MedImageInsight dimensionality reduction vs " + column + " (MRI).",
xaxis=dict(title="Projection 1"),
yaxis=dict(title="Projection 2"),
)
# Create a figure with the traces and layout
fig = go.Figure(data=traces, layout=layout)
# Show the figure
fig.show()
def prepare_feature_maps(df, dataset_root=""):
"""Sets up feature maps for series based on MedImageInsight's individual image embeddings"""
feat_mean_matrix = []
feat_std_matrix = []
for index, row in df.iterrows():
feat_path = row["features"]
feat_file = os.path.join(dataset_root, feat_path)
feat_dict = pd.read_pickle(feat_file)
# Sort features by slice number
feat_dict = dict(sorted(feat_dict.items()))
# creat list of dict items values
feat_list = list(feat_dict.values())
# Select the center slice + 10 slices before and after
center_slice = int(len(feat_dict) / 2)
slice_range = list(range(center_slice - 10, center_slice + 10))
# print("number of slices: ", len(feat_dict))
# Extract features for the selected slices and create a 2D matrix
feat_subject_matrix = []
# Check if slice_range is within the number of slices
slice_range = [x for x in slice_range if x < len(feat_list)]
for slice_num in slice_range:
feat_subject_matrix.append(feat_list[slice_num])
feat_subject_matrix = numpy.array(feat_subject_matrix)
# Calculate the mean and standard deviation of each feature
feat_mean = numpy.mean(feat_subject_matrix, axis=0)
feat_std = numpy.std(feat_subject_matrix, axis=0)
feat_center_slice = feat_list[center_slice]
# save feat_mean as pickle file
feat_mean_path = feat_file.replace(".pkl", ".mean.pkl")
with open(feat_mean_path, "wb") as handle:
pickle.dump(feat_mean, handle)
# save feat_center_slice as pickle file
feat_center_slice_path = feat_file.replace(".pkl", ".center_slice.pkl")
with open(feat_center_slice_path, "wb") as handle:
pickle.dump(feat_center_slice, handle)
feat_mean_matrix.append(feat_mean)
feat_std_matrix.append(feat_std)
feat_mean_matrix = numpy.array(feat_mean_matrix)
# Not used yet
feat_std_matrix = numpy.array(feat_std_matrix)
return feat_mean_matrix
def read_image(path):
"""Reads an image from a file and returns it as a numpy array"""
return mpimg.imread(path)
def plot_image(df, parameter, dataset_root=""):
# load the image and display it
para_vals = df[parameter].dropna().unique()
dict_para = {}
# subplot parameters (rows, columns)
(n_rows, n_cols) = (4, 5)
# Create a subplots of images for each value of the parameter
fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 10))
fig.suptitle("MRI images for each " + parameter)
for i, para in enumerate(para_vals):
# Select a row from the dataframe corresponding each value of of the given column
row = df[df[parameter] == para].iloc[0]
# Load the image
# list files in row['png_path'] and select the middle one
png_path = row["png_path"]
if dataset_root != "":
png_path = os.path.join(dataset_root, png_path)
files = os.listdir(png_path)
files.sort()
png_file = png_path + os.sep + files[len(files) // 2]
img = read_image(png_file)
# Display the image
axs[i // 5, i % 5].imshow(img, cmap="gray")
axs[i // 5, i % 5].set_title(para)
axs[i // 5, i % 5].axis("off")
dict_para[para] = img
if i == n_rows * n_cols - 1:
break
plt.show()
def display_closest_images(
df, column, image_series, text_flag=False, images_path="", features_path=""
):
"""Display the top 5 closest images to a given image given a dataframe and a given column"""
# The following code randomly picks an image with a given value and uses it to find the closest images
# random_index = numpy.random.randint(0, len(df[df[column] == value]))
# print('Random index: ', random_index)
# # Select a row from the dataframe corresponding each value of of the given column
# row = df[df[column] == value].iloc[random_index]
# The following code uses supplied dataframe row to find closest images
row = image_series
pickle_path = row["features"].replace(".pkl", ".center_slice.pkl")
if features_path != "":
pickle_path = os.path.join(features_path, pickle_path)
image_feat_mean = pd.read_pickle(pickle_path)
# Compute text if text_flag is True
if text_flag:
# compare the image_feat_mean with the values of text_parameter_feat_dict
image_text_similarity_list = []
for key, value in text_parameter_feat_dict.items():
# Cross product between the image_feat_mean and the values of text_parameter_feat_dict
cross_product = numpy.dot(image_feat_mean, value)
image_text_similarity_list.append(cross_product)
# Create a dataframe with the similarity values
df_distances = pd.DataFrame(
{
"parameter": list(text_parameter_feat_dict.keys()),
"distance": image_text_similarity_list,
}
)
# Sort the dataframe by distance
df_distances.sort_values(by="distance", inplace=True)
# Print the top 5 df_distances values
print("Top 5 closest exam parameters based on text similarity:")
print(df_distances.head(5))
embedding_p1 = row["embedding_p1"]
embedding_p2 = row["embedding_p2"]
png_path = row["png_path"]
if images_path != "":
png_path = os.path.join(images_path, png_path)
files = os.listdir(png_path)
files.sort()
png_file = png_path + os.sep + files[len(files) // 2]
img = read_image(png_file)
# String containing the value of the parameter of interest or "Not present" if the value is not present
value_string = (
row[column] if isinstance(row[column], str) and row[column] else "Not present"
)
plt.imshow(img, cmap="gray")
plt.title(f"{column}: " + value_string)
# Deep copy the dataframe
df_copy = df.copy()
distances = []
for index, row in df.iterrows():
dist = distance.euclidean(
[embedding_p1, embedding_p2], [row["embedding_p1"], row["embedding_p2"]]
)
distances.append(dist)
df_copy["distance"] = distances
df_copy.sort_values(by="distance", inplace=True)
# Display the top 5 closest images but not the original image
fig, axs = plt.subplots(1, 5, figsize=(20, 10))
for i in range(5):
row = df_copy.iloc[i + 1]
png_path = row["png_path"]
if images_path != "":
png_path = os.path.join(images_path, png_path)
files = os.listdir(png_path)
files.sort()
png_file = png_path + os.sep + files[len(files) // 2]
img = read_image(png_file)
axs[i].imshow(img, cmap="gray")
value_string = (
row[column]
if isinstance(row[column], str) and row[column]
else "Not present"
)
axs[i].set_title(
f"{column}: " + value_string + "\n" + "Distance: " + str(row["distance"])
)
plt.show()