# Scalable MedImageInsight Endpoint Usage

**Requirements** - To run this notebook, you will need:
- A basic understanding of Python and medical image processing
- Access to an Azure Machine Learning workspace and an online endpoint
- Installed necessary Python packages listed below

**Learning Objectives** - By the end of this tutorial, you will learn how to:
- Read and process DICOM images into NumPy arrays
- Convert processed images into image byte arrays
- Submit requests to an Azure Machine Learning endpoint with retry and rate limit handling
- Use `joblib` and `tqdm` for parallel processing and progress monitoring

**Motivation** - This notebook demonstrates how to generate embeddings of medical images at scale using the MedImageInsight API while handling potential network issues gracefully.


### Prerequisites 

### Create MedImageInsight endpoint
* Follow instructions in [deploy](./deploy.ipynb)

### Download data

`azcopy copy --recursive https://azuremlexampledata.blob.core.windows.net/data/healthcare-ai/ /home/azureuser/data/`

### Install Required Packages

We need to install several packages to ensure all functionalities are available.

`pip install 'tenacity~=9.0.0' 'ratelimit~=2.2.0' 'tqdm~=4.66.0' 'simpleitk~=2.4.0' 'joblib>1.4.0'`

## 1. Import Libraries

Import all the required libraries for image processing, handling requests, and parallel processing.


In [1]:
import numpy as np
from io import BytesIO
from PIL import Image
import itertools
import SimpleITK as sitk
import tempfile
from base64 import encodebytes
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
import glob
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
from ratelimit import limits, sleep_and_retry, RateLimitException
from requests.exceptions import ConnectionError, Timeout, HTTPError
import requests
from joblib import Parallel, delayed
from tqdm import tqdm

# Suppress SimpleITK warnings
sitk.ProcessObject_SetGlobalWarningDisplay(False)

## 2. Image Processing Functions

Define functions to load, read, and convert DICOM files into image byte arrays.


In [2]:
def load_databytes(path):
    """
    Load data bytes from a file path.

    Parameters
    ----------
    path : str
        File path.

    Returns
    -------
    bytes
        Data bytes.
    """
    with open(path, "rb") as f:
        return f.read()


def read_dicom_bytes_to_numpy(dicom_bytes: bytes) -> np.ndarray:
    """
    Read DICOM data from bytes to a NumPy array, applying windowing and normalization.

    Parameters
    ----------
    dicom_bytes : bytes
        The DICOM file content in bytes.

    Returns
    -------
    np.ndarray
        The windowed image as a NumPy array.
    """
    with tempfile.NamedTemporaryFile(suffix=".dcm") as temp_file:
        temp_file.write(dicom_bytes)
        temp_file.flush()
        img = sitk.ReadImage(temp_file.name)
        img_array = sitk.GetArrayFromImage(img).astype(np.float32)[0, :, :]
        img_array = np.clip(img_array, *np.percentile(img_array, [10, 90]))
        img_array = (
            (img_array - img_array.min()) * 255 / (img_array.max() - img_array.min())
        ).astype(np.uint8)
        return img_array


def numpy_to_image_bytearray(img: np.ndarray, format: str = "PNG") -> bytes:
    """
    Convert a NumPy array to an image byte array.

    Parameters
    ----------
    img : np.ndarray
        The image as a NumPy array.
    format : str, optional
        The image format, by default "PNG".

    Returns
    -------
    bytes
        The image in byte array format.
    """
    byte_io = BytesIO()
    pil_image = Image.fromarray(img)
    if pil_image.mode == "L":
        pil_image = pil_image.convert("RGB")
    pil_image.save(byte_io, format=format)
    return byte_io.getvalue()


def read_to_imagebytes(dcm_bytes):
    """
    Convert DICOM bytes to image byte array.

    Parameters
    ----------
    dcm_bytes : bytes
        DICOM file content in bytes.

    Returns
    -------
    bytes
        Image data in bytes.
    """
    np_img = read_dicom_bytes_to_numpy(dcm_bytes)
    return numpy_to_image_bytearray(np_img)


def path_to_imagebytes(path):
    """
    Convert a DICOM file at a given path to image byte array.

    Parameters
    ----------
    path : str
        File path to the DICOM file.

    Returns
    -------
    bytes
        Image data in bytes.
    """
    bytes_data = load_databytes(path)
    return read_to_imagebytes(bytes_data)

## 3. Request Submission Functions

### 3.1 Creating Post Function with Retries and Rate Limiting

Define a robust `post` function that handles retries and respects rate limits.


In [3]:
def create_post_func(
    retries=5, rate_calls=60, rate_period=60, exp_multiplier=1, exp_min=2, exp_max=60
):
    """
    Create a post function with retries and rate limiting.

    Parameters
    ----------
    retries : int
        Number of retry attempts.
    rate_calls : int
        Number of allowed calls in the period.
    rate_period : int
        Period in seconds for rate limiting.
    exp_multiplier : int
        Multiplier for exponential backoff.
    exp_min : int
        Minimum wait time in seconds.
    exp_max : int
        Maximum wait time in seconds.

    Returns
    -------
    function
        Configured post function.
    """

    def is_retryable_exception(exc):
        if isinstance(exc, (ConnectionError, Timeout, RateLimitException)):
            return True
        elif isinstance(exc, HTTPError) and exc.response is not None:
            if 500 <= exc.response.status_code < 600 or exc.response.status_code == 429:
                return True
        return False

    @retry(
        retry=retry_if_exception(is_retryable_exception),
        wait=wait_exponential(multiplier=exp_multiplier, min=exp_min, max=exp_max),
        stop=stop_after_attempt(retries),
    )
    @sleep_and_retry
    @limits(calls=rate_calls, period=rate_period)
    def post(*args, **kwargs):
        response = requests.post(*args, **kwargs)
        response.raise_for_status()
        return response.json()

    return post

### 3.2 Submitting Requests

Functions to submit image data to the Azure ML online endpoint, handling batch requests.


In [4]:
def submit_batch_request(list_image_databytes, params, target, headers, post_func):
    """
    Submit a batch of image data to the endpoint.

    Parameters
    ----------
    list_image_databytes : list
        List of image data in bytes.
    params : dict
        Additional parameters for the request.
    target : str
        Endpoint URL.
    headers : dict
        Request headers.
    post_func : function
        Function to post the request.

    Returns
    -------
    list
        List of results from the endpoint.
    """
    text_data = ""

    def encode_data(image_databytes, text_data):
        return [encodebytes(image_databytes).decode("utf-8"), text_data]

    payload = {
        "input_data": {
            "columns": ["image", "text"],
            "index": [i for i in range(len(list_image_databytes))],
            "data": [
                encode_data(image_databytes, text_data)
                for image_databytes in list_image_databytes
            ],
        },
        "params": params,
    }

    response_json = post_func(target, json=payload, headers=headers)
    result = [r["image_features"] for r in response_json]
    return result


def submit_request(image_databytes, params, target, headers, post_func):
    """
    Submit a single image data to the endpoint.

    Parameters
    ----------
    image_databytes : bytes
        Image data in bytes.
    params : dict
        Additional parameters for the request.
    target : str
        Endpoint URL.
    headers : dict
        Request headers.
    post_func : function
        Function to post the request.

    Returns
    -------
    Any
        Result from the endpoint.
    """
    return submit_batch_request([image_databytes], params, target, headers, post_func)[
        0
    ]

## 4. Configure Azure ML Client

Set up the Azure ML client to interact with the online endpoint.


In [None]:
# Azure ML endpoint name
endpoint_name = ""  # Set this to the name of the endpoint you wish to use.

# Initialize MLClient with DefaultAzureCredential
ml_client = MLClient.from_config(DefaultAzureCredential())

# Get endpoint details
endpoint = ml_client.online_endpoints.get(name=endpoint_name)
keys = ml_client.online_endpoints.get_keys(name=endpoint_name)

# Set target URL and headers
target = endpoint.scoring_uri
api_key = keys.primary_key
headers = {"Authorization": f"Bearer {api_key}"}

## 5. Processing Images in Parallel
### 5.1. Retrieve DICOM File Paths

Use `glob` to collect all DICOM file paths from a directory.


In [None]:
filelist = list(
    glob.glob(
        "/home/azureuser/data/healthcare-ai/medimageinsight-zeroshot/**/*.dcm",
        recursive=True,
    )
)
print(f"Total DICOM files found: {len(filelist)}")

### 5.2 Process Images

Process the DICOM images using parallel processing and collect the results.


In [None]:
# Define a function to process individual images and set up parallel processing
request_post_w_retry = create_post_func(retries=8, rate_calls=60, rate_period=60)


def process_path(path):
    image_databytes = path_to_imagebytes(path)
    return submit_request(image_databytes, {}, target, headers, request_post_w_retry)


# Number of parallel jobs
njobs = 3

results = []
with tqdm(total=len(filelist)) as pbar:
    # Process files in parallel and collect results
    results_gen = Parallel(
        n_jobs=njobs, prefer="threads", return_as="generator_unordered"
    )(delayed(process_path)(path=path) for path in filelist)
    for res in results_gen:
        pbar.update(1)
        results.append(res)

### 5.3. Process Images in Batches

Process the DICOM images in batches using `submit_batch_request` directly, splitting the data into chunks of 10.


In [None]:
def batchify(iterable, batch_size=10):
    """Yield successive chunks of a specified size from an iterable."""
    iterator = iter(iterable)
    while True:
        chunk = list(itertools.islice(iterator, batch_size))
        if not chunk:
            break
        yield chunk


def process_batch(batch_paths):
    list_image_databytes = [path_to_imagebytes(path) for path in batch_paths]
    batch_results = submit_batch_request(
        list_image_databytes,
        params=None,
        target=target,
        headers=headers,
        post_func=request_post_w_retry,
    )
    return batch_results


# Number of parallel jobs
njobs = 3

results = []
total_files = len(filelist)
batch_paths_list = list(batchify(filelist, batch_size=10))

with tqdm(total=total_files) as pbar:
    # Process batches in parallel and collect results
    results_gen = Parallel(
        n_jobs=njobs, prefer="threads", return_as="generator_unordered"
    )(delayed(process_batch)(batch_paths) for batch_paths in batch_paths_list)
    for batch_results in results_gen:
        results.extend(batch_results)
        pbar.update(len(batch_results))

## 6. Conclusion

By implementing a robust request submission function with retries and rate limiting, we ensured reliable and efficient communication with the Azure Machine Learning online endpoint. This approach handles potential network issues gracefully, maintaining the integrity of the data processing pipeline.

Processing multiple DICOM images in parallel using `joblib` and `tqdm` significantly enhanced computational efficiency. This method is crucial when dealing with large datasets common in medical imaging, allowing for scalable and time-effective data analysis.

The combination of these techniques facilitated the successful submission of image data to the endpoint and retrieval of results, demonstrating an effective and robust image processing workflow.

---

**Next Steps**:

- Utilize the similar robust request functions to interact with other Azure ML models, such as **"MedImageParse"** and **"CRReportGen"**, to broaden the scope of your medical imaging analysis.
- Setup [autoscaling](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-autoscale-endpoints?view=azureml-api-2&tabs=python) with your endpoint to significantly improve performance!
