ez_wsi_dicomweb/gcs_image.py (491 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates Images and Patches from tradtional image formats stored on GCS."""
from __future__ import annotations
import base64
import binascii
import copy
import dataclasses
import io
import threading
import typing
from typing import Any, Dict, MutableMapping, Optional, Union
import cv2
from ez_wsi_dicomweb import credential_factory as credential_factory_module
from ez_wsi_dicomweb import dicom_slide
from ez_wsi_dicomweb import error_retry_util
from ez_wsi_dicomweb import ez_wsi_errors
from ez_wsi_dicomweb import slide_level_map
import google.api_core.exceptions
import google.auth
import google.cloud.storage
import numpy as np
import PIL
from PIL import ImageCms
import retrying
ImageDimensions = slide_level_map.ImageDimensions
_RGB = 'RGB'
@dataclasses.dataclass(frozen=True)
class _GcsImageState:
width: int
height: int
image_bytes: np.ndarray
icc_color_profile: Optional[bytes] = None
_CoreImageTypes = Union[str, np.ndarray, bytes]
GcsImageSourceTypes = Union[
_CoreImageTypes,
google.cloud.storage.Blob,
google.cloud.storage.blob.Blob,
]
def _gcs_image_json_metadata(
image: Union[GcsPatch, GcsImage],
color_transform: Optional[ImageCms.ImageCmsTransform] = None,
) -> str:
"""Converts uncompressed RGB image to PNG and returns base64 encoded bytes."""
image_bytes = image.image_bytes(color_transform)
mode = {1: 'L', 3: _RGB}.get(
dicom_slide.get_image_bytes_samples_per_pixel(image_bytes)
)
if mode is None:
raise ez_wsi_errors.GcsImageError(
f'Unsupported image samples per pixel; image shape: {image_bytes.shape}'
)
with io.BytesIO() as compressed_bytes:
with PIL.Image.frombytes(
mode=mode,
size=(image.width, image.height),
data=image_bytes.tobytes(),
decoder_name='raw',
) as pil_image:
pil_image.save(compressed_bytes, format='PNG')
return base64.b64encode(compressed_bytes.getvalue()).decode('utf-8')
class GcsPatch(dicom_slide.BasePatch):
"""Represents a patch stored in GCS."""
def __init__(
self,
source: GcsImage,
x: int,
y: int,
width: int,
height: int,
require_fully_in_source_image: bool = False,
):
super().__init__(x, y, width, height)
self._source = source
if (
require_fully_in_source_image
and not self.is_patch_fully_in_source_image()
):
raise ez_wsi_errors.PatchOutsideOfImageDimensionsError(
'A portion of the patch does not overlap the image.'
)
self._require_fully_in_source_image = require_fully_in_source_image
@classmethod
def create_from_json(
cls,
json_metadata: str,
require_fully_in_source_image: bool = False,
source_image_dimension: Optional[ImageDimensions] = None,
) -> GcsPatch:
try:
img = GcsImage(base64.b64decode(json_metadata, validate=True))
except (binascii.Error, ValueError) as exp:
raise ez_wsi_errors.GcsImageError('Error decoding image bytes') from exp
patch = GcsPatch(
img,
0,
0,
img.width,
img.height,
require_fully_in_source_image=require_fully_in_source_image,
)
if (
require_fully_in_source_image
and source_image_dimension is not None
and not patch.is_patch_fully_in_source_image_dim(
source_image_dimension.width_px, source_image_dimension.height_px
)
):
raise ez_wsi_errors.PatchOutsideOfImageDimensionsError(
'A portion of the patch does not overlap the image.'
)
return patch
def __eq__(self, other: Any) -> bool:
if not isinstance(other, GcsPatch):
return False
return (
self.x == other.x
and self.y == other.y
and self.width == other.width
and self.height == other.height
and self._source == other._source
)
@property
def is_resized(self) -> bool:
return self._source.is_resized
@property
def source(self) -> GcsImage:
return self._source
def image_bytes(
self, color_transform: Optional[ImageCms.ImageCmsTransform] = None
) -> np.ndarray:
"""Returns the patch's image bytes."""
image_bytes = self._source.image_bytes()
if len(image_bytes.shape) == 2:
image_bytes = np.expand_dims(image_bytes, axis=-1)
samples_per_pixel = dicom_slide.get_image_bytes_samples_per_pixel(
image_bytes
)
cropped_image = np.zeros(
(self.height, self.width, samples_per_pixel), dtype=image_bytes.dtype
)
y_start = min(max(self.y, 0), self._source.height)
x_start = min(max(self.x, 0), self._source.width)
y_end = min(max(self.y + self.height, 0), self._source.height)
x_end = min(max(self.x + self.width, 0), self._source.width)
copy_width = x_end - x_start
copy_height = y_end - y_start
if copy_width > 0 and copy_height > 0:
dx = max(0, -self.x)
dy = max(0, -self.y)
sx = max(0, self.x)
sy = max(0, self.y)
cropped_image[dy : dy + copy_height, dx : dx + copy_width, :] = (
image_bytes[sy : sy + copy_height, sx : sx + copy_width, :]
)
return dicom_slide.transform_image_bytes_color(
cropped_image, color_transform
)
def get_patch(
self,
x: int,
y: int,
width: int,
height: int,
require_fully_in_source_image: Optional[bool] = None,
) -> GcsPatch:
"""Returns a patch at the specified location and size."""
require_fully_in_source_image = (
self._require_fully_in_source_image
if require_fully_in_source_image is None
else require_fully_in_source_image
)
return GcsPatch(
self._source,
x,
y,
width,
height,
require_fully_in_source_image=require_fully_in_source_image,
)
def is_patch_fully_in_source_image(self) -> bool:
return self.is_patch_fully_in_source_image_dim(
self._source.width, self._source.height
)
def get_gcp_data_credential_header(
self, credential: Optional[google.auth.credentials.Credentials] = None
) -> Dict[str, str]:
"""Returns the credential header patch requests."""
return self.source.get_credential_header(credential)
def json_metadata(
self,
color_transform: Optional[ImageCms.ImageCmsTransform] = None,
) -> str:
return _gcs_image_json_metadata(self, color_transform)
class GcsImage:
"""Represents an image stored in GCS."""
def __init__(
self,
image_source: GcsImageSourceTypes,
credential_factory: Optional[
credential_factory_module.AbstractCredentialFactory
] = None,
image_dimensions: Optional[ImageDimensions] = None,
):
"""Initializes the GcsImage.
GCS image represents a image in tradtional image format (PNG, JPEG, etc)
that is stored in GCS or passed in directly.
Args:
image_source: Image source can be a str representing a gs style path, e.g.
gs://bucket/path/to/image.png; or numpy array containing uncompressed
RGB or single channel image; or bytes that contain the compressed
bytes of a tradtional image format, e.g. PNG or jpeg bytes.
credential_factory: Credential factory that returns credentials to use
reading from GCS.
image_dimensions: Image dimensions to resize images to.
Raises:
GcsImageError: If image source is not supported.
"""
self._gcs_image_lock = threading.RLock()
self._icc_color_profile = None
self._credentials = None
self._source_image_compressed_bytes_size = 0
self._source_image_compressed_bytes = b''
self._image_resize_dims = (
None if image_dimensions is None else image_dimensions.copy()
)
self._are_image_bytes_resized = False
if not isinstance(image_source, _CoreImageTypes):
# unit testing framwork makes direct testing of blob type difficult,
# test that assume image is blob if not other expected types.
image_source = typing.cast(google.cloud.storage.Blob, image_source)
image_source = f'gs://{image_source.bucket.name}/{image_source.name}'
if isinstance(image_source, np.ndarray):
self._gcs_uri = ''
self._credential_factory = (
credential_factory_module.NoAuthCredentialsFactory()
)
self._image_bytes = image_source.copy()
try:
self._height, self._width = self._image_bytes.shape[:2]
except ValueError as exp:
raise ez_wsi_errors.GcsImageError(
f'Unsupported image shape: {self._image_bytes.shape}'
) from exp
if image_source.dtype != np.uint8:
raise ez_wsi_errors.GcsImageError(
f'Unsupported image dtype: {image_source.dtype}'
)
samples_per_pixel = dicom_slide.get_image_bytes_samples_per_pixel(
self._image_bytes
)
if samples_per_pixel not in (1, 3, 4):
raise ez_wsi_errors.GcsImageError(
f'Unsupported image samples per pixel: {samples_per_pixel}'
)
if samples_per_pixel == 4:
# If present Remove alpha channel
self._image_bytes = self._image_bytes[:, :, :3]
self._bytes_pre_pixel = samples_per_pixel
self._resize()
return
# uninitialized values
self._image_bytes = None
self._width = -1
self._height = -1
self._bytes_pre_pixel = -1
if isinstance(image_source, bytes):
if not image_source:
raise ez_wsi_errors.GcsImageError('Image bytes is empty.')
self._gcs_uri = ''
self._credential_factory = (
credential_factory_module.NoAuthCredentialsFactory()
)
self._source_image_compressed_bytes = image_source
self._init_compressed_image_bytes(image_source)
return
try:
# test gcs url formatting looks correct.
google.cloud.storage.Blob.from_string(image_source)
except ValueError as exp:
raise ez_wsi_errors.GcsImagePathFormatError(
f'Invalid GCS URI: {image_source}'
) from exp
self._gcs_uri = image_source
self._credential_factory = (
credential_factory_module.DefaultCredentialFactory()
if credential_factory is None
else credential_factory
)
if image_dimensions is not None:
self._width = image_dimensions.width_px
self._height = image_dimensions.height_px
@property
def is_resized(self) -> bool:
"""Returns true if image dimensions have/maybe resized."""
with self._gcs_image_lock:
if self.are_image_bytes_loaded:
# Image dimensions were resized.
return self._are_image_bytes_resized
# If image bytes are not loaded then use the definition of the image
# dimensions as a proxy to avoid loading actual image bytes.
# return true if image dimensions have been defined. Its clearly possible
# the defined resize dimensions and the actual image dimensions may
# be the same.
return self._image_resize_dims is not None # safe to access outside of lock
@property
def resize_dimensions(self) -> Optional[ImageDimensions]:
# Never modified outside of constructor. Safe to access outside of lock.
return self._image_resize_dims
def _resize(self):
"""Resizes image to resize_dims if provided."""
resize_dims = self._image_resize_dims
if resize_dims is None or self._image_bytes is None:
return
height, width = self._image_bytes.shape[:2]
if width == resize_dims.width_px and height == resize_dims.height_px:
return
if resize_dims.width_px > width or resize_dims.height_px > height:
resize_method = cv2.INTER_CUBIC
else:
resize_method = cv2.INTER_AREA
self._are_image_bytes_resized = True
self._width = resize_dims.width_px
self._height = resize_dims.height_px
self._image_bytes = cv2.resize(
self._image_bytes, (self._width, self._height), resize_method
)
@property
def are_image_bytes_loaded(self) -> bool:
return self._image_bytes is not None
def _init_compressed_image_bytes(
self, source_image_compressed_bytes: bytes
) -> None:
"""Initialize image bytes from images compressed bytes."""
self._source_image_compressed_bytes_size = len(
source_image_compressed_bytes
)
with io.BytesIO(source_image_compressed_bytes) as image_bytes:
try:
with PIL.Image.open(image_bytes) as image:
if image.mode in ('YCbCr', 'CMYK', 'HSV', 'LAB', 'RGBA'):
image = image.convert('RGB')
if image.mode == 'L':
self._bytes_pre_pixel = 1
elif image.mode == 'RGB':
self._bytes_pre_pixel = 3
else:
raise ez_wsi_errors.GcsImageError(
f'Unsupported image mode: {image.mode}'
)
self._icc_color_profile = image.info.get('icc_profile')
self._image_bytes = np.asarray(image)
self._width, self._height = image.size
self._resize()
except PIL.UnidentifiedImageError as exp:
raise ez_wsi_errors.GcsImageError(
'Error decoding image bytes.'
) from exp
@property
def credentials(self) -> google.auth.credentials.Credentials:
with self._gcs_image_lock:
if self._credentials is None:
self._credentials = self._credential_factory.get_credentials()
else:
credential_factory_module.refresh_credentials(
self._credentials, self._credential_factory
)
return self._credentials
@retrying.retry(**error_retry_util.HTTP_AUTH_ERROR_RETRY_CONFIG)
@retrying.retry(**error_retry_util.HTTP_SERVER_ERROR_RETRY_CONFIG)
def _get_gcs_image(self) -> _GcsImageState:
"""Returns image width, height, bytes, and icc profile."""
with self._gcs_image_lock:
if self._image_bytes is not None:
return _GcsImageState(
self._width,
self._height,
self._image_bytes,
self._icc_color_profile,
)
try:
credentials = self.credentials
if not credentials.token or isinstance(
self._credential_factory,
credential_factory_module.NoAuthCredentialsFactory,
):
client = google.cloud.storage.Client.create_anonymous_client()
else:
client = google.cloud.storage.Client(credentials=self.credentials)
gcs_blob = google.cloud.storage.Blob.from_string(
self._gcs_uri,
client=client,
)
raw_bytes = gcs_blob.download_as_bytes(raw_download=True)
except google.api_core.exceptions.GoogleAPICallError as exp:
raise ez_wsi_errors.raise_ez_wsi_http_exception(exp.message, exp)
self._init_compressed_image_bytes(raw_bytes)
return _GcsImageState(
self._width, self._height, self._image_bytes, self._icc_color_profile
)
@classmethod
def create_from_json(cls, json_metadata: str) -> GcsImage:
try:
return GcsImage(base64.b64decode(json_metadata, validate=True))
except (binascii.Error, ValueError) as exp:
raise ez_wsi_errors.GcsImageError('Error decoding image bytes') from exp
@property
def size_bytes_of_source_image(self) -> Optional[int]:
with self._gcs_image_lock:
if self._source_image_compressed_bytes_size == 0:
return None
return self._source_image_compressed_bytes_size
def _get_source_image_bytes_from_file(self) -> bytes:
return b''
def source_image_bytes_json_metadata(self) -> str:
"""Returns bytes encoding source image.
Raises:
GcsImageError: If source image bytes are not set.
"""
with self._gcs_image_lock:
if self._are_image_bytes_resized:
raise ez_wsi_errors.GcsImageError(
'Source image bytes have been resized. Source image metadata is not'
' available.'
)
image_bytes = self._source_image_compressed_bytes
if not image_bytes:
image_bytes = self._get_source_image_bytes_from_file()
if image_bytes:
return base64.b64encode(image_bytes).decode('utf-8')
raise ez_wsi_errors.GcsImageError(
'Source image bytes are not set. Source image metadata is not'
' available.'
)
def clear_source_image_compressed_bytes(self) -> None:
"""Clears source image compressed bytes."""
# Source image compressed bytes are used to call embedding api
# when it is optimal pass a representation of the whole image.
# This function enables the source bytes to be cleared to save
# working memory.
with self._gcs_image_lock:
self._source_image_compressed_bytes = b''
def json_metadata(
self,
color_transform: Optional[ImageCms.ImageCmsTransform] = None,
) -> str:
return _gcs_image_json_metadata(self, color_transform)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, GcsImage):
return False
if self._gcs_uri and other._gcs_uri:
return (
self._gcs_uri == other._gcs_uri
and self._image_resize_dims == other._image_resize_dims
)
if self._gcs_uri and not self.are_image_bytes_loaded:
self._get_gcs_image() # load images from GCS if required.
if other._gcs_uri and not other.are_image_bytes_loaded:
other._get_gcs_image()
if self.width != other.width or self.height != other.height:
return False
# if loaded image bytes never cleared safe to access outside of lock.
return np.array_equal(self._image_bytes, other._image_bytes)
def __getstate__(self) -> MutableMapping[str, Any]:
"""Returns class state for pickle serialization."""
state = copy.copy(self.__dict__)
del state['_credentials']
del state['_gcs_image_lock']
return state
def __setstate__(self, dct: MutableMapping[str, Any]) -> None:
"""Init class state from pickle serialization."""
self.__dict__ = dct
self._credentials = None
self._gcs_image_lock = threading.RLock()
def get_credential_header(
self, credentials: Optional[google.auth.credentials.Credentials] = None
) -> Dict[str, str]:
"""Returns credential header for retrieval of GCS image.
Args:
credentials: Optional credential to use if not provided will use
credentials provided by credential factory encode the bearer token
provided to the GcsImage constructor.
Raises:
InvalidCredentialsError: If credential factory is not initialized.
"""
headers = {}
if credentials is None:
credentials = self.credentials
else:
credential_factory_module.refresh_credentials(credentials)
credentials.apply(headers)
return headers
@property
def uri(self) -> str:
# Once initialized never changed. Safe to access outside of lock.
return self._gcs_uri
def get_patch(
self,
x: int,
y: int,
width: int,
height: int,
require_fully_in_source_image: bool = False,
) -> GcsPatch:
"""Returns a patch of the image."""
return GcsPatch(
self,
x,
y,
width,
height,
require_fully_in_source_image=require_fully_in_source_image,
)
def get_image_as_patch(self) -> GcsPatch:
return self.get_patch(
0,
0,
self.width,
self.height,
require_fully_in_source_image=True, # patch matches source image dim.
)
@property
def width(self) -> int:
if self._width == -1:
self._get_gcs_image() # initializes self._width
return self._width
@property
def height(self) -> int:
if self._height == -1:
self._get_gcs_image() # initializes self._height
return self._height
@property
def bytes_pre_pixel(self) -> int:
if self._bytes_pre_pixel == -1:
self._get_gcs_image()
return self._bytes_pre_pixel
def create_icc_profile_transformation(
self,
icc_profile: Union[bytes, ImageCms.core.CmsProfile, None],
rendering_intent: ImageCms.Intent = ImageCms.Intent.PERCEPTUAL,
) -> Optional[ImageCms.ImageCmsTransform]:
"""Returns transformation to from pyramid colorspace to icc_profile.
Args:
icc_profile: ICC Profile to DICOM Pyramid imaging to.
rendering_intent: Rendering intent to use in transformation.
Returns:
PIL.ImageCmsTransformation to transform pixel imaging or None.
"""
if icc_profile is None or not icc_profile:
return None
return dicom_slide.create_icc_profile_transformation(
self._get_gcs_image().icc_color_profile, icc_profile, rendering_intent
)
def image_bytes(
self, color_transform: Optional[ImageCms.ImageCmsTransform] = None
) -> np.ndarray:
"""Loads the pixel bytes of the DICOM Image.
Args:
color_transform: Optional ICC Profile color transformation to perform on
image.
Returns:
Numpy array representing the DICOM Image.
"""
# Internally reuses the Patch implementation for bytes fetching.
# An image can be represented as a giant patch starting from (0, 0)
# and spans the whole slide.
return dicom_slide.transform_image_bytes_color(
self._get_gcs_image().image_bytes, color_transform
)