ez_wsi_dicomweb/patch_embedding_endpoints.py (1,620 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core methods to return endpoint generated embeddings for patch pixels."""
from __future__ import annotations
import abc
from collections.abc import Sequence
import concurrent
import copy
import dataclasses
import enum
import json
import math
import threading
import typing
from typing import Any, Callable, Generic, List, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar, Union
import cachetools
from ez_wsi_dicomweb import credential_factory as credential_factory_module
from ez_wsi_dicomweb import dicom_slide
from ez_wsi_dicomweb import error_retry_util
from ez_wsi_dicomweb import ez_wsi_errors
from ez_wsi_dicomweb import gcs_image
from ez_wsi_dicomweb import patch_embedding_types
from ez_wsi_dicomweb import slide_level_map
from ez_wsi_dicomweb.ml_toolkit import tags
import google.auth
import numpy as np
from PIL import ImageCms
import requests
import retrying
class EndpointJsonKeys:
"""JSON Keys for pathology v1 and v2 encoder endpoint."""
# V2 encoder key types
IMAGE_FILE_URI = 'image_file_uri'
RAW_IMAGE_BYTES = 'raw_image_bytes'
DICOM_PATH = 'dicom_path'
SERIES_PATH = 'series_path'
INSTANCE_UIDS = 'instance_uids'
BEARER_TOKEN = 'bearer_token'
INSTANCES = 'instances'
PATCH_COORDINATES = 'patch_coordinates'
X_ORIGIN = 'x_origin'
Y_ORIGIN = 'y_origin'
WIDTH = 'width'
HEIGHT = 'height'
EXTENSIONS = 'extensions'
IMAGE_DIMENSIONS = 'image_dimensions'
TRANSFORM_IMAGING_TO_ICC_PROFILE = 'transform_imaging_to_icc_profile'
REQUIRE_PATCHES_FULLY_IN_SOURCE_IMAGE = (
'require_patches_fully_in_source_image'
)
EZ_WSI_STATE = 'ez_wsi_state'
# key for list of patch bytes, base64 encoded
PATCHES = 'patches'
# key whole image bytes, base64 encoded
IMAGE = 'image'
# icc profile norm performed on patch imaging.
ICC_PROFILE_METADATA_NORMALIZATION = 'icc_profile_metadata_normalization'
# height and width of source image for patch
SOURCE_IMAGE_WIDTH_PX = 'source_image_width_px'
SOURCE_IMAGE_HEIGHT_PX = 'source_image_height_px'
# V1 encoder only
DICOM_WEB_STORE_URL = 'dicom_web_store_url'
DICOM_STUDY_UID = 'dicom_study_uid' # V1 encoder only
DICOM_SERIES_UID = 'dicom_series_uid' # V1 encoder only
GCS_IMAGE_URL = 'gcs_image_url' # V1 encoder only
PROJECT_NAME = 'project_name' # V1 encoder only
PARAMETERS = 'parameters' # V1 encoder only
MODEL_SIZE = 'model_size' # V1 encoder only
MODEL_KIND = 'model_kind' # V1 encoder only
# embedding encoder response
PREDICTIONS = 'predictions'
VERTEXAI_ERROR = 'error'
ERROR = 'error'
ERROR_CODE = 'code'
ERROR_CODE_DESCRIPTION = 'description'
RESULT = 'result'
EMBEDDINGS = 'embeddings' # V1 encoder only
ERROR_RESPONSE = 'error_response' # V1 encoder only
EMBEDDING_RESULT = 'embedding_result' # V1 encoder only
EMBEDDING_VECTOR = 'embedding_vector' # V2 encoder
PATCH_EMBEDDINGS = 'patch_embeddings'
PATCH_COORDINATE = 'patch_coordinate'
# Retryable error codes
INVALID_CREDENTIALS = 'INVALID_CREDENTIALS_ERROR'
# maximum request size in bytes for endpoint. Less than vertex max to provide
# safety margin.
_MAX_VERTEX_AI_V1_REQUEST_SIZE_BYTES = 1300000
_DEFAULT_RETRY_COUNT = 5
_DEFAULT_ENDPOINT_THREADS = 5
_DEFAULT_MAX_PATCHES_PER_REQUEST = 100
_MAX_ENDPOINT_THREADS = 10
_ITERATOR_MAX_ENDPOINT_PATCHES_PER_REQUEST = 100
_MAX_V1_ENDPOINT_PATCHES_PER_REQUEST = 3000
_MAX_V2_ENDPOINT_PATCHES_PER_REQUEST = 3000
_DEFAULT_DICOM_INSTANCE_ICC_PROFILE_CACHE_COUNT = 20
# Size safety buffer encode whole images may not exceed.
# Maxsize= _MAX_VERTEX_AI_V1_REQUEST_SIZE_BYTES - _WHOLE_IMAGE_SIZE_SAFTY_MARGIN
# Images exceeding this are are encoded as patches which enables them to be
# split across multiple VertexAI requests..
_WHOLE_IMAGE_SIZE_SAFTY_MARGIN = 300000
# Pyramid ICC profiles are optimally serialized in JSON to avoid repeative
# re-initialization. However, some digital pathology DICOM, e.g. Leica, have
# very large ICC profiles, e.g., 12 MB. The default max size of the ICC profile
# controls the maximum size of the ICC profile serialized in JSON.
_MAX_DICOM_SLIDE_ICCPROFILE_METADATA_SIZE = min(
204800, slide_level_map.DEFAULT_MAX_JSON_ENCODED_ICC_PROFILE_SIZE_IN_BYTES
)
_DEFAULT_TIMEOUT = None # Time out in seconds; None for no timeout.
class IccProfileNormalization(enum.Enum):
"""ICC Profile To Normalize Embedding Patches To."""
NONE = 'NONE'
SRGB = 'SRGB'
ADOBERGB = 'ADOBERGB'
ROMMRGB = 'ROMMRGB'
_UINT8_MAX_VALUE = 255.0
def _test_patch_coordinates_match(
pc: Mapping[str, Any], x: int, y: int, width: int, height: int
) -> bool:
"""Test if dict encoded coordinates match expected coordinates."""
try:
if pc[EndpointJsonKeys.X_ORIGIN] != x or pc[EndpointJsonKeys.Y_ORIGIN] != y:
return False
if (
pc.get(EndpointJsonKeys.WIDTH, width) != width
or pc.get(EndpointJsonKeys.HEIGHT, height) != height
):
return False
return True
except (IndexError, KeyError, ValueError, TypeError) as _:
return False
def _get_icc_profile_bytes(
icc_profile_normalization: IccProfileNormalization,
) -> bytes:
"""Returns ICC Profile bytes for endpoint."""
if icc_profile_normalization == IccProfileNormalization.NONE:
return b''
if icc_profile_normalization == IccProfileNormalization.SRGB:
return dicom_slide.get_srgb_icc_profile_bytes()
if icc_profile_normalization == IccProfileNormalization.ADOBERGB:
return dicom_slide.get_adobergb_icc_profile_bytes()
if icc_profile_normalization == IccProfileNormalization.ROMMRGB:
return dicom_slide.get_rommrgb_icc_profile_bytes()
raise ez_wsi_errors.InternalError('ICC Profile not supported')
RequestResponseType = TypeVar('RequestResponseType')
def normalized_patch_channels(
width: int, height: int, patch: np.ndarray
) -> np.ndarray:
"""Normalize monochrome and RGBA imaging to RGB."""
if patch.shape == (height, width, 3):
return patch
if patch.shape == (height, width):
patch = np.expand_dims(patch, axis=-1)
if patch.shape == (height, width, 1):
mem = np.zeros((height, width, 3), dtype=patch.dtype)
mem[..., np.arange(3)] = patch[...]
return mem
if patch.shape == (height, width, 4):
return patch[..., :3]
raise ez_wsi_errors.PatchEmbeddingDimensionError
class AbstractPreparedEmbeddingRequest(
Generic[RequestResponseType], metaclass=abc.ABCMeta
):
"""Base class for prepared embedding requests."""
def __init__(
self,
slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
):
self._slide_embedding_source = slide_embedding_source
@property
def slide_embedding_source(
self,
) -> patch_embedding_types.SlideEmbeddingSource:
return self._slide_embedding_source
@property
@abc.abstractmethod
def json_size_in_bytes(self) -> int:
"""Return size in bytes of json sent to endpoint."""
@abc.abstractmethod
def finalize(self) -> None:
"""finalize after this there will be no more changes."""
@abc.abstractmethod
def split(
self,
endpoint: AbstractPatchEmbeddingEndpoint[RequestResponseType],
max_size: int,
) -> Tuple[
Optional[AbstractPreparedEmbeddingRequest[RequestResponseType]],
patch_embedding_types.SlideEmbeddingSource,
]:
"""Splits object into parts which meet size and exceed size req."""
def _copy_dict_excluding_keys(
state: Mapping[str, Any],
exclude: List[List[str]],
base_keys: Optional[List[str]] = None,
) -> MutableMapping[str, Any]:
"""Duplicates str keyed dict excluding predfined keys."""
if base_keys is None:
base_keys = []
copy_dict = {}
for key, value in state.items():
base_keys.append(key)
if base_keys in exclude:
base_keys.pop()
continue
if not isinstance(value, Mapping):
copy_dict[key] = value
else:
copy_dict[key] = _copy_dict_excluding_keys(value, exclude, base_keys)
base_keys.pop()
return copy_dict
@dataclasses.dataclass
class _VertexModelResult:
instances: List[Mapping[str, Any]]
class PreparedVertexEmbeddingRequest(
AbstractPreparedEmbeddingRequest[_VertexModelResult]
):
"""Internral respresentation of embedding json embedding requests."""
def __init__(
self,
prepared_request: Mapping[str, Any],
slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
):
super().__init__(slide_embedding_source)
self._embedding_request: Mapping[str, Any] = prepared_request
self._embedding_json: Optional[str] = None
@classmethod
def init_from_json_finalized(
cls,
json_str: str,
slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
) -> PreparedVertexEmbeddingRequest:
"""create finalized vertex embedding request from json and source."""
instance = PreparedVertexEmbeddingRequest.__new__(
PreparedVertexEmbeddingRequest
)
super(PreparedVertexEmbeddingRequest, instance).__init__(
slide_embedding_source
)
instance._embedding_json = json_str
instance._embedding_request = None
return instance
@property
def embedding_request(self) -> Mapping[str, Any]:
if self._embedding_request is None:
raise ez_wsi_errors.InternalError('Request has been finalized.')
return self._embedding_request
@property
def json(self) -> str:
if self._embedding_json is None:
self._embedding_json = json.dumps(self._embedding_request)
return self._embedding_json
@property
def json_size_in_bytes(self) -> int:
return len(self.json)
def finalize(self) -> None:
if self._embedding_json is None:
self._embedding_json = json.dumps(self._embedding_request)
self._embedding_request = None
def _non_splitable(
self,
) -> Tuple[None, patch_embedding_types.SlideEmbeddingSource]:
return None, self.slide_embedding_source
def _split_results(self, split_request: str, end_split_index: int) -> Tuple[
Optional[PreparedVertexEmbeddingRequest],
patch_embedding_types.SlideEmbeddingSource,
]:
"""Returns split prepared vertex embedding result."""
split_half_slide_embedding_source = (
patch_embedding_types.SlideEmbeddingSource(
self.slide_embedding_source.patches[:end_split_index]
)
)
finalized_split_half = (
PreparedVertexEmbeddingRequest.init_from_json_finalized(
split_request, split_half_slide_embedding_source
)
)
slide_embedding_source = patch_embedding_types.SlideEmbeddingSource(
self.slide_embedding_source.patches[end_split_index:]
)
return (
finalized_split_half,
slide_embedding_source,
)
def _split_on_patch_coordinates_only(
self,
endpoint: AbstractPatchEmbeddingEndpoint[_VertexModelResult],
max_size: int,
) -> Tuple[
Optional[PreparedVertexEmbeddingRequest],
patch_embedding_types.SlideEmbeddingSource,
]:
"""If possilbe splits prepared request to meet size req."""
if self.json_size_in_bytes < endpoint.max_request_size_bytes():
# it is much less desirable to split DicomPatches or GcsImages into
# multiple requests. The majority of the metadata could be in state
# which cannot be split and would be sent in duplicate across multiple
# requests. Test if entire message could be encoded by itself. If it can
# defer sending.
return self._non_splitable()
try:
coordinates = self.embedding_request[EndpointJsonKeys.PATCH_COORDINATES]
except KeyError:
return self._non_splitable()
# make copy of whats in patch exclude coordinates and cached patches.
base_request = _copy_dict_excluding_keys(
self.embedding_request,
[[EndpointJsonKeys.PATCH_COORDINATES]],
)
request_size = len(json.dumps(base_request))
end_split_index = 0
for coordinate in coordinates:
patch_md_size = len(json.dumps(coordinate))
if request_size + patch_md_size >= max_size:
break
request_size += patch_md_size
end_split_index += 1
while True:
if end_split_index <= 0:
return self._non_splitable()
split_request = copy.copy(base_request)
split_request[EndpointJsonKeys.PATCH_COORDINATES] = coordinates[
:end_split_index
]
split_request = json.dumps(split_request)
if len(split_request) <= max_size:
break
end_split_index -= 1
return self._split_results(split_request, end_split_index)
def split(
self,
endpoint: AbstractPatchEmbeddingEndpoint[_VertexModelResult],
max_size: int,
) -> Tuple[
Optional[PreparedVertexEmbeddingRequest],
patch_embedding_types.SlideEmbeddingSource,
]:
"""If possible splits object into parts which meet size and exceed size req."""
if EndpointJsonKeys.DICOM_PATH in self.embedding_request:
return self._split_on_patch_coordinates_only(endpoint, max_size)
elif EndpointJsonKeys.IMAGE_FILE_URI in self.embedding_request:
try:
ez_wsi_state = self.embedding_request[EndpointJsonKeys.EXTENSIONS][
EndpointJsonKeys.EZ_WSI_STATE
]
patches = ez_wsi_state[EndpointJsonKeys.PATCHES]
coordinates = self.embedding_request[EndpointJsonKeys.PATCH_COORDINATES]
except KeyError:
return self._split_on_patch_coordinates_only(endpoint, max_size)
if len(patches) <= 1:
return self._non_splitable()
if len(coordinates) != len(patches):
raise ez_wsi_errors.InternalError(
'Patch state and coordinate counts do not match.'
)
# make copy of whats in patch exclude coordinates and cached patches.
base_request = _copy_dict_excluding_keys(
self.embedding_request,
[
[EndpointJsonKeys.PATCH_COORDINATES],
[
EndpointJsonKeys.EXTENSIONS,
EndpointJsonKeys.EZ_WSI_STATE,
EndpointJsonKeys.PATCHES,
],
],
)
request_size = len(json.dumps(base_request))
end_split_index = 0
for patch_metadata in patches:
patch_md_size = len(patch_metadata)
if request_size + patch_md_size >= max_size:
break
request_size += patch_md_size
end_split_index += 1
while True:
if end_split_index <= 0:
return self._non_splitable()
patch_coordinate_size = len(json.dumps(coordinates[:end_split_index]))
if patch_coordinate_size + request_size < max_size:
break
end_split_index -= 1
if end_split_index > 0:
request_size -= len(patches[end_split_index])
while True:
if end_split_index <= 0:
return self._non_splitable()
split_request = copy.copy(base_request)
split_request[EndpointJsonKeys.PATCH_COORDINATES] = coordinates[
:end_split_index
]
split_request[EndpointJsonKeys.EXTENSIONS][
EndpointJsonKeys.EZ_WSI_STATE
][EndpointJsonKeys.PATCHES] = patches[:end_split_index]
split_request = json.dumps(split_request)
if len(split_request) <= max_size:
break
end_split_index -= 1
return self._split_results(split_request, end_split_index)
raise ez_wsi_errors.InternalError('unidentified JSON')
class AbstractPatchEmbeddingEndpoint(
Generic[RequestResponseType], metaclass=abc.ABCMeta
):
"""Abstract class for embedding endpoint."""
def __init__(
self,
icc_profile_normalization: IccProfileNormalization,
timeout: Optional[Union[float, int]],
):
self._icc_profile_normalization = icc_profile_normalization
self._icc_profile_bytes = None
self._timeout = timeout
@abc.abstractmethod
def max_request_size_bytes(self) -> int:
"""Maximum size in bytes that can be sent in single request."""
@abc.abstractmethod
def max_threads(self) -> int:
"""Returns maximum number of threads to spawn."""
@abc.abstractmethod
def patch_width(self) -> int:
"""Returns embedding endpoint input size width in pixels."""
@abc.abstractmethod
def patch_height(self) -> int:
"""Returns embedding endpoint input size height in pixels."""
@abc.abstractmethod
def max_number_of_patches_per_request(self) -> int:
"""Maximum number of patches to send endpoint in a request."""
@abc.abstractmethod
def endpoint_max_number_of_patches_per_request(self) -> int:
"""Maximum number of patches that the endpoint supports."""
@abc.abstractmethod
def retry_count(self) -> int:
"""Maximum number of get_embedding attempts before endpoint raises.."""
@abc.abstractmethod
def prepare_embedding_request(
self,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> AbstractPreparedEmbeddingRequest[RequestResponseType]:
"""Converts slide embedding source to request to JSON formatted request."""
@abc.abstractmethod
def request_embeddings(
self,
embedding_inputs: Sequence[
AbstractPreparedEmbeddingRequest[RequestResponseType]
],
) -> RequestResponseType:
"""Returns raw embedding result."""
@property
def request_embedding_return_type(self) -> Type[RequestResponseType]:
return RequestResponseType
@abc.abstractmethod
def process_response(
self,
embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
msg: RequestResponseType,
) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
"""Converts raw embedding response to list of embedding results."""
@property
def icc_profile_normalization(self) -> IccProfileNormalization:
"""Returns ICC Profile bytes for endpoint will transform imaging to."""
return self._icc_profile_normalization
def icc_profile_bytes(self) -> bytes:
"""Returns ICC Profile bytes for endpoint will transform imaging to."""
if self._icc_profile_bytes is None:
self._icc_profile_bytes = _get_icc_profile_bytes(
self._icc_profile_normalization
)
return self._icc_profile_bytes
@property
def timeout(self) -> Optional[Union[float, int]]:
return self._timeout
@timeout.setter
def timeout(self, val: Optional[Union[float, int]]) -> None:
self._timeout = val
def _get_gcs_image_md_size(
json_metadata: Mapping[str, Union[int, str, List[str]]],
) -> int:
"""returns size in bytes of data encoded in metadata."""
size = 0
for value in json_metadata.values():
if isinstance(value, int):
size += len(str(value))
elif isinstance(value, str):
size += len(value)
elif isinstance(value, list):
size += sum(len(md) for md in value)
else:
raise ez_wsi_errors.InternalError(
f'Unsupported metadata value type: {value}'
)
return size
def _patch_pixel_area(
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> int:
patch_pixel_area = 0
for ip in slide_embedding.patches:
patch_pixel_area += ip.patch.width * ip.patch.height
return patch_pixel_area
def _get_gcs_whole_image_metadata(
source_image: gcs_image.GcsImage,
icc_profile_normalization: IccProfileNormalization,
c_transform: Optional[ImageCms.ImageCmsTransform],
max_image_size_bytes: int,
) -> Tuple[Mapping[str, Any], int]:
"""Return whole image metadata and metadata size in bytes."""
if source_image.size_bytes_of_source_image is not None:
# image bytes are base64 encoded estimate actual size as 4x original.
# plus 100 bytes of padding.
estimated_size = int(
math.ceil(source_image.size_bytes_of_source_image * 8 / 6)
+ len(IccProfileNormalization.NONE.value)
)
if estimated_size >= max_image_size_bytes:
return {}, estimated_size
try:
source_image_metadata = {
EndpointJsonKeys.IMAGE: source_image.source_image_bytes_json_metadata(),
EndpointJsonKeys.ICC_PROFILE_METADATA_NORMALIZATION: (
IccProfileNormalization.NONE.value
),
}
except ez_wsi_errors.GcsImageError:
# image bytes initialized from in memory representation.
source_image_md_size = int(
math.ceil(
source_image.width
* source_image.height
* source_image.bytes_pre_pixel
* 8 # bits per channel
/ 12 # base 64 encoding + est 2x reduction due to PNG compression
)
+ len(icc_profile_normalization.value)
)
if source_image_md_size >= max_image_size_bytes:
return {}, source_image_md_size
source_image_metadata = {
EndpointJsonKeys.IMAGE: source_image.json_metadata(c_transform),
EndpointJsonKeys.ICC_PROFILE_METADATA_NORMALIZATION: (
icc_profile_normalization.value
),
}
source_image_md_size = _get_gcs_image_md_size(source_image_metadata)
if source_image_md_size >= max_image_size_bytes:
return {}, source_image_md_size
return source_image_metadata, source_image_md_size
def _gcs_image_json_metadata(
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
icc_profile_normalization: IccProfileNormalization,
c_transform: Optional[ImageCms.ImageCmsTransform],
max_endpoint_request_size_bytes: int,
) -> Mapping[str, Any]:
"""Returns metadata for GCS images."""
patch = typing.cast(gcs_image.GcsPatch, slide_embedding.patches[0].patch)
source_image = patch.source
max_image_size_bytes = max(
0, max_endpoint_request_size_bytes - _WHOLE_IMAGE_SIZE_SAFTY_MARGIN
)
source_image_metadata, source_image_md_size = _get_gcs_whole_image_metadata(
source_image,
icc_profile_normalization,
c_transform,
max_image_size_bytes,
)
if not source_image_metadata or _patch_pixel_area(slide_embedding) < int(
0.95 * float(source_image.width * source_image.height)
):
# If patch area is smaller than 95% source image area, compute patch
# metadata. (5% factor to account for over head associated defining
# multiple patches instead of single image.)
patch_metadata = {
EndpointJsonKeys.SOURCE_IMAGE_WIDTH_PX: int(patch.source.width),
EndpointJsonKeys.SOURCE_IMAGE_HEIGHT_PX: int(patch.source.height),
EndpointJsonKeys.ICC_PROFILE_METADATA_NORMALIZATION: (
icc_profile_normalization.value
),
EndpointJsonKeys.PATCHES: [
typing.cast(gcs_image.GcsPatch, ip.patch).json_metadata(c_transform)
for ip in slide_embedding.patches
],
}
if (
not source_image_metadata
or _get_gcs_image_md_size(patch_metadata) < source_image_md_size
):
return patch_metadata
return source_image_metadata
def _get_gcs_image_metadata(
max_endpoint_request_size_bytes: int,
encode_patch_data_in_request: bool,
icc_profile_normalization: IccProfileNormalization,
icc_profile_bytes: bytes,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Returns metadata for GCS images."""
source_image = typing.cast(
gcs_image.GcsPatch, slide_embedding.patches[0].patch
).source
c_transform = source_image.create_icc_profile_transformation(
icc_profile_bytes
)
gcs_image_url = source_image.uri
if not gcs_image_url:
# always send image data if no image url.
return _gcs_image_json_metadata(
slide_embedding,
icc_profile_normalization,
c_transform,
max_endpoint_request_size_bytes,
)
if encode_patch_data_in_request and source_image.are_image_bytes_loaded:
json_metadata = _gcs_image_json_metadata(
slide_embedding,
icc_profile_normalization,
c_transform,
max_endpoint_request_size_bytes,
)
if (
source_image.size_bytes_of_source_image is None
or _get_gcs_image_md_size(json_metadata)
<= source_image.size_bytes_of_source_image
):
# Send JSON metadata if the image was initialized from raw bytes or
# the size of the json is smaller than source image.
return json_metadata
return {}
def _get_dicom_instance_uids_and_required_levels(
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Tuple[List[str], List[str]]:
"""Returns list of DICOM instances and leveles required for WSI patches."""
instance_uids = set()
required_levels = []
for patch in slide_embedding.patches:
level = typing.cast(
dicom_slide.DicomPatch, patch.patch
).get_pyramid_imaging_source_level()
if level in required_levels:
continue
required_levels.append(level)
for instance in level.instances.values():
instance_uids.add(instance.dicom_object.get_value(tags.SOP_INSTANCE_UID))
return list(instance_uids), required_levels
class AbstractVertexPatchEmbeddingEndpointBase(
AbstractPatchEmbeddingEndpoint[_VertexModelResult]
):
"""Shared implementation of patch embedding endpoint for V1 and V2 endpoints."""
def __init__(
self,
patch_width: int,
patch_height: int,
icc_profile_normalization: IccProfileNormalization,
send_gcs_patch_bytes_from_client_to_server: bool,
end_point_url: str,
max_threads: int,
max_patches_per_request: int,
endpoint_max_patches_per_request: int,
retry_count: int,
credential_factory: Optional[
credential_factory_module.AbstractCredentialFactory
],
timeout: Optional[Union[int, float]],
):
super().__init__(icc_profile_normalization, timeout)
self._patch_width = patch_width
self._patch_height = patch_height
self._credentials = None
self._credentials_factory = (
credential_factory
if credential_factory is not None
else credential_factory_module.DefaultCredentialFactory()
)
self._send_gcs_patch_bytes_from_client_to_server = (
send_gcs_patch_bytes_from_client_to_server
)
self._end_point_url = end_point_url
self._max_threads = max(1, min(max_threads, _MAX_ENDPOINT_THREADS))
self._endpoint_max_patches_per_request = int(
max(1, endpoint_max_patches_per_request)
)
self._max_patches_per_request = int(
max(
1,
min(
max_patches_per_request, self._endpoint_max_patches_per_request
),
)
)
self._retry_count = max(0, retry_count)
@property
def end_point_url(self) -> str:
return self._end_point_url
def max_request_size_bytes(self) -> int:
"""Maximum size in bytes that can be sent in single request."""
return _MAX_VERTEX_AI_V1_REQUEST_SIZE_BYTES
@property
def credentials(self) -> google.auth.credentials.Credentials:
if self._credentials is None:
self._credentials = self._credentials_factory.get_credentials()
else:
self._credentials = credential_factory_module.refresh_credentials(
self._credentials, self._credentials_factory
)
return self._credentials
def vertex_endpoint_authentication_header(self) -> MutableMapping[str, str]:
headers = {}
self.credentials.apply(headers)
return headers
def retry_count(self) -> int:
return self._retry_count
def max_threads(self) -> int:
return self._max_threads
def patch_width(self) -> int:
return self._patch_width
def patch_height(self) -> int:
return self._patch_height
def max_number_of_patches_per_request(self) -> int:
return self._max_patches_per_request
def endpoint_max_number_of_patches_per_request(self) -> int:
"""Maximum number of patches that can be sent to the endpoint at once."""
return self._endpoint_max_patches_per_request
@abc.abstractmethod
def _dicom_patch_embedding_request(
self,
bearer_token: str,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Returns DICOM patch embedding request.
Args:
bearer_token: Bearer for store requests.
slide_embedding: DICOM embedding inputs.
Returns:
JSON formatted embedding request.
"""
@abc.abstractmethod
def get_embedding_request(
self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
) -> str:
"""Returns patch embedding request.
Args:
embedding_inputs: Embedding inputs.
Returns:
JSON formatted embedding request.
"""
@abc.abstractmethod
def _gcs_patch_embedding_request(
self,
bearer_token: str,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Returns GCS patch embedding request.
Args:
bearer_token: Bearer for store requests.
slide_embedding: GCS embedding inputs.
Returns:
JSON formatted embedding request.
"""
def prepare_embedding_request(
self,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> PreparedVertexEmbeddingRequest:
first_patch = slide_embedding.patches[0].patch
if len(slide_embedding.patches) > 1:
first_patch_type = type(first_patch)
for p in slide_embedding.patches:
if not isinstance(p.patch, first_patch_type):
raise ez_wsi_errors.InternalError(
'Patch in request are not all the same type.'
)
if isinstance(first_patch, dicom_slide.DicomPatch):
bearer_token = slide_embedding.get_bearer_token()
return PreparedVertexEmbeddingRequest(
self._dicom_patch_embedding_request(
bearer_token,
slide_embedding,
),
slide_embedding,
)
elif isinstance(first_patch, gcs_image.GcsPatch):
bearer_token = slide_embedding.get_bearer_token()
return PreparedVertexEmbeddingRequest(
self._gcs_patch_embedding_request(
bearer_token,
slide_embedding,
),
slide_embedding,
)
raise ez_wsi_errors.InternalError(
'Patch is not a dicom_slide.DicomPatch or gcs_image.GcsPatch.'
)
@property
def send_gcs_patch_bytes_from_client_to_server(self) -> bool:
return self._send_gcs_patch_bytes_from_client_to_server
@retrying.retry(**error_retry_util.HTTP_AUTH_ERROR_RETRY_CONFIG)
@retrying.retry(**error_retry_util.HTTP_SERVER_ERROR_RETRY_CONFIG)
def _request_embeddings(self, json_msg: str) -> str:
"""Sends json request to Vertex AI endpoint."""
try:
headers = self.vertex_endpoint_authentication_header()
headers['Content-Length'] = f'{len(json_msg)}'
headers['Content-Type'] = 'application/json'
response = requests.post(
self._end_point_url,
headers=headers,
data=json_msg,
timeout=self.timeout,
)
# Raises a HTTPError if the response code was not 200
response.raise_for_status()
return response.text
except requests.exceptions.HTTPError as exp:
ez_wsi_errors.raise_ez_wsi_http_exception(exp.response.reason, exp)
except requests.exceptions.Timeout as timeout_error:
raise ez_wsi_errors.HttpRequestTimeoutError(
str(timeout_error), 'Request Timeout'
) from timeout_error
@abc.abstractmethod
def _is_request_error_retryable(
self, json_response: Mapping[str, Any]
) -> bool:
"""Returns true if error at request level is retryable."""
@abc.abstractmethod
def _decode_response(
self,
embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
json_response: Mapping[str, Any],
) -> _VertexModelResult:
"""Decodes json_response response from Vertex AI endpoint into _VertexModelResult."""
@abc.abstractmethod
def _instance_has_retryable_error(self, json_dict: Mapping[str, Any]) -> bool:
"""Decodes response from Vertex AI endpoint into _VertexModelResult."""
def _regenerate_instance_with_new_auth_token(
self, prepared_request: PreparedVertexEmbeddingRequest
) -> PreparedVertexEmbeddingRequest:
prepared_request = self.prepare_embedding_request(
prepared_request.slide_embedding_source
)
prepared_request.finalize()
return prepared_request
def _merge_embedding_input_results(
self, model_result: _VertexModelResult, partial_result: _VertexModelResult
) -> _VertexModelResult:
partial_result_counter = 0
for index in range(len(model_result.instances)):
if self._instance_has_retryable_error(model_result.instances[index]):
model_result.instances[index] = partial_result.instances[
partial_result_counter
]
partial_result_counter += 1
return model_result
def _regenerate_embedding_input_requests_with_new_auth_token(
self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
) -> Tuple[str, Sequence[PreparedVertexEmbeddingRequest]]:
embedding_inputs = [
self._regenerate_instance_with_new_auth_token(ei)
for ei in embedding_inputs
]
return self.get_embedding_request(embedding_inputs), embedding_inputs
def _retry_failed_embedding_input_requests(
self,
model_result: _VertexModelResult,
embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
):
retry_list = []
for index, embedding_input in enumerate(embedding_inputs):
if self._instance_has_retryable_error(model_result.instances[index]):
retry_list.append(embedding_input)
if not retry_list:
return '', retry_list
return self._regenerate_embedding_input_requests_with_new_auth_token(
retry_list
)
def request_embeddings(
self,
embedding_inputs: Sequence[
AbstractPreparedEmbeddingRequest[_VertexModelResult]
],
) -> _VertexModelResult:
"""Requests embeddings from Vertex AI endpoint.
Method is roboust to expiration of authentication tokens in either instance
or endpoint. Will retry three times to attempt to correct authentication
issues.
Args:
embedding_inputs: Prepared list of instances (images) with patches to get
embeddings.
Returns:
VertexModelResult, (List of JSON) results for each instance.
"""
if not embedding_inputs:
return _VertexModelResult([])
embedding_inputs = typing.cast(
Sequence[PreparedVertexEmbeddingRequest], embedding_inputs
)
attempts = 0
# generates initial embedding request
json_msg = self.get_embedding_request(embedding_inputs)
if not json_msg and not embedding_inputs:
return _VertexModelResult([])
request_embedding_inputs = embedding_inputs
model_result = None
# This retry loop handles authentication errors encountered by
# the endpoint attempting to connect to data sources for which
# expired tokens have been provided. In this case we are connecting
# to the endpoint, but the endpoint is unable to fulfill the request
# completely and is returning errors in one or more of its responses.
# More general authentication and retry logic for connecting to the endpoint
# is handled by the decorators attached to _request_embeddings.
while True:
# Get get JSON response from Vertex AI endpoint.
json_response = self._request_embeddings(json_msg)
attempts += 1
# Decode JSON response.
try:
json_response = json.loads(json_response)
except json.JSONDecodeError as exp:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Endpoint returned invalid JSON.'
) from exp
# If error at response level is found and retryable, regenerate request
# whole request with new authentication token and retry.
if self._is_request_error_retryable(json_response):
json_msg, request_embedding_inputs = (
self._regenerate_embedding_input_requests_with_new_auth_token(
request_embedding_inputs
)
)
continue
# Decode json response into _VertexModelResult.
result = self._decode_response(request_embedding_inputs, json_response)
if model_result is None:
model_result = result
else:
# if results already processed merge with prexisiting results.
self._merge_embedding_input_results(model_result, result)
if attempts >= 3:
# if on third attempt, then just return result.
# should never take more than two attempts.
return model_result
# Retry to retrieve embeddings only for the instances that failed.
# with retriable errors.
json_msg, request_embedding_inputs = (
self._retry_failed_embedding_input_requests(
model_result, embedding_inputs
)
)
if not json_msg:
# if json message is empty then nothing to do but return.
if model_result is None:
# return empty result should never occure.
return _VertexModelResult([])
else:
return model_result
class V1PatchEmbeddingEndpoint(AbstractVertexPatchEmbeddingEndpointBase):
"""Implements Patch embedding V1 API."""
def __init__(
self,
patch_width: int = 224,
patch_height: int = 224,
endpoint_api_version: str = 'v1', # Vertex API version
project_id: str = 'hai-cd3-foundations',
endpoint_location: str = 'us-central1',
endpoint_id: str = '160',
max_threads: int = _DEFAULT_ENDPOINT_THREADS,
max_patches_per_request: int = _DEFAULT_MAX_PATCHES_PER_REQUEST,
retry_count: int = _DEFAULT_RETRY_COUNT,
send_gcs_patch_bytes_from_client_to_server: bool = False,
credential_factory: Optional[
credential_factory_module.AbstractCredentialFactory
] = None,
timeout: Optional[Union[int, float]] = _DEFAULT_TIMEOUT,
):
end_point: List[str] = [
f'https://{endpoint_location}-aiplatform.googleapis.com',
endpoint_api_version,
'projects',
project_id,
'locations',
endpoint_location,
'endpoints',
f'{endpoint_id}:predict',
]
end_point_url = '/'.join([ep.strip('/') for ep in end_point])
super().__init__(
patch_width,
patch_height,
IccProfileNormalization.NONE,
send_gcs_patch_bytes_from_client_to_server,
end_point_url,
max_threads,
max_patches_per_request,
_MAX_V1_ENDPOINT_PATCHES_PER_REQUEST,
retry_count,
credential_factory,
timeout,
)
self._model_size = 'MEDIUM'
self._model_kind = 'LOW_PIXEL_SPACING'
def _dicom_patch_embedding_request(
self,
bearer_token: str,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Generates Embedding request for image stored in DICOM Store."""
patch = typing.cast(
dicom_slide.DicomPatch, slide_embedding.patches[0].patch
)
source_series = patch.source
path = source_series.path
instance_uids, required_levels = (
_get_dicom_instance_uids_and_required_levels(slide_embedding)
)
if patch.is_resized:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'V1 encoder does not support image level resizing.'
)
if not bearer_token:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'V1 encoder does not support empty bearer tokens.'
)
dicom_store_url = path.GetStorePath().complete_url
if dicom_store_url.startswith('https://healthcare.googleapis.com/v1/'):
# If calling Google DICOM store using V1 endpoint use V1 endpoint legacy
# format when generating embeddings against a V1 Google DICOM store.
dicom_store_url = f'projects/{path.project_id}/locations/{path.location}/datasets/{path.dataset_id}/dicomStores/{path.store_id}'
return {
EndpointJsonKeys.DICOM_WEB_STORE_URL: dicom_store_url,
EndpointJsonKeys.DICOM_STUDY_UID: path.study_uid,
EndpointJsonKeys.DICOM_SERIES_UID: path.series_uid,
EndpointJsonKeys.BEARER_TOKEN: bearer_token,
EndpointJsonKeys.EZ_WSI_STATE: source_series.json_metadata_dict(
level_subset=required_levels,
max_json_encoded_icc_profile_size=_MAX_DICOM_SLIDE_ICCPROFILE_METADATA_SIZE,
),
EndpointJsonKeys.INSTANCE_UIDS: instance_uids,
EndpointJsonKeys.PATCH_COORDINATES: [
{
EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
EndpointJsonKeys.WIDTH: int(input.patch.width),
EndpointJsonKeys.HEIGHT: int(input.patch.height),
}
for input in slide_embedding.patches
],
}
def _gcs_patch_embedding_request(
self,
bearer_token: str,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Generates Embedding request for image stored in DICOM Store."""
json_metadata = _get_gcs_image_metadata(
self.max_request_size_bytes(),
self.send_gcs_patch_bytes_from_client_to_server,
self.icc_profile_normalization,
self.icc_profile_bytes(),
slide_embedding,
)
gcs_patch = typing.cast(
gcs_image.GcsPatch, slide_embedding.patches[0].patch
)
uri = gcs_patch.source.uri
if gcs_patch.is_resized:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'V1 encoder does not support image image resizing.'
)
if not bearer_token and gcs_patch.source.uri:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'V1 encoder does not support empty bearer tokens.'
)
return {
EndpointJsonKeys.PROJECT_NAME: 'placeholder',
EndpointJsonKeys.GCS_IMAGE_URL: uri,
EndpointJsonKeys.BEARER_TOKEN: bearer_token,
EndpointJsonKeys.EZ_WSI_STATE: json_metadata,
EndpointJsonKeys.PATCH_COORDINATES: [
{
EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
EndpointJsonKeys.WIDTH: int(input.patch.width),
EndpointJsonKeys.HEIGHT: int(input.patch.height),
}
for input in slide_embedding.patches
],
}
def _validate_embedding_response(
self,
embedding_source: patch_embedding_types.SlideEmbeddingSource,
embedding_response: Mapping[str, Any],
):
"""Validate embedding DICOM UID match patch request."""
if embedding_source.patches and isinstance(
embedding_source.patches[0].patch, dicom_slide.DicomPatch
):
# Test StudyInstanceUID and SeriesInstanceUID from request and response
# match if returning result for DICOM slide.
patch = typing.cast(
dicom_slide.DicomPatch, embedding_source.patches[0].patch
)
path = patch.source.path
if (
embedding_response[EndpointJsonKeys.DICOM_STUDY_UID] != path.study_uid
or embedding_response[EndpointJsonKeys.DICOM_SERIES_UID]
!= path.series_uid
):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Study or Series UID of embedding does not match request.'
)
def _is_request_error_retryable(
self, json_response: Mapping[str, Any]
) -> bool:
"""Returns true if error at request level is retryable."""
try:
predictions = json_response[EndpointJsonKeys.PREDICTIONS]
if isinstance(predictions, Mapping):
# This is raw response vertex endpoint translates response
# to alternative format
error = predictions[EndpointJsonKeys.ERROR_RESPONSE]
else:
# This is response as translasted by vertex endpoint
returned_slide_embeddings, error, ml_version = predictions
del returned_slide_embeddings, ml_version
return error == EndpointJsonKeys.INVALID_CREDENTIALS
except (KeyError, ValueError, TypeError) as _:
return False
def _decode_response(
self,
embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
json_response: Mapping[str, Any],
) -> _VertexModelResult:
"""Decodes response from Vertex AI endpoint into _VertexModelResult."""
try:
predictions = json_response[EndpointJsonKeys.PREDICTIONS]
if isinstance(predictions, Mapping):
# This is raw response vertex endpoint translates response
# to alternative format
returned_slide_embeddings = predictions[
EndpointJsonKeys.EMBEDDING_RESULT
]
error = predictions[EndpointJsonKeys.ERROR_RESPONSE]
else:
# This is response as translasted by vertex endpoint
returned_slide_embeddings, error, _ = predictions
except (KeyError, ValueError, TypeError) as exp:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Endpoint returned incorrectly formatted JSON.'
) from exp
if error is not None and error:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
f'Endpoint returned error: {error}'
)
# Test the number of slide embedding responses matches the request.
if len(embedding_inputs) != len(returned_slide_embeddings):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Number of embedding responses received does not match number of'
f' embedding requests; expected: {len(embedding_inputs)}; received:'
f' {len(returned_slide_embeddings)}.'
)
return _VertexModelResult(returned_slide_embeddings)
def _instance_has_retryable_error(self, json_dict: Mapping[str, Any]) -> bool:
return False
def process_response(
self,
embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
msg: _VertexModelResult,
) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
"""Returns patch embedding results for input and returned embeddings."""
result = []
endpoint_patch_width = self.patch_width()
endpoint_patch_height = self.patch_height()
for patch_embeddings, instance_input in zip(
msg.instances, embedding_inputs
):
self._validate_embedding_response(instance_input, patch_embeddings)
patch_embeddings = patch_embeddings[EndpointJsonKeys.PATCH_EMBEDDINGS]
# Test the number of patches received for the slide matches the request.
if len(patch_embeddings) != len(instance_input.patches):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Number of patches in embedding response does not match request;'
f' expected: {len(instance_input.patches)}; received:'
f' {len(patch_embeddings)}.'
)
for patch_embedding, source in zip(
patch_embeddings, instance_input.patches
):
pc = patch_embedding[EndpointJsonKeys.PATCH_COORDINATE]
# Test the coodinates of the patch matches the request.
if not _test_patch_coordinates_match(
pc,
source.patch.x,
source.patch.y,
endpoint_patch_width,
endpoint_patch_height,
):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Embedding patch coordinates or dimensions do not match request.'
)
embedding_value = np.asarray(
patch_embedding[EndpointJsonKeys.EMBEDDINGS]
)
result.append(
patch_embedding_types.PatchEmbeddingEnsembleResult(
source, embedding_value, None
)
)
return result
def get_embedding_request(
self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
) -> str:
"""Returns JSON formmatted embedding request."""
instances = ','.join(i.json for i in embedding_inputs)
return f'{{"{EndpointJsonKeys.PARAMETERS}":{{"{EndpointJsonKeys.MODEL_SIZE}":"{self._model_size}","{EndpointJsonKeys.MODEL_KIND}":"{self._model_kind}"}},"{EndpointJsonKeys.INSTANCES}":[{instances}]}}'
def _gen_v2_extensions(
require_fully_in_source_image: bool,
image_dimensions: Optional[dicom_slide.ImageDimensions],
icc_profile: IccProfileNormalization,
ez_wsi_state: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Returns extensions for pathology embeddings."""
extension = {
EndpointJsonKeys.REQUIRE_PATCHES_FULLY_IN_SOURCE_IMAGE: str(
require_fully_in_source_image
)
}
if image_dimensions is not None:
extension[EndpointJsonKeys.IMAGE_DIMENSIONS] = dataclasses.asdict(
image_dimensions
)
if icc_profile:
extension[EndpointJsonKeys.TRANSFORM_IMAGING_TO_ICC_PROFILE] = str(
icc_profile.value
)
if ez_wsi_state:
extension[EndpointJsonKeys.EZ_WSI_STATE] = ez_wsi_state
return extension
def _format_error_message(error_code: str, error_description: str) -> str:
if not error_description:
return f'Error code: {error_code}'
return f'Error code: {error_code}; {error_description}'
class V2PatchEmbeddingEndpoint(AbstractVertexPatchEmbeddingEndpointBase):
"""Implements Patch embedding V2 API."""
def __init__(
self,
patch_width: int = 224,
patch_height: int = 224,
endpoint_api_version: str = 'v1', # Vertex API version
project_id: str = 'hai-cd3-foundations',
endpoint_location: str = 'us-central1',
endpoint_id: str = '162',
max_threads: int = _DEFAULT_ENDPOINT_THREADS,
max_patches_per_request: int = _DEFAULT_MAX_PATCHES_PER_REQUEST,
retry_count: int = _DEFAULT_RETRY_COUNT,
icc_profile_normalization: IccProfileNormalization = (
IccProfileNormalization.NONE
),
send_gcs_patch_bytes_from_client_to_server: bool = False,
require_fully_in_source_image: bool = True,
credential_factory: Optional[
credential_factory_module.AbstractCredentialFactory
] = None,
timeout: Optional[Union[int, float]] = _DEFAULT_TIMEOUT,
):
end_point: List[str] = [
f'https://{endpoint_location}-aiplatform.googleapis.com',
endpoint_api_version,
'projects',
project_id,
'locations',
endpoint_location,
'endpoints',
f'{endpoint_id}:predict',
]
end_point_url = '/'.join([ep.strip('/') for ep in end_point])
super().__init__(
patch_width,
patch_height,
icc_profile_normalization,
send_gcs_patch_bytes_from_client_to_server,
end_point_url,
max_threads,
max_patches_per_request,
_MAX_V2_ENDPOINT_PATCHES_PER_REQUEST,
retry_count,
credential_factory,
timeout,
)
self._require_fully_in_source_image = require_fully_in_source_image
def _dicom_patch_embedding_request(
self,
bearer_token: str,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Generates Embedding request for image stored in DICOM Store."""
patch = typing.cast(
dicom_slide.DicomPatch, slide_embedding.patches[0].patch
)
instance_uids, required_levels = (
_get_dicom_instance_uids_and_required_levels(slide_embedding)
)
if patch.is_resized:
image_dimensions = dicom_slide.ImageDimensions(
patch.level.width, patch.level.height
)
else:
image_dimensions = None
request = {
EndpointJsonKeys.DICOM_PATH: {
EndpointJsonKeys.SERIES_PATH: str(
patch.source.path.GetSeriesPath()
),
EndpointJsonKeys.INSTANCE_UIDS: instance_uids,
},
EndpointJsonKeys.EXTENSIONS: _gen_v2_extensions(
self._require_fully_in_source_image,
image_dimensions,
self._icc_profile_normalization,
patch.source.json_metadata_dict(
level_subset=required_levels,
max_json_encoded_icc_profile_size=_MAX_DICOM_SLIDE_ICCPROFILE_METADATA_SIZE,
),
),
EndpointJsonKeys.PATCH_COORDINATES: [
{
EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
EndpointJsonKeys.WIDTH: int(input.patch.width),
EndpointJsonKeys.HEIGHT: int(input.patch.height),
}
for input in slide_embedding.patches
],
}
if bearer_token:
request[EndpointJsonKeys.BEARER_TOKEN] = bearer_token
return request
def _gcs_patch_embedding_request(
self,
bearer_token: str,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
"""Generates Embedding request for image stored in DICOM Store."""
json_metadata = _get_gcs_image_metadata(
self.max_request_size_bytes(),
self.send_gcs_patch_bytes_from_client_to_server,
self.icc_profile_normalization,
self.icc_profile_bytes(),
slide_embedding,
)
gcs_patch = typing.cast(
gcs_image.GcsPatch, slide_embedding.patches[0].patch
)
uri = gcs_patch.source.uri
request = {
EndpointJsonKeys.IMAGE_FILE_URI: (
uri if uri else 'gs:///placeholder.png'
),
EndpointJsonKeys.EXTENSIONS: _gen_v2_extensions(
self._require_fully_in_source_image,
gcs_patch.source.resize_dimensions,
self._icc_profile_normalization,
json_metadata,
),
EndpointJsonKeys.PATCH_COORDINATES: [
{
EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
EndpointJsonKeys.WIDTH: int(input.patch.width),
EndpointJsonKeys.HEIGHT: int(input.patch.height),
}
for input in slide_embedding.patches
],
}
if bearer_token:
request[EndpointJsonKeys.BEARER_TOKEN] = bearer_token
return request
def get_embedding_request(
self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
) -> str:
"""Returns JSON formmatted embedding request."""
instances = ','.join(i.json for i in embedding_inputs)
return f'{{"{EndpointJsonKeys.INSTANCES}":[{instances}]}}'
def _is_request_error_retryable(
self, json_response: Mapping[str, Any]
) -> bool:
"""Returns true if error at request level is retryable."""
try:
error_code = json_response[EndpointJsonKeys.VERTEXAI_ERROR]
if isinstance(error_code, dict):
error_code = error_code[EndpointJsonKeys.ERROR_CODE]
return error_code == EndpointJsonKeys.INVALID_CREDENTIALS
except (KeyError, ValueError, TypeError):
return False
def _decode_response(
self,
embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
json_response: Mapping[str, Any],
) -> _VertexModelResult:
"""Decodes response from Vertex AI endpoint into _VertexModelResult."""
try:
returned_slide_embeddings = json_response[EndpointJsonKeys.PREDICTIONS]
except (KeyError, ValueError, TypeError):
try:
error_code = json_response[EndpointJsonKeys.VERTEXAI_ERROR]
error_description = ''
if isinstance(error_code, dict):
error_description = error_code.get(
EndpointJsonKeys.ERROR_CODE_DESCRIPTION, ''
)
error_code = error_code[EndpointJsonKeys.ERROR_CODE]
if isinstance(error_code, str):
msg = _format_error_message(error_code, error_description)
raise ez_wsi_errors.PatchEmbeddingEndpointError( # pylint: disable=raise-missing-from
f'Endpoint error; {msg}'
)
except (KeyError, ValueError, TypeError):
pass
raise ez_wsi_errors.PatchEmbeddingEndpointError( # pylint: disable=raise-missing-from
'Endpoint did not return a valid JSON response.'
)
# Test the number of slide embedding responses matches the request.
if len(embedding_inputs) != len(returned_slide_embeddings):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Number of embedding responses received does not match number of'
f' embedding requests; expected: {len(embedding_inputs)}; received:'
f' {len(returned_slide_embeddings)}.'
)
return _VertexModelResult(returned_slide_embeddings)
def _instance_has_retryable_error(self, json_dict: Mapping[str, Any]) -> bool:
"""Decodes response from Vertex AI endpoint into _VertexModelResult."""
error = json_dict.get(EndpointJsonKeys.ERROR)
if error is None:
return False
try:
return (
error[EndpointJsonKeys.ERROR_CODE]
== EndpointJsonKeys.INVALID_CREDENTIALS
)
except (KeyError, TypeError, IndexError) as _:
return False
def process_response(
self,
embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
msg: _VertexModelResult,
) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
"""Returns patch embedding results for input and returned embeddings."""
result = []
endpoint_patch_width = self.patch_width()
endpoint_patch_height = self.patch_height()
for returned_instance, instance_input in zip(
msg.instances, embedding_inputs
):
try:
error = returned_instance.get(EndpointJsonKeys.ERROR)
if error is not None:
error_code = error[EndpointJsonKeys.ERROR_CODE]
error_description = error.get(
EndpointJsonKeys.ERROR_CODE_DESCRIPTION, ''
)
error_message = '\n'.join([
'Endpoint error generating instance embeddings.',
f'Endpoint: {self.end_point_url}',
f'{_format_error_message(error_code, error_description)}',
])
error = patch_embedding_types.PatchEmbeddingError(
error_code, error_message
)
# Return PatchEmbeddingEnsembleResult with an error
# for each expected patch. Errors will be raised when
# embedding values from the patches are accessed. Typically this will
# occure almost immediately after during ensemble reduction. However,
# this will also enable callers using PatchEmbeddingSequence
# (indexed based) to access data for instances which may succeed
# after a instance fails.
for patch_source in instance_input.patches:
result.append(
patch_embedding_types.PatchEmbeddingEnsembleResult(
patch_source, None, error
)
)
continue
patch_embeddings = returned_instance[EndpointJsonKeys.RESULT][
EndpointJsonKeys.PATCH_EMBEDDINGS
]
# Test the number of patches received for the slide matches the request.
if len(patch_embeddings) != len(instance_input.patches):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Number of patches in embedding response does not match request;'
f' expected: {len(instance_input.patches)}; received:'
f' {len(patch_embeddings)}.'
)
for patch_embedding, patch_source in zip(
patch_embeddings, instance_input.patches
):
pc = patch_embedding[EndpointJsonKeys.PATCH_COORDINATE]
# Test the coodinates of the patch matches the request.
if not _test_patch_coordinates_match(
pc,
patch_source.patch.x,
patch_source.patch.y,
endpoint_patch_width,
endpoint_patch_height,
):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Embedding patch coordinates or dimensions do not match'
' request.'
)
embedding_value = np.asarray(
patch_embedding[EndpointJsonKeys.EMBEDDING_VECTOR]
)
result.append(
patch_embedding_types.PatchEmbeddingEnsembleResult(
patch_source, embedding_value, None
)
)
except (KeyError, IndexError, TypeError, ValueError) as exp:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Endpoint returned an unexpected response.'
) from exp
return result
class PreparedLocalEmbeddingRequest(
AbstractPreparedEmbeddingRequest[np.ndarray]
):
"""Base class for prepared embedding requests."""
def __init__(
self,
slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
endpoint_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor],
icc_profile_bytes: bytes,
icc_profile_cache: cachetools.LRUCache,
icc_profile_cache_lock: threading.Lock,
require_fully_in_source_image: bool,
):
super().__init__(slide_embedding_source)
self._input_patch_bytes = []
self._input_patch_bytes_future = None
self._thread_pool = endpoint_thread_pool
self._target_icc_profile_bytes = icc_profile_bytes
self._icc_profile_cache = icc_profile_cache
self._icc_profile_cache_lock = icc_profile_cache_lock
self._require_fully_in_source_image = require_fully_in_source_image
@property
def json_size_in_bytes(self) -> int:
return 0
def _load_patch_bytes(self) -> List[np.ndarray]:
"""Loads patch bytes for each request."""
if (
self._slide_embedding_source is None
or not self._slide_embedding_source.patches
):
return []
first_patch = self._slide_embedding_source.patches[0].patch
if isinstance(first_patch, gcs_image.GcsPatch):
# Make a shallow copy of source patch Source GcsImage to ensure
# if image bytes are are not retained after the request is processed.
source_copy = copy.copy(first_patch.source)
image_bytes = []
icc_profile_normalization = source_copy.create_icc_profile_transformation(
self._target_icc_profile_bytes
)
for patch in self._slide_embedding_source.patches:
p = patch.patch
if (
self._require_fully_in_source_image
and not p.is_patch_fully_in_source_image()
):
raise ez_wsi_errors.PatchOutsideOfImageDimensionsError(
'A portion of the patch does not overlap the image.'
)
# create a patch with same coordinates a the temporary source.
# require in source image is not relevant to the temp patch. Setting to
# false.
temp_patch = gcs_image.GcsPatch(
source_copy, p.x, p.y, p.width, p.height, False
)
image_bytes.append(temp_patch.image_bytes(icc_profile_normalization))
return image_bytes
if isinstance(first_patch, dicom_slide.DicomPatch):
# Load in slide DICOM slide imaging using frame cache that is not shared
# to scope imageing bytes to the embedding request and also avoid.
# possible LRU cache eviction across parallel reads.
# Make a shallow copy of source patch Source DicomSlide or
# or Dicom Microscopy image
source_copy = copy.copy(first_patch.source)
# Init new frame cache on the shallow copy source.
# Source and copy will no longer share the cache.
fc = source_copy.init_slide_frame_cache()
# Construct list of patches to return embedding for.
patch_list = typing.cast(
List[dicom_slide.DicomPatch],
[patch.patch for patch in self._slide_embedding_source.patches],
)
# Preload list of patches into frame cache. Copy across any loaded
# imaging that was loaded in the original cache.
source_copy.preload_patches_in_frame_cache(
patch_list, False, first_patch.source.slide_frame_cache
)
if (
self._target_icc_profile_bytes is None
or not self._target_icc_profile_bytes
):
icc_profile_normalization = None
else:
dicom_path = str(source_copy.path)
with self._icc_profile_cache_lock:
source_icc_profile_bytes = self._icc_profile_cache.get(dicom_path)
if source_icc_profile_bytes is None:
if isinstance(source_copy, dicom_slide.DicomSlide):
source_icc_profile_bytes = source_copy.get_icc_profile_bytes()
elif isinstance(source_copy, dicom_slide.DicomMicroscopeImage):
source_icc_profile_bytes = source_copy.get_level_icc_profile_bytes(
first_patch.level
)
else:
raise ValueError('Unexpected object')
with self._icc_profile_cache_lock:
self._icc_profile_cache[dicom_path] = source_icc_profile_bytes
icc_profile_normalization = (
dicom_slide.create_icc_profile_transformation(
source_icc_profile_bytes, self._target_icc_profile_bytes
)
)
fc.block_until_frames_are_loaded()
# Generate image bytes for each patch.
image_bytes = []
for p in patch_list:
if (
self._require_fully_in_source_image
and not p.is_patch_fully_in_source_image()
):
raise ez_wsi_errors.PatchOutsideOfImageDimensionsError(
'A portion of the patch does not overlap the image.'
)
# create copy of the patch.
# set the patch to point to the copied source to make patch image
# retrieval read from the copied sources frame cache.
temp_patch = dicom_slide.DicomPatch(
p.get_pyramid_imaging_source_level(),
p.x,
p.y,
p.width,
p.height,
source_copy,
p.level,
# ensuring patch falls inside image dimensions is not relevant
# to the temp patch.
False,
)
image_bytes.append(temp_patch.image_bytes(icc_profile_normalization))
return image_bytes
raise ValueError('Unexpected object')
@property
def input_patch_bytes(self) -> List[np.ndarray]:
if self._input_patch_bytes_future is not None:
self._input_patch_bytes = self._input_patch_bytes_future.result()
self._input_patch_bytes_future = None
return self._input_patch_bytes
def finalize(self) -> None:
if self._thread_pool is None:
self._input_patch_bytes_future = None
self._input_patch_bytes = self._load_patch_bytes()
return
self._input_patch_bytes_future = self._thread_pool.submit(
self._load_patch_bytes
)
def split(
self, endpoint: AbstractPatchEmbeddingEndpoint[np.ndarray], max_size: int
) -> Tuple[
Optional[AbstractPreparedEmbeddingRequest[np.ndarray]],
patch_embedding_types.SlideEmbeddingSource,
]:
# splitting is not relevant for local endpoints.
# return an unsplit response.
if self._slide_embedding_source is None:
raise ValueError('Slide embedding source is None.')
return None, self._slide_embedding_source
class LocalEndpoint(AbstractPatchEmbeddingEndpoint[np.ndarray]):
"""Endpoint for generating embeddings with a locally loaded model."""
def __init__(
self,
model: Callable[[np.ndarray], np.ndarray],
icc_profile_normalization: IccProfileNormalization = (
IccProfileNormalization.NONE
),
patch_width: int = 224,
patch_height: int = 224,
require_fully_in_source_image: bool = True,
max_threads: int = _DEFAULT_ENDPOINT_THREADS,
retry_count: int = _DEFAULT_RETRY_COUNT,
max_patches_per_request: int = _DEFAULT_MAX_PATCHES_PER_REQUEST,
dicom_instance_icc_profile_cache_count: int = _DEFAULT_DICOM_INSTANCE_ICC_PROFILE_CACHE_COUNT,
timeout: Optional[Union[int, float]] = _DEFAULT_TIMEOUT,
):
super().__init__(icc_profile_normalization, timeout)
self._require_fully_in_source_image = require_fully_in_source_image
self._patch_width = patch_width
self._patch_height = patch_height
self._max_threads = max(1, max_threads)
self._max_patches_per_request = int(max(1, max_patches_per_request))
self._retry_count = max(0, retry_count)
if self._max_threads < 2:
self._thread_pool = None
else:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_threads
)
self._dicom_instance_icc_profile_cache_count = (
dicom_instance_icc_profile_cache_count
)
self._model = model
self._icc_profile_cache = cachetools.LRUCache(
self._dicom_instance_icc_profile_cache_count
)
self._icc_profile_cache_lock = threading.Lock()
def __del__(self):
if self._thread_pool is not None:
self._thread_pool.shutdown(wait=False, cancel_futures=True) # pylint: disable=attribute-error
def __getstate__(self) -> MutableMapping[str, Any]:
"""Returns class state for pickle serialization."""
state = copy.copy(self.__dict__)
del state['_thread_pool']
del state['_icc_profile_cache']
del state['_icc_profile_cache_lock']
return state
def __setstate__(self, dct: MutableMapping[str, Any]) -> None:
"""Init class state from pickle serialization."""
self.__dict__ = dct
if self._max_threads < 2:
self._thread_pool = None
else:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_threads
)
self._icc_profile_cache = cachetools.LRUCache(
self._dicom_instance_icc_profile_cache_count
)
self._icc_profile_cache_lock = threading.Lock()
def max_request_size_bytes(self) -> int:
"""Maximum size in bytes that can be sent in single request."""
return 0xFFFFFFFFFFFFFFFF # max size uint64
def retry_count(self) -> int:
return self._retry_count
def max_threads(self) -> int:
return self._max_threads
def patch_width(self) -> int:
return self._patch_width
def patch_height(self) -> int:
return self._patch_height
def max_number_of_patches_per_request(self) -> int:
return self._max_patches_per_request
def endpoint_max_number_of_patches_per_request(self) -> int:
"""Maximum number of patches that can be sent to the endpoint at once."""
return self._max_patches_per_request
def prepare_embedding_request(
self,
slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> PreparedLocalEmbeddingRequest:
return PreparedLocalEmbeddingRequest(
slide_embedding,
self._thread_pool,
self.icc_profile_bytes(),
self._icc_profile_cache,
self._icc_profile_cache_lock,
self._require_fully_in_source_image,
)
def normalize_imaging(self, input_patch_bytes: np.ndarray) -> np.ndarray:
"""Normalizes input patch bytes to float32 in range [0, 1]."""
return input_patch_bytes.astype(np.float32) / _UINT8_MAX_VALUE
def generate_ml_input(
self,
embedding_inputs: Sequence[AbstractPreparedEmbeddingRequest[np.ndarray]],
) -> np.ndarray:
"""Generates ML input for local model."""
if not embedding_inputs:
return np.zeros((), dtype=np.float32)
normalized_imaging_list = []
patch_width = self.patch_width()
patch_height = self.patch_height()
for e in embedding_inputs:
e = typing.cast(PreparedLocalEmbeddingRequest, e)
for single_patch in e.input_patch_bytes:
normalized_imaging_list.append(
np.expand_dims(
normalized_patch_channels(
patch_width,
patch_height,
self.normalize_imaging(single_patch),
),
axis=0,
)
)
return np.concatenate(normalized_imaging_list, axis=0)
def request_embeddings(
self,
embedding_inputs: Sequence[AbstractPreparedEmbeddingRequest[np.ndarray]],
) -> np.ndarray:
"""Returns embeddings for input patches."""
ml_input = self.generate_ml_input(embedding_inputs)
if not ml_input.shape:
return ml_input
return self._model(ml_input)
def process_response(
self,
embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
msg: np.ndarray,
) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
"""Converts raw embedding response to list of embedding results."""
if not bool(msg.shape):
generated_embedding_count = 0
else:
generated_embedding_count = msg.shape[0]
total_patches = sum(len(i.patches) for i in embedding_inputs)
if total_patches != generated_embedding_count:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Number of patches in embedding response does not match request.'
)
results = []
index = 0
for instance_input in embedding_inputs:
for patch_source in instance_input.patches:
results.append(
patch_embedding_types.PatchEmbeddingEnsembleResult(
patch_source, msg[index, ...], None
)
)
index += 1
return results