ez_wsi_dicomweb/patch_embedding_ensemble_methods.py (293 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Methods to generate embeddings for patches dim >= embedding input dim.""" import abc import enum from typing import Iterator, Sequence, Tuple, Union import uuid from ez_wsi_dicomweb import ez_wsi_errors from ez_wsi_dicomweb import patch_embedding_endpoints from ez_wsi_dicomweb import patch_embedding_types import numpy as np class SinglePatchEnsemblePosition(enum.Enum): UPPER_LEFT = 'UPPER_LEFT' UPPER_RIGHT = 'UPPER_RIGHT' CENTER = 'CENTER' LOWER_LEFT = 'LOWER_LEFT' LOWER_RIGHT = 'LOWER_RIGHT' _ReducedType = Union[ Sequence[patch_embedding_types.EmbeddingResult], Sequence[patch_embedding_types.PatchEmbeddingEnsembleResult], ] class PatchEnsembleMethod(metaclass=abc.ABCMeta): """Defines operation to define & combine regions of patch to gen embeddings.""" def __init__(self): self._ensemble_id_base = str(uuid.uuid4()) self._ensemble_id = 1 def _get_ensemble_id(self) -> str: val = self._ensemble_id self._ensemble_id += 1 return f'{self._ensemble_id_base}-{val}' def _validate_patch_dimensions( self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, ) -> None: """Validates that patch position and input dimensions are valid. Args: endpoint: Patch embedding endpoint. patch: Patch to embedd. Raises: ez_wsi_errors.PatchEmbeddingDimensionError: Patch input falls outside of slide image or patch dimensions are <= endpoint input dimensions. """ # Test patch dimensions >= embedding endpoint input dimensions. if ( endpoint.patch_width() > patch.width or endpoint.patch_height() > patch.height ): raise ez_wsi_errors.PatchEmbeddingDimensionError( f'Patch dimensions ({patch.width}, {patch.height}) are less than ' f' embedding input dimensions ({endpoint.patch_width()},' f' {endpoint.patch_height()}).' ) if patch.x < 0 or patch.y < 0 or patch.width <= 0 or patch.height <= 0: raise ez_wsi_errors.PatchEmbeddingDimensionError( f'Invalid patch dimensions ({patch.x}, {patch.y} to' f' {patch.x + patch.width -1}, {patch.y + patch.height -1} ).' ) @abc.abstractmethod def generate_ensemble( self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, ) -> Iterator[patch_embedding_types.PatchEmbeddingSource]: """Yields iterator of patches of embedding dim to gen embedding for patch. Args: endpoint: Embedding endpoint used to generate patch embeddings. patch: Input pixel region to generate an embedding. Yields: PatchEmbeddingSource that define one or more sub patches that are required to generate an embedding for the patch. """ @abc.abstractmethod def reduce_ensemble( self, patch: patch_embedding_types.EmbeddingPatch, ensemble_list: _ReducedType, ) -> patch_embedding_types.EmbeddingResult: """Returns single embedding result from ensemble of patch embeddings. Args: patch: Input pixel region embedding was generated from ensemble_list: List of embedding results generated within patch Returns: Single embedding result for patch. """ def _raise_if_error( ensemble_result: Union[ patch_embedding_types.EmbeddingResult, patch_embedding_types.PatchEmbeddingEnsembleResult, ], ) -> None: """Raises if patch_embedding_types.PatchEmbeddingEnsembleResult has error.""" if ( isinstance( ensemble_result, patch_embedding_types.PatchEmbeddingEnsembleResult ) and ensemble_result.error is not None ): raise ez_wsi_errors.PatchEmbeddingEndpointError( ensemble_result.error.error_message ) def _get_sub_patch_position( endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, position: SinglePatchEnsemblePosition, ) -> Tuple[int, int]: """Return sub_patch position within a patch defined by enum.""" if position == SinglePatchEnsemblePosition.UPPER_LEFT: pos_x, pos_y = patch.x, patch.y elif position == SinglePatchEnsemblePosition.UPPER_RIGHT: pos_x, pos_y = patch.x + patch.width - endpoint.patch_width(), patch.y elif position == SinglePatchEnsemblePosition.LOWER_LEFT: pos_x, pos_y = patch.x, patch.y + patch.height - endpoint.patch_height() elif position == SinglePatchEnsemblePosition.LOWER_RIGHT: pos_x = patch.x + patch.width - endpoint.patch_width() pos_y = patch.y + patch.height - endpoint.patch_height() elif position == SinglePatchEnsemblePosition.CENTER: pos_x = int(patch.x + (patch.width - endpoint.patch_width()) / 2) pos_y = int(patch.y + (patch.height - endpoint.patch_height()) / 2) else: raise ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError( 'Invalid SinglePatchEnsemblePosition.' ) pos_x = int( max(min(pos_x, patch.x + patch.width - endpoint.patch_width()), patch.x) ) pos_y = int( max( min(pos_y, patch.y + patch.height - endpoint.patch_height()), patch.y, ) ) return pos_x, pos_y class SinglePatchEnsemble(PatchEnsembleMethod): """Returns embedding generated from a single patch.""" def __init__(self, position: SinglePatchEnsemblePosition): """SinglePatchEnsemble Constructor. Args: position: Position of patch to generate embedding. Raises: ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError: Invalid SinglePatchEnsemblePosition. """ super().__init__() self._position = position try: if position not in SinglePatchEnsemblePosition: raise ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError( 'Invalid SinglePatchEnsemblePosition.' ) except TypeError as e: raise ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError( 'Invalid SinglePatchEnsemblePosition.' ) from e def generate_ensemble( self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, ) -> Iterator[patch_embedding_types.PatchEmbeddingSource]: """Yields iterator of patches of embedding dim to gen embedding for patch. Args: endpoint: Embedding endpoint used to generate patch embeddings. patch: Input pixel region to generate an embedding. Yields: PatchEmbeddingSource that define one or more sub patches that are required to generate an embedding for the patch. Raises: ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError: Invalid SinglePatchEnsemblePosition. ez_wsi_errors.PatchEmbeddingDimensionError: Patch input falls outside of slide image or patch dimensions are <= endpoint input dimensions. """ self._validate_patch_dimensions(endpoint, patch) ensemble_id = self._get_ensemble_id() pos_x, pos_y = _get_sub_patch_position(endpoint, patch, self._position) yield patch_embedding_types.PatchEmbeddingSource( patch.get_patch( pos_x, pos_y, endpoint.patch_width(), endpoint.patch_height(), ), patch, ensemble_id, ) def reduce_ensemble( self, patch: patch_embedding_types.EmbeddingPatch, ensemble_list: _ReducedType, ) -> patch_embedding_types.EmbeddingResult: """Returns single embedding result from ensemble of patch embeddings. Args: patch: Input pixel region embedding was generated from ensemble_list: List of embedding results generated within patch Returns: Single embedding result for patch. Raises: ez_wsi_errors.SinglePatchEmbeddingEnsembleError: Ensemble results did not retuurn one embedding. """ if len(ensemble_list) != 1: raise ez_wsi_errors.SinglePatchEmbeddingEnsembleError( 'SinglePatchEnsemble requires exactly one embedding result.' ) _raise_if_error(ensemble_list[0]) return patch_embedding_types.EmbeddingResult( patch, ensemble_list[0].embedding ) class DefaultSinglePatchEnsemble(SinglePatchEnsemble): """Returns single embedding for patch, validates patch dim = embedding dim.""" def __init__(self): super().__init__(SinglePatchEnsemblePosition.UPPER_LEFT) def generate_ensemble( self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, ) -> Iterator[patch_embedding_types.PatchEmbeddingSource]: if ( endpoint.patch_width() != patch.width or endpoint.patch_height() != patch.height ): raise ez_wsi_errors.PatchEmbeddingDimensionError( f'Patch dimensions ({patch.width}, {patch.height}) do not match' f' endpoint embedding input dimensions ({endpoint.patch_width()},' f' {endpoint.patch_height()}). To generate embeddings from patches' ' that are not the same as the endpoint input dimensions, set the' ' embedding method "ensemble_method" parameter' ' (e.g., MeanPatchEmbeddingEnsemble, FivePatchMeanEnsemble or' ' SinglePatchEnsemble).' ) return super().generate_ensemble(endpoint, patch) class FivePatchMeanEnsemble(PatchEnsembleMethod): """Returns mean embedding from five patches sampled across the patch.""" def generate_ensemble( self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, ) -> Iterator[patch_embedding_types.PatchEmbeddingSource]: self._validate_patch_dimensions(endpoint, patch) ensemble_id = self._get_ensemble_id() endpoint_width = endpoint.patch_width() endpoint_height = endpoint.patch_height() if patch.width == endpoint_width and patch.height == endpoint_height: # if patches overlap perfectly, just yield one. sampling_positions = [SinglePatchEnsemblePosition.UPPER_LEFT] else: sampling_positions = [ SinglePatchEnsemblePosition.UPPER_LEFT, SinglePatchEnsemblePosition.UPPER_RIGHT, SinglePatchEnsemblePosition.CENTER, SinglePatchEnsemblePosition.LOWER_LEFT, SinglePatchEnsemblePosition.LOWER_RIGHT, ] for position in sampling_positions: pos_x, pos_y = _get_sub_patch_position(endpoint, patch, position) yield patch_embedding_types.PatchEmbeddingSource( patch.get_patch( pos_x, pos_y, endpoint_width, endpoint_height, ), patch, ensemble_id, ) def reduce_ensemble( self, patch: patch_embedding_types.EmbeddingPatch, ensemble_list: _ReducedType, ) -> patch_embedding_types.EmbeddingResult: if not ensemble_list: raise ez_wsi_errors.MeanPatchEmbeddingEnsembleError( 'MeanPatchEmbeddingEnsemble requires at least one embedding result.' ) first_result = ensemble_list[0] _raise_if_error(first_result) embedding_dtype = first_result.embedding.dtype dtype_cast = ( not np.issubdtype(embedding_dtype, np.floating) or embedding_dtype.itemsize < 8 ) embedding = np.zeros( first_result.embedding.shape, dtype=np.float64 if dtype_cast else embedding_dtype, ) for result in ensemble_list: embedding += result.embedding embedding /= float(len(ensemble_list)) if dtype_cast: embedding = embedding.astype(embedding_dtype) return patch_embedding_types.EmbeddingResult(patch, embedding) class MeanPatchEmbeddingEnsemble(FivePatchMeanEnsemble): """Returns mean embedding from set of embeddings sampled across the patch.""" def __init__(self, step_x_px: int, step_y_px: int): """MeanPatchEmbeddingEnsemble Constructor. Args: step_x_px: Step size in x direction to sample patch for embedding. step_y_px: Step size in y direction to sample patch for embedding. Raises: ez_wsi_errors.SinglePatchEmbeddingEnsemblePositionError: Invalid step size. """ super().__init__() self._step_x = step_x_px self._step_y = step_y_px if self._step_x <= 0 or self._step_y <= 0: raise ez_wsi_errors.MeanPatchEmbeddingEnsembleError( 'MeanPatchEmbeddingEnsemble requires a minimum of 1 px patch step.' ) def generate_ensemble( self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint, patch: patch_embedding_types.EmbeddingPatch, ) -> Iterator[patch_embedding_types.PatchEmbeddingSource]: self._validate_patch_dimensions(endpoint, patch) embedding_width = endpoint.patch_width() embedding_height = endpoint.patch_height() start_x = patch.x start_y = patch.y end_x = max(start_x, start_x + patch.width - embedding_width) end_y = max(start_y, start_y + patch.height - embedding_height) ensemble_id = self._get_ensemble_id() for y in range(start_y, end_y, self._step_y): for x in range(start_x, end_x, self._step_x): yield patch_embedding_types.PatchEmbeddingSource( patch.get_patch( x, y, embedding_width, embedding_height, ), patch, ensemble_id, ) def mean_patch_embedding( embeddings: Union[ Iterator[patch_embedding_types.EmbeddingResult], Sequence[patch_embedding_types.EmbeddingResult], ], ) -> np.ndarray: """Returns mean embedding from list of or iterator of embedding results.""" if isinstance(embeddings, Sequence): embeddings = iter(embeddings) try: result = next(embeddings) except StopIteration: raise ez_wsi_errors.MeanPatchEmbeddingEnsembleError( 'MeanPatchEmbeddingEnsemble requires at least one embedding result.' ) from None embedding_dtype = result.embedding.dtype dtype_cast = ( not np.issubdtype(embedding_dtype, np.floating) or embedding_dtype.itemsize < 8 ) embedding = np.zeros( result.embedding.shape, dtype=np.float64 if dtype_cast else embedding_dtype, ) embedding += result.embedding count = 1 for result in embeddings: embedding += result.embedding count += 1 embedding /= float(count) if dtype_cast: embedding = embedding.astype(embedding_dtype) return embedding