ez_wsi_dicomweb/patch_embedding_ensemble_methods.py (293 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Methods to generate embeddings for patches dim >= embedding input dim."""
import abc
import enum
from typing import Iterator, Sequence, Tuple, Union
import uuid
from ez_wsi_dicomweb import ez_wsi_errors
from ez_wsi_dicomweb import patch_embedding_endpoints
from ez_wsi_dicomweb import patch_embedding_types
import numpy as np
class SinglePatchEnsemblePosition(enum.Enum):
UPPER_LEFT = 'UPPER_LEFT'
UPPER_RIGHT = 'UPPER_RIGHT'
CENTER = 'CENTER'
LOWER_LEFT = 'LOWER_LEFT'
LOWER_RIGHT = 'LOWER_RIGHT'
_ReducedType = Union[
Sequence[patch_embedding_types.EmbeddingResult],
Sequence[patch_embedding_types.PatchEmbeddingEnsembleResult],
]
class PatchEnsembleMethod(metaclass=abc.ABCMeta):
"""Defines operation to define & combine regions of patch to gen embeddings."""
def __init__(self):
self._ensemble_id_base = str(uuid.uuid4())
self._ensemble_id = 1
def _get_ensemble_id(self) -> str:
val = self._ensemble_id
self._ensemble_id += 1
return f'{self._ensemble_id_base}-{val}'
def _validate_patch_dimensions(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
) -> None:
"""Validates that patch position and input dimensions are valid.
Args:
endpoint: Patch embedding endpoint.
patch: Patch to embedd.
Raises:
ez_wsi_errors.PatchEmbeddingDimensionError: Patch input falls outside
of slide image or patch dimensions are <= endpoint input dimensions.
"""
# Test patch dimensions >= embedding endpoint input dimensions.
if (
endpoint.patch_width() > patch.width
or endpoint.patch_height() > patch.height
):
raise ez_wsi_errors.PatchEmbeddingDimensionError(
f'Patch dimensions ({patch.width}, {patch.height}) are less than '
f' embedding input dimensions ({endpoint.patch_width()},'
f' {endpoint.patch_height()}).'
)
if patch.x < 0 or patch.y < 0 or patch.width <= 0 or patch.height <= 0:
raise ez_wsi_errors.PatchEmbeddingDimensionError(
f'Invalid patch dimensions ({patch.x}, {patch.y} to'
f' {patch.x + patch.width -1}, {patch.y + patch.height -1} ).'
)
@abc.abstractmethod
def generate_ensemble(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
) -> Iterator[patch_embedding_types.PatchEmbeddingSource]:
"""Yields iterator of patches of embedding dim to gen embedding for patch.
Args:
endpoint: Embedding endpoint used to generate patch embeddings.
patch: Input pixel region to generate an embedding.
Yields:
PatchEmbeddingSource that define one or more sub patches that are
required to generate an embedding for the patch.
"""
@abc.abstractmethod
def reduce_ensemble(
self,
patch: patch_embedding_types.EmbeddingPatch,
ensemble_list: _ReducedType,
) -> patch_embedding_types.EmbeddingResult:
"""Returns single embedding result from ensemble of patch embeddings.
Args:
patch: Input pixel region embedding was generated from
ensemble_list: List of embedding results generated within patch
Returns:
Single embedding result for patch.
"""
def _raise_if_error(
ensemble_result: Union[
patch_embedding_types.EmbeddingResult,
patch_embedding_types.PatchEmbeddingEnsembleResult,
],
) -> None:
"""Raises if patch_embedding_types.PatchEmbeddingEnsembleResult has error."""
if (
isinstance(
ensemble_result, patch_embedding_types.PatchEmbeddingEnsembleResult
)
and ensemble_result.error is not None
):
raise ez_wsi_errors.PatchEmbeddingEndpointError(
ensemble_result.error.error_message
)
def _get_sub_patch_position(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
position: SinglePatchEnsemblePosition,
) -> Tuple[int, int]:
"""Return sub_patch position within a patch defined by enum."""
if position == SinglePatchEnsemblePosition.UPPER_LEFT:
pos_x, pos_y = patch.x, patch.y
elif position == SinglePatchEnsemblePosition.UPPER_RIGHT:
pos_x, pos_y = patch.x + patch.width - endpoint.patch_width(), patch.y
elif position == SinglePatchEnsemblePosition.LOWER_LEFT:
pos_x, pos_y = patch.x, patch.y + patch.height - endpoint.patch_height()
elif position == SinglePatchEnsemblePosition.LOWER_RIGHT:
pos_x = patch.x + patch.width - endpoint.patch_width()
pos_y = patch.y + patch.height - endpoint.patch_height()
elif position == SinglePatchEnsemblePosition.CENTER:
pos_x = int(patch.x + (patch.width - endpoint.patch_width()) / 2)
pos_y = int(patch.y + (patch.height - endpoint.patch_height()) / 2)
else:
raise ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError(
'Invalid SinglePatchEnsemblePosition.'
)
pos_x = int(
max(min(pos_x, patch.x + patch.width - endpoint.patch_width()), patch.x)
)
pos_y = int(
max(
min(pos_y, patch.y + patch.height - endpoint.patch_height()),
patch.y,
)
)
return pos_x, pos_y
class SinglePatchEnsemble(PatchEnsembleMethod):
"""Returns embedding generated from a single patch."""
def __init__(self, position: SinglePatchEnsemblePosition):
"""SinglePatchEnsemble Constructor.
Args:
position: Position of patch to generate embedding.
Raises:
ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError: Invalid
SinglePatchEnsemblePosition.
"""
super().__init__()
self._position = position
try:
if position not in SinglePatchEnsemblePosition:
raise ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError(
'Invalid SinglePatchEnsemblePosition.'
)
except TypeError as e:
raise ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError(
'Invalid SinglePatchEnsemblePosition.'
) from e
def generate_ensemble(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
) -> Iterator[patch_embedding_types.PatchEmbeddingSource]:
"""Yields iterator of patches of embedding dim to gen embedding for patch.
Args:
endpoint: Embedding endpoint used to generate patch embeddings.
patch: Input pixel region to generate an embedding.
Yields:
PatchEmbeddingSource that define one or more sub patches that are
required to generate an embedding for the patch.
Raises:
ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError: Invalid
SinglePatchEnsemblePosition.
ez_wsi_errors.PatchEmbeddingDimensionError: Patch input falls outside
of slide image or patch dimensions are <= endpoint input dimensions.
"""
self._validate_patch_dimensions(endpoint, patch)
ensemble_id = self._get_ensemble_id()
pos_x, pos_y = _get_sub_patch_position(endpoint, patch, self._position)
yield patch_embedding_types.PatchEmbeddingSource(
patch.get_patch(
pos_x,
pos_y,
endpoint.patch_width(),
endpoint.patch_height(),
),
patch,
ensemble_id,
)
def reduce_ensemble(
self,
patch: patch_embedding_types.EmbeddingPatch,
ensemble_list: _ReducedType,
) -> patch_embedding_types.EmbeddingResult:
"""Returns single embedding result from ensemble of patch embeddings.
Args:
patch: Input pixel region embedding was generated from
ensemble_list: List of embedding results generated within patch
Returns:
Single embedding result for patch.
Raises:
ez_wsi_errors.SinglePatchEmbeddingEnsembleError: Ensemble results did not
retuurn one embedding.
"""
if len(ensemble_list) != 1:
raise ez_wsi_errors.SinglePatchEmbeddingEnsembleError(
'SinglePatchEnsemble requires exactly one embedding result.'
)
_raise_if_error(ensemble_list[0])
return patch_embedding_types.EmbeddingResult(
patch, ensemble_list[0].embedding
)
class DefaultSinglePatchEnsemble(SinglePatchEnsemble):
"""Returns single embedding for patch, validates patch dim = embedding dim."""
def __init__(self):
super().__init__(SinglePatchEnsemblePosition.UPPER_LEFT)
def generate_ensemble(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
) -> Iterator[patch_embedding_types.PatchEmbeddingSource]:
if (
endpoint.patch_width() != patch.width
or endpoint.patch_height() != patch.height
):
raise ez_wsi_errors.PatchEmbeddingDimensionError(
f'Patch dimensions ({patch.width}, {patch.height}) do not match'
f' endpoint embedding input dimensions ({endpoint.patch_width()},'
f' {endpoint.patch_height()}). To generate embeddings from patches'
' that are not the same as the endpoint input dimensions, set the'
' embedding method "ensemble_method" parameter'
' (e.g., MeanPatchEmbeddingEnsemble, FivePatchMeanEnsemble or'
' SinglePatchEnsemble).'
)
return super().generate_ensemble(endpoint, patch)
class FivePatchMeanEnsemble(PatchEnsembleMethod):
"""Returns mean embedding from five patches sampled across the patch."""
def generate_ensemble(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
) -> Iterator[patch_embedding_types.PatchEmbeddingSource]:
self._validate_patch_dimensions(endpoint, patch)
ensemble_id = self._get_ensemble_id()
endpoint_width = endpoint.patch_width()
endpoint_height = endpoint.patch_height()
if patch.width == endpoint_width and patch.height == endpoint_height:
# if patches overlap perfectly, just yield one.
sampling_positions = [SinglePatchEnsemblePosition.UPPER_LEFT]
else:
sampling_positions = [
SinglePatchEnsemblePosition.UPPER_LEFT,
SinglePatchEnsemblePosition.UPPER_RIGHT,
SinglePatchEnsemblePosition.CENTER,
SinglePatchEnsemblePosition.LOWER_LEFT,
SinglePatchEnsemblePosition.LOWER_RIGHT,
]
for position in sampling_positions:
pos_x, pos_y = _get_sub_patch_position(endpoint, patch, position)
yield patch_embedding_types.PatchEmbeddingSource(
patch.get_patch(
pos_x,
pos_y,
endpoint_width,
endpoint_height,
),
patch,
ensemble_id,
)
def reduce_ensemble(
self,
patch: patch_embedding_types.EmbeddingPatch,
ensemble_list: _ReducedType,
) -> patch_embedding_types.EmbeddingResult:
if not ensemble_list:
raise ez_wsi_errors.MeanPatchEmbeddingEnsembleError(
'MeanPatchEmbeddingEnsemble requires at least one embedding result.'
)
first_result = ensemble_list[0]
_raise_if_error(first_result)
embedding_dtype = first_result.embedding.dtype
dtype_cast = (
not np.issubdtype(embedding_dtype, np.floating)
or embedding_dtype.itemsize < 8
)
embedding = np.zeros(
first_result.embedding.shape,
dtype=np.float64 if dtype_cast else embedding_dtype,
)
for result in ensemble_list:
embedding += result.embedding
embedding /= float(len(ensemble_list))
if dtype_cast:
embedding = embedding.astype(embedding_dtype)
return patch_embedding_types.EmbeddingResult(patch, embedding)
class MeanPatchEmbeddingEnsemble(FivePatchMeanEnsemble):
"""Returns mean embedding from set of embeddings sampled across the patch."""
def __init__(self, step_x_px: int, step_y_px: int):
"""MeanPatchEmbeddingEnsemble Constructor.
Args:
step_x_px: Step size in x direction to sample patch for embedding.
step_y_px: Step size in y direction to sample patch for embedding.
Raises:
ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError: Invalid
step size.
"""
super().__init__()
self._step_x = step_x_px
self._step_y = step_y_px
if self._step_x <= 0 or self._step_y <= 0:
raise ez_wsi_errors.MeanPatchEmbeddingEnsembleError(
'MeanPatchEmbeddingEnsemble requires a minimum of 1 px patch step.'
)
def generate_ensemble(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
) -> Iterator[patch_embedding_types.PatchEmbeddingSource]:
self._validate_patch_dimensions(endpoint, patch)
embedding_width = endpoint.patch_width()
embedding_height = endpoint.patch_height()
start_x = patch.x
start_y = patch.y
end_x = max(start_x, start_x + patch.width - embedding_width)
end_y = max(start_y, start_y + patch.height - embedding_height)
ensemble_id = self._get_ensemble_id()
for y in range(start_y, end_y, self._step_y):
for x in range(start_x, end_x, self._step_x):
yield patch_embedding_types.PatchEmbeddingSource(
patch.get_patch(
x,
y,
embedding_width,
embedding_height,
),
patch,
ensemble_id,
)
def mean_patch_embedding(
embeddings: Union[
Iterator[patch_embedding_types.EmbeddingResult],
Sequence[patch_embedding_types.EmbeddingResult],
],
) -> np.ndarray:
"""Returns mean embedding from list of or iterator of embedding results."""
if isinstance(embeddings, Sequence):
embeddings = iter(embeddings)
try:
result = next(embeddings)
except StopIteration:
raise ez_wsi_errors.MeanPatchEmbeddingEnsembleError(
'MeanPatchEmbeddingEnsemble requires at least one embedding result.'
) from None
embedding_dtype = result.embedding.dtype
dtype_cast = (
not np.issubdtype(embedding_dtype, np.floating)
or embedding_dtype.itemsize < 8
)
embedding = np.zeros(
result.embedding.shape,
dtype=np.float64 if dtype_cast else embedding_dtype,
)
embedding += result.embedding
count = 1
for result in embeddings:
embedding += result.embedding
count += 1
embedding /= float(count)
if dtype_cast:
embedding = embedding.astype(embedding_dtype)
return embedding