ez_wsi_dicomweb/patch_embedding.py (545 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main functions to return embeddings for Patches or Images."""
import collections.abc
from concurrent import futures
import dataclasses
import functools
import os
import threading
import time
import typing
from typing import Any, Iterator, List, Optional, Sequence, Union
from ez_wsi_dicomweb import credential_factory as credential_factory_module
from ez_wsi_dicomweb import dicom_slide
from ez_wsi_dicomweb import error_retry_util
from ez_wsi_dicomweb import ez_wsi_errors
from ez_wsi_dicomweb import gcs_image
from ez_wsi_dicomweb import local_image
from ez_wsi_dicomweb import patch_embedding_endpoints
from ez_wsi_dicomweb import patch_embedding_ensemble_methods
from ez_wsi_dicomweb import patch_embedding_types
from ez_wsi_dicomweb import patch_generator
from ez_wsi_dicomweb import pixel_spacing
from ez_wsi_dicomweb import slide_level_map
import numpy as np
import retrying
@dataclasses.dataclass(frozen=True)
class BatchEmbeddingRequest:
"""Batch Embedding request."""
json_request: str
prepared_request: Sequence[
patch_embedding_endpoints.AbstractPreparedEmbeddingRequest
]
# by default embedding request throttling is disabled.
_max_requests_per_minute: Optional[int] = None
_last_request_time = 0.0
_request_lock = threading.Lock()
def _init_request_throttle() -> None:
global _request_lock
_request_lock = threading.Lock()
def disable_embedding_request_throttling():
"""Disables embedding request throttling."""
global _max_requests_per_minute
with _request_lock:
_max_requests_per_minute = None
def set_max_embedding_requests_per_min(max_requests_per_minute: int) -> None:
"""Sets maximum number of requests per minute which can occure in process."""
global _max_requests_per_minute
with _request_lock:
_max_requests_per_minute = max_requests_per_minute
def _get_embedding_thread(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
slide_embeddings: Sequence[
patch_embedding_endpoints.AbstractPreparedEmbeddingRequest
],
) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
"""Returns endpoint embedding for list of slide patches."""
@retrying.retry(
**error_retry_util.other_http_exception_retry_config(
endpoint.retry_count()
)
)
def _inner_func() -> str:
if _max_requests_per_minute is None:
# embedding request throttling is disabled.
return endpoint.request_embeddings(slide_embeddings)
# embedding request throttling is enabled.
while True:
global _last_request_time
min_sec_between_requests = 60.0 / _max_requests_per_minute
with _request_lock:
current_time = time.time()
# min average time between requests
delta = min_sec_between_requests - (current_time - _last_request_time)
if delta <= 0:
_last_request_time = current_time
return endpoint.request_embeddings(slide_embeddings)
# sleep until delta predicted to expire.
time.sleep(delta)
response = _inner_func()
return endpoint.process_response(
[s.slide_embedding_source for s in slide_embeddings], response
)
@dataclasses.dataclass(frozen=True)
class _GeneratedPreparedRequest:
prepared_requests: List[
patch_embedding_endpoints.AbstractPreparedEmbeddingRequest
]
source_overflow: List[patch_embedding_types.SlideEmbeddingSource]
overflow_size: int
class _EmbeddingAPIRequest:
"""Collects patch embedding api requests."""
def __init__(self):
self._slide_processing: Union[
gcs_image.GcsImage,
slide_level_map.Level,
slide_level_map.ResizedLevel,
None,
] = None
# list of unique slides, gcsimages, etc, each has a list of one or more
# patches.
self._queued_embedding_image_requests: List[
patch_embedding_types.SlideEmbeddingSource
] = []
self._mag_scaled_patch_count = 0
self._patch_count = 0
@property
def has_queued_embedding_requests(self) -> bool:
# is there a image (slides, gcsimage) queued.
return bool(self._queued_embedding_image_requests)
def __len__(self) -> int:
"""Number of image sources, slides, gcsimages that have patch requests."""
return len(self._queued_embedding_image_requests)
def _recalulate_patch_counts_for_queued_requests(self) -> None:
self._patch_count = 0
self._mag_scaled_patch_count = 0
for image_request in self._queued_embedding_image_requests:
self._mag_scaled_patch_count += (
image_request.mag_scaled_embedding_patch_count
)
self._patch_count += len(image_request.patches)
def _generate_request(
self, endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint
) -> _GeneratedPreparedRequest:
"""Transforms list of embedding request into a request and overflow.
Vertex Endpoints have size limits that restrict the total number of bytes
that can be sent in a single request and the pathology embedding endpoints
restrict the total number of patch embeddings that can be requested at once.
This function transforms a list of embedding requests into a valid
request and a overflow list that will be processed in a different request.
Args:
endpoint: Patch embedding endpoint to use.
Returns:
_GeneratedPreparedRequest(
prepared list of embeddings that can be processed,
overflow list of embeddings that could not be processed,
size of the overflow in bytes.
)
Raises:
ez_wsi_errors.PatchEmbeddingEndpointError: If the first embedding request
exceeds the size limit of the endpoint.
"""
pending_request_size_in_bytes = 0
prepared_request_list: List[
patch_embedding_endpoints.AbstractPreparedEmbeddingRequest
] = []
request_overflow = []
overflow_size = 0
for index, embedding_request in enumerate(
self._queued_embedding_image_requests
):
prepared_request = endpoint.prepare_embedding_request(embedding_request)
prepared_request_size = prepared_request.json_size_in_bytes
max_request_size = endpoint.max_request_size_bytes()
if (
pending_request_size_in_bytes + prepared_request_size
<= max_request_size
):
pending_request_size_in_bytes += prepared_request_size
prepared_request_list.append(prepared_request)
prepared_request.finalize()
else:
split_prepared_request, overflow_embedding_source = (
prepared_request.split(
endpoint, max_request_size - pending_request_size_in_bytes
)
)
if split_prepared_request is not None:
split_prepared_request.finalize()
prepared_request_list.append(split_prepared_request)
# slightly under estimates size of overflow, doesn't count duplicate
# state.
overflow_size = (
prepared_request.json_size_in_bytes
- split_prepared_request.json_size_in_bytes
)
elif index == 0:
raise ez_wsi_errors.PatchEmbeddingEndpointError(
'Embedding request size,'
f' {prepared_request_size} (bytes), exceeds endpoint'
f' size limit, {max_request_size} (bytes).'
)
else:
overflow_size = prepared_request.json_size_in_bytes
request_overflow = [overflow_embedding_source]
request_overflow.extend(
self._queued_embedding_image_requests[index + 1 :]
)
break
return _GeneratedPreparedRequest(
prepared_request_list, request_overflow, overflow_size
)
def generate_prepared_embedding_request(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
) -> Iterator[_GeneratedPreparedRequest]:
"""returns prepared embedding requests."""
run_batch_request_loop = self.has_queued_embedding_requests
max_request_size = endpoint.max_request_size_bytes()
while run_batch_request_loop:
gen_request = self._generate_request(endpoint)
run_batch_request_loop = (
len(gen_request.source_overflow) > 1
or gen_request.overflow_size > max_request_size
)
yield gen_request
self._queued_embedding_image_requests = gen_request.source_overflow
if not self._queued_embedding_image_requests:
self._slide_processing = None
self._recalulate_patch_counts_for_queued_requests()
def add_new_slide(
self,
slide_key: Union[
gcs_image.GcsImage,
slide_level_map.Level,
slide_level_map.ResizedLevel,
],
):
"""Adds new slide to the embedding request."""
self._queued_embedding_image_requests.append(
patch_embedding_types.SlideEmbeddingSource([])
)
self._slide_processing = slide_key
@property
def slide_processing(self) -> Union[
gcs_image.GcsImage,
slide_level_map.Level,
slide_level_map.ResizedLevel,
None,
]:
"""Returns key for the slide currently being processed."""
return self._slide_processing
@property
def mag_scaled_patch_count(self) -> int:
"""Returns total number of embeddings requested scaled by magnification."""
return self._mag_scaled_patch_count
@property
def patch_count(self) -> int:
"""Returns total number of embeddings requested."""
return self._patch_count
def add_patch(
self,
embedding_request: patch_embedding_types.PatchEmbeddingSource,
mag_scaled_patch_count: int,
) -> None:
"""Adds an embedding request for current slide."""
self._queued_embedding_image_requests[-1].patches.append(embedding_request)
self._mag_scaled_patch_count += mag_scaled_patch_count
self._patch_count += 1
def _generate_prepared_embedding_requests(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch_embedding_sources: Union[
Iterator[patch_embedding_types.PatchEmbeddingSource],
Sequence[patch_embedding_types.PatchEmbeddingSource],
],
) -> Iterator[List[patch_embedding_endpoints.AbstractPreparedEmbeddingRequest]]:
"""Yields embedding requests to be processed on an endpoint.
Args:
endpoint: Patch embedding endpoint to use.
patch_embedding_sources: Iterator of embedding requests.
Yields:
embedding requests to perform in batch
"""
api_request = _EmbeddingAPIRequest()
max_number_of_patches_per_request = (
endpoint.max_number_of_patches_per_request()
)
endpoint_max_mag_scaled_patch_count = (
endpoint.endpoint_max_number_of_patches_per_request()
)
for patch_embedding_source in patch_embedding_sources:
patch = patch_embedding_source.patch
if isinstance(patch, dicom_slide.DicomPatch):
slide_key = patch.level
elif isinstance(patch, gcs_image.GcsPatch):
slide_key = patch.source
else:
raise ez_wsi_errors.InternalError(
'Patch is not a dicom_slide.DicomPatch or gcs_image.GcsPatch.'
)
patch_count = patch_embedding_source.mag_scaled_embedding_patch_count
if api_request.mag_scaled_patch_count > 0 and (
api_request.mag_scaled_patch_count + patch_count
> endpoint_max_mag_scaled_patch_count
or api_request.patch_count + 1 > max_number_of_patches_per_request
):
for br in api_request.generate_prepared_embedding_request(endpoint):
yield br.prepared_requests
if api_request.slide_processing != slide_key:
api_request.add_new_slide(slide_key)
api_request.add_patch(patch_embedding_source, patch_count)
while api_request.has_queued_embedding_requests:
yield_result = False
for br in api_request.generate_prepared_embedding_request(endpoint):
yield br.prepared_requests
yield_result = True
if not yield_result:
raise ez_wsi_errors.InternalError(
'Error request queue is not processing.'
)
def _embedding_api_call(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch_embedding_sources: Union[
Iterator[patch_embedding_types.PatchEmbeddingSource],
Sequence[patch_embedding_types.PatchEmbeddingSource],
],
) -> Iterator[List[patch_embedding_types.PatchEmbeddingEnsembleResult]]:
"""Yields an embedding results.
Args:
endpoint: Patch embedding endpoint to use.
patch_embedding_sources: Iterator of embedding requests.
Yields:
Embedding results.
"""
max_threads = endpoint.max_threads()
map_func = functools.partial(_get_embedding_thread, endpoint)
prepared_embedding_requests = _generate_prepared_embedding_requests(
endpoint, patch_embedding_sources
)
if max_threads < 2:
for response in map(map_func, prepared_embedding_requests):
yield response
else:
try:
with futures.ThreadPoolExecutor(max_workers=max_threads) as pool:
for response in pool.map(
map_func,
prepared_embedding_requests,
# scale endpoint timeout to allow for internal retry.
timeout=None if endpoint.timeout is None else endpoint.timeout * 4,
):
yield response
except TimeoutError as exp:
raise ez_wsi_errors.ThreadPoolTimeoutError(
'Timeout while waiting for embedding results.'
) from exp
def _generate_ensemble_for_patches(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
ensemble_method: patch_embedding_ensemble_methods.PatchEnsembleMethod,
patches: Union[
Sequence[patch_embedding_types.EmbeddingPatch],
Iterator[patch_embedding_types.EmbeddingPatch],
],
) -> Iterator[patch_embedding_types.PatchEmbeddingSource]:
"""Yields embedding api requests for user defined patches.
Args:
endpoint: Patch embedding endpoint to use.
ensemble_method: Method to use to genenerate embedding requests for user
defined patch.
patches: Iterator of user defined patches.
Yields:
Embedding api requests.
"""
for patch in patches:
for ag_patch in ensemble_method.generate_ensemble(endpoint, patch):
yield ag_patch
def _reduce_embedding_ensemble(
ensemble_method: patch_embedding_ensemble_methods.PatchEnsembleMethod,
result_lists: Iterator[
Sequence[patch_embedding_types.PatchEmbeddingEnsembleResult]
],
) -> Iterator[patch_embedding_types.EmbeddingResult]:
"""Yields embedding results for user defined patches.
Args:
ensemble_method: Method to use to genenerate embedding requests for user
defined patch.
result_lists: Iterator of List embedding results.
Yields:
Embedding results for use defined patches.
"""
ensemble_id = ''
ensemble_list: List[patch_embedding_types.PatchEmbeddingEnsembleResult] = []
for result_list in result_lists:
for result in result_list:
if result.input_patch.ensemble_id == ensemble_id:
ensemble_list.append(result)
continue
if ensemble_list:
yield ensemble_method.reduce_ensemble(
ensemble_list[0].input_patch.ensemble_source_patch, ensemble_list
)
ensemble_id = result.input_patch.ensemble_id
ensemble_list = [result]
if ensemble_list:
yield ensemble_method.reduce_ensemble(
ensemble_list[0].input_patch.ensemble_source_patch, ensemble_list
)
def _create_patch_embedding_batch_request(
endpoint: patch_embedding_endpoints.AbstractVertexPatchEmbeddingEndpointBase,
patches: Union[
Sequence[patch_embedding_types.EmbeddingPatch],
Iterator[patch_embedding_types.EmbeddingPatch],
],
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> Sequence[BatchEmbeddingRequest]:
"""Returns Sequence of embedding requests to be processed in batch.
Args:
endpoint: Patch embedding endpoint to use.
patches: Iterator of user defined patches.
ensemble_method: Method to use to genenerate embedding requests for user
defined patch.
Returns:
Sequence of embedding requests to be processed in batch.
"""
if ensemble_method is None:
ensemble_method = (
patch_embedding_ensemble_methods.DefaultSinglePatchEnsemble()
)
# force embedding requests will be created with credentials that
# have been acquired at the time the batch request was initialized.
credential_factory_module.clear_credential_cache()
embedding_request = []
for prepared_requests in _generate_prepared_embedding_requests(
endpoint,
_generate_ensemble_for_patches(endpoint, ensemble_method, patches),
):
embedding_request.append(
BatchEmbeddingRequest(
endpoint.get_embedding_request(
typing.cast(
Sequence[
patch_embedding_endpoints.PreparedVertexEmbeddingRequest
],
prepared_requests,
),
),
prepared_requests,
)
)
return embedding_request
def generate_patch_embeddings(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patches: Union[
Sequence[patch_embedding_types.EmbeddingPatch],
Iterator[patch_embedding_types.EmbeddingPatch],
],
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> Iterator[patch_embedding_types.EmbeddingResult]:
"""Returns Iterator embedding results for user defined patches.
Args:
endpoint: Patch embedding endpoint to use.
patches: Iterator of user defined patches.
ensemble_method: Method to use to genenerate embedding requests for user
defined patch.
Returns:
Iterator embedding results for user defined patches.
"""
if ensemble_method is None:
ensemble_method = (
patch_embedding_ensemble_methods.DefaultSinglePatchEnsemble()
)
return _reduce_embedding_ensemble(
ensemble_method,
_embedding_api_call(
endpoint,
_generate_ensemble_for_patches(endpoint, ensemble_method, patches),
),
)
def get_patch_embedding(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patch: patch_embedding_types.EmbeddingPatch,
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> np.ndarray:
"""Returns embedding for a user defined patch.
Args:
endpoint: Patch embedding endpoint to use.
patch: user defined patch.
ensemble_method: Method to use to genenerate embedding requests for user
defined patch.
Returns:
Returns embedding (numpy array) for a user defined patch.
"""
return next(
generate_patch_embeddings(endpoint, [patch], ensemble_method)
).embedding
class PatchEmbeddingSequence(
collections.abc.Sequence[patch_embedding_types.EmbeddingResult]
):
"""Sequence of patches to return embeddings by index.
If all embeddings in the sequence are going to be iteratated across accessing
the embeddings via the iterator will provide higher performance by enabling
multiple patches to be requested concurrently.
"""
def __init__(
self,
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
patches: Sequence[patch_embedding_types.EmbeddingPatch],
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
):
"""Constructor for PatchEmbeddingSequence.
Args:
endpoint: Is the an abstraction interface through which EZ-WSI
communicates with the various embedding model VertexAI endpoints and or
local execution.
patches: A sequence of patches. Patches should be clustered in the input
sequence such that patch from the same data source are fall sequentially
in the sequence.
ensemble_method: Ensemble methods are optional and enable EZ-WSI to
generate embeddings for patches which exceed the embedding dimensions of
the endpoint. If not provided, input patches must match the input width
and height dimensions of the endpoint.
"""
super().__init__()
self._endpoint = endpoint
self._patches = patches
self._ensemble_method = ensemble_method
def __eq__(self, value: Any) -> bool:
if not isinstance(value, PatchEmbeddingSequence):
return False
return self._patches == value._patches
def __contains__(self, value: Any) -> bool:
if not isinstance(value, patch_embedding_types.EmbeddingPatch):
return False
return value in self._patches
def __getitem__(self, index: Union[int, slice]):
if isinstance(index, int):
return next(
generate_patch_embeddings(
self._endpoint, [self._patches[index]], self._ensemble_method
)
)
return list(
generate_patch_embeddings(
self._endpoint, self._patches[index], self._ensemble_method
)
)
def get_patch(self, index: int) -> patch_embedding_types.EmbeddingPatch:
return self._patches[index]
def get_embedding(self, index: int) -> np.ndarray:
return self.__getitem__(index).embedding # pytype: disable=attribute-error
def __iter__(self) -> Iterator[patch_embedding_types.EmbeddingResult]:
return generate_patch_embeddings(
self._endpoint, self._patches, self._ensemble_method
)
def __len__(self) -> int:
return len(self._patches)
def get_dicom_image_embeddings(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
slide: dicom_slide.DicomSlide,
ps: Union[
slide_level_map.Level,
slide_level_map.ResizedLevel,
pixel_spacing.PixelSpacing,
],
patch_size: Optional[int] = None,
mask: Optional[np.ndarray] = None,
stride_size: Optional[int] = None,
min_luminance: Optional[float] = None,
max_luminance: Optional[float] = None,
mask_level: Union[
slide_level_map.Level,
slide_level_map.ResizedLevel,
pixel_spacing.PixelSpacing,
None,
] = None,
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> PatchEmbeddingSequence:
"""Returns Itertor of embeddings for a level of whole slide pyramid.
Args:
endpoint: Patch embedding endpoint to use.
slide: DICOM Slide to extract patches from.
ps: Pixel spacing of the slide pyramid level to extract patches from.
patch_size: Size of the patch to extract defaults to endpoint patch size.
mask: If provided, will be used as the embedding patch sampling mask.
stride_size: Stride size to use when extracting patches defaults to patch
size.
min_luminance: Regions with luminance (grayscale) < this threshold are to be
considered non-tissue background, and will be discarded in the patch
sampling.
max_luminance: Regions with luminance (grayscale) > this threshold are to be
considered non-tissue background, and will be discarded in the patch
sampling.
mask_level: Pyramid level to use to determine where tissue is present if a
tissue mask is not provded.
ensemble_method: Method to use to genenerate embedding patches; required
only patch dimensions != endpoint patch dimensions.
Returns:
Sequence of embedding results
"""
if patch_size is None:
patch_size = endpoint.patch_width()
if stride_size is None:
stride_size = patch_size
if mask is None and mask_level is None:
mask_level = patch_generator.TISSUE_MASK_PIXEL_SPACING
if isinstance(mask_level, pixel_spacing.PixelSpacing):
if (
slide.get_level_by_pixel_spacing(mask_level, maximum_downsample=8.0)
is None
):
mask_level = ps
target_icc_profile_bytes = endpoint.icc_profile_bytes()
color_transform = (
slide.create_icc_profile_transformation(target_icc_profile_bytes)
if target_icc_profile_bytes
else None
)
return PatchEmbeddingSequence(
endpoint,
patch_generator.DicomPatchGenerator(
slide,
ps,
patch_size=patch_size,
mask=mask,
stride_size=stride_size,
min_luminance=min_luminance,
max_luminance=max_luminance,
mask_level=mask_level,
mask_color_transform=color_transform,
),
ensemble_method,
)
def get_gcs_image_embeddings(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
image: Union[gcs_image.GcsImage, local_image.LocalImage],
patch_size: Optional[int] = None,
mask: Optional[np.ndarray] = None,
stride_size: Optional[int] = None,
min_luminance: Optional[float] = None,
max_luminance: Optional[float] = None,
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> PatchEmbeddingSequence:
"""Returns Itertor of embeddings for a level of whole slide pyramid.
Args:
endpoint: Patch embedding endpoint to use.
image: Image to extract patches from.
patch_size: Size of the patch to extract defaults to endpoint patch size.
mask: If provided, will be used as the embedding patch sampling mask.
stride_size: Stride size to use when extracting patches defaults to patch
size.
min_luminance: Regions with luminance (grayscale) < this threshold are to be
considered non-tissue background, and will be discarded in the patch
sampling.
max_luminance: Regions with luminance (grayscale) > this threshold are to be
considered non-tissue background, and will be discarded in the patch
sampling.
ensemble_method: Method to use to genenerate embedding patches; required
only patch dimensions != endpoint patch dimensions.
Returns:
Iterator embedding results
"""
if patch_size is None:
patch_size = endpoint.patch_width()
if stride_size is None:
stride_size = patch_size
target_icc_profile_bytes = endpoint.icc_profile_bytes()
color_transform = (
image.create_icc_profile_transformation(target_icc_profile_bytes)
if target_icc_profile_bytes
else None
)
return PatchEmbeddingSequence(
endpoint,
patch_generator.GcsImagePatchGenerator(
image,
patch_size=patch_size,
mask=mask,
stride_size=stride_size,
min_luminance=min_luminance,
max_luminance=max_luminance,
mask_color_transform=color_transform,
),
ensemble_method,
)
def gcs_images_to_embeddings(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
images: patch_generator.GcsImagesToPatchesInputTypes,
credential_factory: Optional[
credential_factory_module.AbstractCredentialFactory
] = None,
image_dimensions: Optional[gcs_image.ImageDimensions] = None,
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> Iterator[patch_embedding_types.EmbeddingResult]:
"""Converts whole images in GCS into embeddings."""
return generate_patch_embeddings(
endpoint,
patch_generator.gcs_images_to_patches(
images, credential_factory, image_dimensions, endpoint.max_threads()
),
ensemble_method,
)
def local_images_to_embeddings(
endpoint: patch_embedding_endpoints.AbstractPatchEmbeddingEndpoint,
images: patch_generator.LocalImagesToPatchesInputTypes,
image_dimensions: Optional[gcs_image.ImageDimensions] = None,
ensemble_method: Optional[
patch_embedding_ensemble_methods.PatchEnsembleMethod
] = None,
) -> Iterator[patch_embedding_types.EmbeddingResult]:
"""Converts whole local images into embeddings."""
return generate_patch_embeddings(
endpoint,
patch_generator.local_images_to_patches(images, image_dimensions),
ensemble_method,
)
# init class module variables if forked.
os.register_at_fork(after_in_child=_init_request_throttle) # pylint: disable=protected-access