# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core methods to return endpoint generated embeddings for patch pixels."""

from __future__ import annotations

import abc
from collections.abc import Sequence
import concurrent
import copy
import dataclasses
import enum
import json
import math
import threading
import typing
from typing import Any, Callable, Generic, List, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar, Union

import cachetools
from ez_wsi_dicomweb import credential_factory as credential_factory_module
from ez_wsi_dicomweb import dicom_slide
from ez_wsi_dicomweb import error_retry_util
from ez_wsi_dicomweb import ez_wsi_errors
from ez_wsi_dicomweb import gcs_image
from ez_wsi_dicomweb import patch_embedding_types
from ez_wsi_dicomweb import slide_level_map
from ez_wsi_dicomweb.ml_toolkit import tags
import google.auth
import numpy as np
from PIL import ImageCms
import requests
import retrying


class EndpointJsonKeys:
  """JSON Keys for pathology v1 and v2 encoder endpoint."""

  # V2 encoder key types
  IMAGE_FILE_URI = 'image_file_uri'
  RAW_IMAGE_BYTES = 'raw_image_bytes'
  DICOM_PATH = 'dicom_path'
  SERIES_PATH = 'series_path'
  INSTANCE_UIDS = 'instance_uids'
  BEARER_TOKEN = 'bearer_token'
  INSTANCES = 'instances'

  PATCH_COORDINATES = 'patch_coordinates'
  X_ORIGIN = 'x_origin'
  Y_ORIGIN = 'y_origin'
  WIDTH = 'width'
  HEIGHT = 'height'

  EXTENSIONS = 'extensions'
  IMAGE_DIMENSIONS = 'image_dimensions'
  TRANSFORM_IMAGING_TO_ICC_PROFILE = 'transform_imaging_to_icc_profile'
  REQUIRE_PATCHES_FULLY_IN_SOURCE_IMAGE = (
      'require_patches_fully_in_source_image'
  )
  EZ_WSI_STATE = 'ez_wsi_state'

  # key for list of patch bytes, base64 encoded
  PATCHES = 'patches'
  # key whole image bytes, base64 encoded
  IMAGE = 'image'
  # icc profile norm performed on patch imaging.
  ICC_PROFILE_METADATA_NORMALIZATION = 'icc_profile_metadata_normalization'
  # height and width of source image for patch
  SOURCE_IMAGE_WIDTH_PX = 'source_image_width_px'
  SOURCE_IMAGE_HEIGHT_PX = 'source_image_height_px'

  # V1 encoder only
  DICOM_WEB_STORE_URL = 'dicom_web_store_url'
  DICOM_STUDY_UID = 'dicom_study_uid'  # V1 encoder only
  DICOM_SERIES_UID = 'dicom_series_uid'  # V1 encoder only
  GCS_IMAGE_URL = 'gcs_image_url'  # V1 encoder only
  PROJECT_NAME = 'project_name'  # V1 encoder only
  PARAMETERS = 'parameters'  # V1 encoder only
  MODEL_SIZE = 'model_size'  # V1 encoder only
  MODEL_KIND = 'model_kind'  # V1 encoder only

  # embedding encoder response
  PREDICTIONS = 'predictions'
  VERTEXAI_ERROR = 'error'
  ERROR = 'error'
  ERROR_CODE = 'code'
  ERROR_CODE_DESCRIPTION = 'description'
  RESULT = 'result'
  EMBEDDINGS = 'embeddings'  # V1 encoder only
  ERROR_RESPONSE = 'error_response'  # V1 encoder only
  EMBEDDING_RESULT = 'embedding_result'  # V1 encoder only
  EMBEDDING_VECTOR = 'embedding_vector'  # V2 encoder
  PATCH_EMBEDDINGS = 'patch_embeddings'
  PATCH_COORDINATE = 'patch_coordinate'

  # Retryable error codes
  INVALID_CREDENTIALS = 'INVALID_CREDENTIALS_ERROR'


# maximum request size in bytes for endpoint. Less than vertex max to provide
# safety margin.
_MAX_VERTEX_AI_V1_REQUEST_SIZE_BYTES = 1300000

_DEFAULT_RETRY_COUNT = 5
_DEFAULT_ENDPOINT_THREADS = 5
_DEFAULT_MAX_PATCHES_PER_REQUEST = 100

_MAX_ENDPOINT_THREADS = 10
_ITERATOR_MAX_ENDPOINT_PATCHES_PER_REQUEST = 100
_MAX_V1_ENDPOINT_PATCHES_PER_REQUEST = 3000
_MAX_V2_ENDPOINT_PATCHES_PER_REQUEST = 3000
_DEFAULT_DICOM_INSTANCE_ICC_PROFILE_CACHE_COUNT = 20

# Size safety buffer encode whole images may not exceed.
# Maxsize= _MAX_VERTEX_AI_V1_REQUEST_SIZE_BYTES - _WHOLE_IMAGE_SIZE_SAFTY_MARGIN
# Images exceeding this are are encoded as patches which enables them to be
# split across multiple VertexAI requests..
_WHOLE_IMAGE_SIZE_SAFTY_MARGIN = 300000

# Pyramid ICC profiles are optimally serialized in JSON to avoid repeative
# re-initialization. However, some digital pathology DICOM, e.g. Leica, have
# very large ICC profiles, e.g., 12 MB. The default max size of the ICC profile
# controls the maximum size of the ICC profile serialized in JSON.
_MAX_DICOM_SLIDE_ICCPROFILE_METADATA_SIZE = min(
    204800, slide_level_map.DEFAULT_MAX_JSON_ENCODED_ICC_PROFILE_SIZE_IN_BYTES
)

_DEFAULT_TIMEOUT = None  # Time out in seconds; None for no timeout.


class IccProfileNormalization(enum.Enum):
  """ICC Profile To Normalize Embedding Patches To."""

  NONE = 'NONE'
  SRGB = 'SRGB'
  ADOBERGB = 'ADOBERGB'
  ROMMRGB = 'ROMMRGB'


_UINT8_MAX_VALUE = 255.0


def _test_patch_coordinates_match(
    pc: Mapping[str, Any], x: int, y: int, width: int, height: int
) -> bool:
  """Test if dict encoded coordinates match expected coordinates."""
  try:
    if pc[EndpointJsonKeys.X_ORIGIN] != x or pc[EndpointJsonKeys.Y_ORIGIN] != y:
      return False
    if (
        pc.get(EndpointJsonKeys.WIDTH, width) != width
        or pc.get(EndpointJsonKeys.HEIGHT, height) != height
    ):
      return False
    return True
  except (IndexError, KeyError, ValueError, TypeError) as _:
    return False


def _get_icc_profile_bytes(
    icc_profile_normalization: IccProfileNormalization,
) -> bytes:
  """Returns ICC Profile bytes for endpoint."""
  if icc_profile_normalization == IccProfileNormalization.NONE:
    return b''
  if icc_profile_normalization == IccProfileNormalization.SRGB:
    return dicom_slide.get_srgb_icc_profile_bytes()
  if icc_profile_normalization == IccProfileNormalization.ADOBERGB:
    return dicom_slide.get_adobergb_icc_profile_bytes()
  if icc_profile_normalization == IccProfileNormalization.ROMMRGB:
    return dicom_slide.get_rommrgb_icc_profile_bytes()
  raise ez_wsi_errors.InternalError('ICC Profile not supported')


RequestResponseType = TypeVar('RequestResponseType')


def normalized_patch_channels(
    width: int, height: int, patch: np.ndarray
) -> np.ndarray:
  """Normalize monochrome and RGBA imaging to RGB."""
  if patch.shape == (height, width, 3):
    return patch
  if patch.shape == (height, width):
    patch = np.expand_dims(patch, axis=-1)
  if patch.shape == (height, width, 1):
    mem = np.zeros((height, width, 3), dtype=patch.dtype)
    mem[..., np.arange(3)] = patch[...]
    return mem
  if patch.shape == (height, width, 4):
    return patch[..., :3]
  raise ez_wsi_errors.PatchEmbeddingDimensionError


class AbstractPreparedEmbeddingRequest(
    Generic[RequestResponseType], metaclass=abc.ABCMeta
):
  """Base class for prepared embedding requests."""

  def __init__(
      self,
      slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
  ):
    self._slide_embedding_source = slide_embedding_source

  @property
  def slide_embedding_source(
      self,
  ) -> patch_embedding_types.SlideEmbeddingSource:
    return self._slide_embedding_source

  @property
  @abc.abstractmethod
  def json_size_in_bytes(self) -> int:
    """Return size in bytes of json sent to endpoint."""

  @abc.abstractmethod
  def finalize(self) -> None:
    """finalize after this there will be no more changes."""

  @abc.abstractmethod
  def split(
      self,
      endpoint: AbstractPatchEmbeddingEndpoint[RequestResponseType],
      max_size: int,
  ) -> Tuple[
      Optional[AbstractPreparedEmbeddingRequest[RequestResponseType]],
      patch_embedding_types.SlideEmbeddingSource,
  ]:
    """Splits object into parts which meet size and exceed size req."""


def _copy_dict_excluding_keys(
    state: Mapping[str, Any],
    exclude: List[List[str]],
    base_keys: Optional[List[str]] = None,
) -> MutableMapping[str, Any]:
  """Duplicates str keyed dict excluding predfined keys."""
  if base_keys is None:
    base_keys = []
  copy_dict = {}
  for key, value in state.items():
    base_keys.append(key)
    if base_keys in exclude:
      base_keys.pop()
      continue
    if not isinstance(value, Mapping):
      copy_dict[key] = value
    else:
      copy_dict[key] = _copy_dict_excluding_keys(value, exclude, base_keys)
    base_keys.pop()
  return copy_dict


@dataclasses.dataclass
class _VertexModelResult:
  instances: List[Mapping[str, Any]]


class PreparedVertexEmbeddingRequest(
    AbstractPreparedEmbeddingRequest[_VertexModelResult]
):
  """Internral respresentation of embedding json embedding requests."""

  def __init__(
      self,
      prepared_request: Mapping[str, Any],
      slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
  ):
    super().__init__(slide_embedding_source)
    self._embedding_request: Mapping[str, Any] = prepared_request
    self._embedding_json: Optional[str] = None

  @classmethod
  def init_from_json_finalized(
      cls,
      json_str: str,
      slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
  ) -> PreparedVertexEmbeddingRequest:
    """create finalized vertex embedding request from json and source."""
    instance = PreparedVertexEmbeddingRequest.__new__(
        PreparedVertexEmbeddingRequest
    )
    super(PreparedVertexEmbeddingRequest, instance).__init__(
        slide_embedding_source
    )
    instance._embedding_json = json_str
    instance._embedding_request = None
    return instance

  @property
  def embedding_request(self) -> Mapping[str, Any]:
    if self._embedding_request is None:
      raise ez_wsi_errors.InternalError('Request has been finalized.')
    return self._embedding_request

  @property
  def json(self) -> str:
    if self._embedding_json is None:
      self._embedding_json = json.dumps(self._embedding_request)
    return self._embedding_json

  @property
  def json_size_in_bytes(self) -> int:
    return len(self.json)

  def finalize(self) -> None:
    if self._embedding_json is None:
      self._embedding_json = json.dumps(self._embedding_request)
    self._embedding_request = None

  def _non_splitable(
      self,
  ) -> Tuple[None, patch_embedding_types.SlideEmbeddingSource]:
    return None, self.slide_embedding_source

  def _split_results(self, split_request: str, end_split_index: int) -> Tuple[
      Optional[PreparedVertexEmbeddingRequest],
      patch_embedding_types.SlideEmbeddingSource,
  ]:
    """Returns split prepared vertex embedding result."""
    split_half_slide_embedding_source = (
        patch_embedding_types.SlideEmbeddingSource(
            self.slide_embedding_source.patches[:end_split_index]
        )
    )
    finalized_split_half = (
        PreparedVertexEmbeddingRequest.init_from_json_finalized(
            split_request, split_half_slide_embedding_source
        )
    )
    slide_embedding_source = patch_embedding_types.SlideEmbeddingSource(
        self.slide_embedding_source.patches[end_split_index:]
    )
    return (
        finalized_split_half,
        slide_embedding_source,
    )

  def _split_on_patch_coordinates_only(
      self,
      endpoint: AbstractPatchEmbeddingEndpoint[_VertexModelResult],
      max_size: int,
  ) -> Tuple[
      Optional[PreparedVertexEmbeddingRequest],
      patch_embedding_types.SlideEmbeddingSource,
  ]:
    """If possilbe splits prepared request to meet size req."""
    if self.json_size_in_bytes < endpoint.max_request_size_bytes():
      # it is much less desirable to split DicomPatches or GcsImages into
      # multiple requests. The majority of the metadata could be in state
      # which cannot be split and would be sent in duplicate across multiple
      # requests. Test if entire message could be encoded by itself. If it can
      # defer sending.
      return self._non_splitable()
    try:
      coordinates = self.embedding_request[EndpointJsonKeys.PATCH_COORDINATES]
    except KeyError:
      return self._non_splitable()
    # make copy of whats in patch exclude coordinates and cached patches.
    base_request = _copy_dict_excluding_keys(
        self.embedding_request,
        [[EndpointJsonKeys.PATCH_COORDINATES]],
    )
    request_size = len(json.dumps(base_request))
    end_split_index = 0
    for coordinate in coordinates:
      patch_md_size = len(json.dumps(coordinate))
      if request_size + patch_md_size >= max_size:
        break
      request_size += patch_md_size
      end_split_index += 1
    while True:
      if end_split_index <= 0:
        return self._non_splitable()
      split_request = copy.copy(base_request)
      split_request[EndpointJsonKeys.PATCH_COORDINATES] = coordinates[
          :end_split_index
      ]
      split_request = json.dumps(split_request)
      if len(split_request) <= max_size:
        break
      end_split_index -= 1
    return self._split_results(split_request, end_split_index)

  def split(
      self,
      endpoint: AbstractPatchEmbeddingEndpoint[_VertexModelResult],
      max_size: int,
  ) -> Tuple[
      Optional[PreparedVertexEmbeddingRequest],
      patch_embedding_types.SlideEmbeddingSource,
  ]:
    """If possible splits object into parts which meet size and exceed size req."""
    if EndpointJsonKeys.DICOM_PATH in self.embedding_request:
      return self._split_on_patch_coordinates_only(endpoint, max_size)
    elif EndpointJsonKeys.IMAGE_FILE_URI in self.embedding_request:
      try:
        ez_wsi_state = self.embedding_request[EndpointJsonKeys.EXTENSIONS][
            EndpointJsonKeys.EZ_WSI_STATE
        ]
        patches = ez_wsi_state[EndpointJsonKeys.PATCHES]
        coordinates = self.embedding_request[EndpointJsonKeys.PATCH_COORDINATES]
      except KeyError:
        return self._split_on_patch_coordinates_only(endpoint, max_size)
      if len(patches) <= 1:
        return self._non_splitable()
      if len(coordinates) != len(patches):
        raise ez_wsi_errors.InternalError(
            'Patch state and coordinate counts do not match.'
        )
      # make copy of whats in patch exclude coordinates and cached patches.
      base_request = _copy_dict_excluding_keys(
          self.embedding_request,
          [
              [EndpointJsonKeys.PATCH_COORDINATES],
              [
                  EndpointJsonKeys.EXTENSIONS,
                  EndpointJsonKeys.EZ_WSI_STATE,
                  EndpointJsonKeys.PATCHES,
              ],
          ],
      )
      request_size = len(json.dumps(base_request))
      end_split_index = 0
      for patch_metadata in patches:
        patch_md_size = len(patch_metadata)
        if request_size + patch_md_size >= max_size:
          break
        request_size += patch_md_size
        end_split_index += 1
      while True:
        if end_split_index <= 0:
          return self._non_splitable()
        patch_coordinate_size = len(json.dumps(coordinates[:end_split_index]))
        if patch_coordinate_size + request_size < max_size:
          break
        end_split_index -= 1
        if end_split_index > 0:
          request_size -= len(patches[end_split_index])
      while True:
        if end_split_index <= 0:
          return self._non_splitable()
        split_request = copy.copy(base_request)
        split_request[EndpointJsonKeys.PATCH_COORDINATES] = coordinates[
            :end_split_index
        ]
        split_request[EndpointJsonKeys.EXTENSIONS][
            EndpointJsonKeys.EZ_WSI_STATE
        ][EndpointJsonKeys.PATCHES] = patches[:end_split_index]
        split_request = json.dumps(split_request)
        if len(split_request) <= max_size:
          break
        end_split_index -= 1
      return self._split_results(split_request, end_split_index)
    raise ez_wsi_errors.InternalError('unidentified JSON')


class AbstractPatchEmbeddingEndpoint(
    Generic[RequestResponseType], metaclass=abc.ABCMeta
):
  """Abstract class for embedding endpoint."""

  def __init__(
      self,
      icc_profile_normalization: IccProfileNormalization,
      timeout: Optional[Union[float, int]],
  ):
    self._icc_profile_normalization = icc_profile_normalization
    self._icc_profile_bytes = None
    self._timeout = timeout

  @abc.abstractmethod
  def max_request_size_bytes(self) -> int:
    """Maximum size in bytes that can be sent in single request."""

  @abc.abstractmethod
  def max_threads(self) -> int:
    """Returns maximum number of threads to spawn."""

  @abc.abstractmethod
  def patch_width(self) -> int:
    """Returns embedding endpoint input size width in pixels."""

  @abc.abstractmethod
  def patch_height(self) -> int:
    """Returns embedding endpoint input size height in pixels."""

  @abc.abstractmethod
  def max_number_of_patches_per_request(self) -> int:
    """Maximum number of patches to send endpoint in a request."""

  @abc.abstractmethod
  def endpoint_max_number_of_patches_per_request(self) -> int:
    """Maximum number of patches that the endpoint supports."""

  @abc.abstractmethod
  def retry_count(self) -> int:
    """Maximum number of get_embedding attempts before endpoint raises.."""

  @abc.abstractmethod
  def prepare_embedding_request(
      self,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> AbstractPreparedEmbeddingRequest[RequestResponseType]:
    """Converts slide embedding source to request to JSON formatted request."""

  @abc.abstractmethod
  def request_embeddings(
      self,
      embedding_inputs: Sequence[
          AbstractPreparedEmbeddingRequest[RequestResponseType]
      ],
  ) -> RequestResponseType:
    """Returns raw embedding result."""

  @property
  def request_embedding_return_type(self) -> Type[RequestResponseType]:
    return RequestResponseType

  @abc.abstractmethod
  def process_response(
      self,
      embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
      msg: RequestResponseType,
  ) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
    """Converts raw embedding response to list of embedding results."""

  @property
  def icc_profile_normalization(self) -> IccProfileNormalization:
    """Returns ICC Profile bytes for endpoint will transform imaging to."""
    return self._icc_profile_normalization

  def icc_profile_bytes(self) -> bytes:
    """Returns ICC Profile bytes for endpoint will transform imaging to."""
    if self._icc_profile_bytes is None:
      self._icc_profile_bytes = _get_icc_profile_bytes(
          self._icc_profile_normalization
      )
    return self._icc_profile_bytes

  @property
  def timeout(self) -> Optional[Union[float, int]]:
    return self._timeout

  @timeout.setter
  def timeout(self, val: Optional[Union[float, int]]) -> None:
    self._timeout = val


def _get_gcs_image_md_size(
    json_metadata: Mapping[str, Union[int, str, List[str]]],
) -> int:
  """returns size in bytes of data encoded in metadata."""
  size = 0
  for value in json_metadata.values():
    if isinstance(value, int):
      size += len(str(value))
    elif isinstance(value, str):
      size += len(value)
    elif isinstance(value, list):
      size += sum(len(md) for md in value)
    else:
      raise ez_wsi_errors.InternalError(
          f'Unsupported metadata value type: {value}'
      )
  return size


def _patch_pixel_area(
    slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> int:
  patch_pixel_area = 0
  for ip in slide_embedding.patches:
    patch_pixel_area += ip.patch.width * ip.patch.height
  return patch_pixel_area


def _get_gcs_whole_image_metadata(
    source_image: gcs_image.GcsImage,
    icc_profile_normalization: IccProfileNormalization,
    c_transform: Optional[ImageCms.ImageCmsTransform],
    max_image_size_bytes: int,
) -> Tuple[Mapping[str, Any], int]:
  """Return whole image metadata and metadata size in bytes."""
  if source_image.size_bytes_of_source_image is not None:
    # image bytes are base64 encoded estimate actual size as 4x original.
    # plus 100 bytes of padding.
    estimated_size = int(
        math.ceil(source_image.size_bytes_of_source_image * 8 / 6)
        + len(IccProfileNormalization.NONE.value)
    )
    if estimated_size >= max_image_size_bytes:
      return {}, estimated_size
  try:
    source_image_metadata = {
        EndpointJsonKeys.IMAGE: source_image.source_image_bytes_json_metadata(),
        EndpointJsonKeys.ICC_PROFILE_METADATA_NORMALIZATION: (
            IccProfileNormalization.NONE.value
        ),
    }
  except ez_wsi_errors.GcsImageError:
    # image bytes initialized from in memory representation.
    source_image_md_size = int(
        math.ceil(
            source_image.width
            * source_image.height
            * source_image.bytes_pre_pixel
            * 8  # bits per channel
            / 12  # base 64 encoding + est 2x reduction due to PNG compression
        )
        + len(icc_profile_normalization.value)
    )
    if source_image_md_size >= max_image_size_bytes:
      return {}, source_image_md_size
    source_image_metadata = {
        EndpointJsonKeys.IMAGE: source_image.json_metadata(c_transform),
        EndpointJsonKeys.ICC_PROFILE_METADATA_NORMALIZATION: (
            icc_profile_normalization.value
        ),
    }
  source_image_md_size = _get_gcs_image_md_size(source_image_metadata)
  if source_image_md_size >= max_image_size_bytes:
    return {}, source_image_md_size
  return source_image_metadata, source_image_md_size


def _gcs_image_json_metadata(
    slide_embedding: patch_embedding_types.SlideEmbeddingSource,
    icc_profile_normalization: IccProfileNormalization,
    c_transform: Optional[ImageCms.ImageCmsTransform],
    max_endpoint_request_size_bytes: int,
) -> Mapping[str, Any]:
  """Returns metadata for GCS images."""
  patch = typing.cast(gcs_image.GcsPatch, slide_embedding.patches[0].patch)
  source_image = patch.source
  max_image_size_bytes = max(
      0, max_endpoint_request_size_bytes - _WHOLE_IMAGE_SIZE_SAFTY_MARGIN
  )
  source_image_metadata, source_image_md_size = _get_gcs_whole_image_metadata(
      source_image,
      icc_profile_normalization,
      c_transform,
      max_image_size_bytes,
  )
  if not source_image_metadata or _patch_pixel_area(slide_embedding) < int(
      0.95 * float(source_image.width * source_image.height)
  ):
    # If patch area is smaller than 95% source image area, compute patch
    # metadata. (5% factor to account for over head associated defining
    # multiple patches instead of single image.)
    patch_metadata = {
        EndpointJsonKeys.SOURCE_IMAGE_WIDTH_PX: int(patch.source.width),
        EndpointJsonKeys.SOURCE_IMAGE_HEIGHT_PX: int(patch.source.height),
        EndpointJsonKeys.ICC_PROFILE_METADATA_NORMALIZATION: (
            icc_profile_normalization.value
        ),
        EndpointJsonKeys.PATCHES: [
            typing.cast(gcs_image.GcsPatch, ip.patch).json_metadata(c_transform)
            for ip in slide_embedding.patches
        ],
    }
    if (
        not source_image_metadata
        or _get_gcs_image_md_size(patch_metadata) < source_image_md_size
    ):
      return patch_metadata
  return source_image_metadata


def _get_gcs_image_metadata(
    max_endpoint_request_size_bytes: int,
    encode_patch_data_in_request: bool,
    icc_profile_normalization: IccProfileNormalization,
    icc_profile_bytes: bytes,
    slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Mapping[str, Any]:
  """Returns metadata for GCS images."""
  source_image = typing.cast(
      gcs_image.GcsPatch, slide_embedding.patches[0].patch
  ).source
  c_transform = source_image.create_icc_profile_transformation(
      icc_profile_bytes
  )
  gcs_image_url = source_image.uri
  if not gcs_image_url:
    # always send image data if no image url.
    return _gcs_image_json_metadata(
        slide_embedding,
        icc_profile_normalization,
        c_transform,
        max_endpoint_request_size_bytes,
    )
  if encode_patch_data_in_request and source_image.are_image_bytes_loaded:
    json_metadata = _gcs_image_json_metadata(
        slide_embedding,
        icc_profile_normalization,
        c_transform,
        max_endpoint_request_size_bytes,
    )
    if (
        source_image.size_bytes_of_source_image is None
        or _get_gcs_image_md_size(json_metadata)
        <= source_image.size_bytes_of_source_image
    ):
      # Send JSON metadata if the image was initialized from raw bytes or
      # the size of the json is smaller than source image.
      return json_metadata
  return {}


def _get_dicom_instance_uids_and_required_levels(
    slide_embedding: patch_embedding_types.SlideEmbeddingSource,
) -> Tuple[List[str], List[str]]:
  """Returns list of DICOM instances and leveles required for WSI patches."""
  instance_uids = set()
  required_levels = []
  for patch in slide_embedding.patches:
    level = typing.cast(
        dicom_slide.DicomPatch, patch.patch
    ).get_pyramid_imaging_source_level()
    if level in required_levels:
      continue
    required_levels.append(level)
    for instance in level.instances.values():
      instance_uids.add(instance.dicom_object.get_value(tags.SOP_INSTANCE_UID))
  return list(instance_uids), required_levels


class AbstractVertexPatchEmbeddingEndpointBase(
    AbstractPatchEmbeddingEndpoint[_VertexModelResult]
):
  """Shared implementation of patch embedding endpoint for V1 and V2 endpoints."""

  def __init__(
      self,
      patch_width: int,
      patch_height: int,
      icc_profile_normalization: IccProfileNormalization,
      send_gcs_patch_bytes_from_client_to_server: bool,
      end_point_url: str,
      max_threads: int,
      max_patches_per_request: int,
      endpoint_max_patches_per_request: int,
      retry_count: int,
      credential_factory: Optional[
          credential_factory_module.AbstractCredentialFactory
      ],
      timeout: Optional[Union[int, float]],
  ):
    super().__init__(icc_profile_normalization, timeout)
    self._patch_width = patch_width
    self._patch_height = patch_height
    self._credentials = None
    self._credentials_factory = (
        credential_factory
        if credential_factory is not None
        else credential_factory_module.DefaultCredentialFactory()
    )
    self._send_gcs_patch_bytes_from_client_to_server = (
        send_gcs_patch_bytes_from_client_to_server
    )
    self._end_point_url = end_point_url
    self._max_threads = max(1, min(max_threads, _MAX_ENDPOINT_THREADS))
    self._endpoint_max_patches_per_request = int(
        max(1, endpoint_max_patches_per_request)
    )
    self._max_patches_per_request = int(
        max(
            1,
            min(
                max_patches_per_request, self._endpoint_max_patches_per_request
            ),
        )
    )
    self._retry_count = max(0, retry_count)

  @property
  def end_point_url(self) -> str:
    return self._end_point_url

  def max_request_size_bytes(self) -> int:
    """Maximum size in bytes that can be sent in single request."""
    return _MAX_VERTEX_AI_V1_REQUEST_SIZE_BYTES

  @property
  def credentials(self) -> google.auth.credentials.Credentials:
    if self._credentials is None:
      self._credentials = self._credentials_factory.get_credentials()
    else:
      self._credentials = credential_factory_module.refresh_credentials(
          self._credentials, self._credentials_factory
      )
    return self._credentials

  def vertex_endpoint_authentication_header(self) -> MutableMapping[str, str]:
    headers = {}
    self.credentials.apply(headers)
    return headers

  def retry_count(self) -> int:
    return self._retry_count

  def max_threads(self) -> int:
    return self._max_threads

  def patch_width(self) -> int:
    return self._patch_width

  def patch_height(self) -> int:
    return self._patch_height

  def max_number_of_patches_per_request(self) -> int:
    return self._max_patches_per_request

  def endpoint_max_number_of_patches_per_request(self) -> int:
    """Maximum number of patches that can be sent to the endpoint at once."""
    return self._endpoint_max_patches_per_request

  @abc.abstractmethod
  def _dicom_patch_embedding_request(
      self,
      bearer_token: str,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> Mapping[str, Any]:
    """Returns DICOM patch embedding request.

    Args:
      bearer_token: Bearer for store requests.
      slide_embedding: DICOM embedding inputs.

    Returns:
      JSON formatted embedding request.
    """

  @abc.abstractmethod
  def get_embedding_request(
      self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
  ) -> str:
    """Returns patch embedding request.

    Args:
      embedding_inputs: Embedding inputs.

    Returns:
      JSON formatted embedding request.
    """

  @abc.abstractmethod
  def _gcs_patch_embedding_request(
      self,
      bearer_token: str,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> Mapping[str, Any]:
    """Returns GCS patch embedding request.

    Args:
      bearer_token: Bearer for store requests.
      slide_embedding: GCS embedding inputs.

    Returns:
      JSON formatted embedding request.
    """

  def prepare_embedding_request(
      self,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> PreparedVertexEmbeddingRequest:
    first_patch = slide_embedding.patches[0].patch
    if len(slide_embedding.patches) > 1:
      first_patch_type = type(first_patch)
      for p in slide_embedding.patches:
        if not isinstance(p.patch, first_patch_type):
          raise ez_wsi_errors.InternalError(
              'Patch in request are not all the same type.'
          )
    if isinstance(first_patch, dicom_slide.DicomPatch):
      bearer_token = slide_embedding.get_bearer_token()
      return PreparedVertexEmbeddingRequest(
          self._dicom_patch_embedding_request(
              bearer_token,
              slide_embedding,
          ),
          slide_embedding,
      )
    elif isinstance(first_patch, gcs_image.GcsPatch):
      bearer_token = slide_embedding.get_bearer_token()
      return PreparedVertexEmbeddingRequest(
          self._gcs_patch_embedding_request(
              bearer_token,
              slide_embedding,
          ),
          slide_embedding,
      )
    raise ez_wsi_errors.InternalError(
        'Patch is not a dicom_slide.DicomPatch or gcs_image.GcsPatch.'
    )

  @property
  def send_gcs_patch_bytes_from_client_to_server(self) -> bool:
    return self._send_gcs_patch_bytes_from_client_to_server

  @retrying.retry(**error_retry_util.HTTP_AUTH_ERROR_RETRY_CONFIG)
  @retrying.retry(**error_retry_util.HTTP_SERVER_ERROR_RETRY_CONFIG)
  def _request_embeddings(self, json_msg: str) -> str:
    """Sends json request to Vertex AI endpoint."""
    try:
      headers = self.vertex_endpoint_authentication_header()
      headers['Content-Length'] = f'{len(json_msg)}'
      headers['Content-Type'] = 'application/json'
      response = requests.post(
          self._end_point_url,
          headers=headers,
          data=json_msg,
          timeout=self.timeout,
      )
      # Raises a HTTPError if the response code was not 200
      response.raise_for_status()
      return response.text
    except requests.exceptions.HTTPError as exp:
      ez_wsi_errors.raise_ez_wsi_http_exception(exp.response.reason, exp)
    except requests.exceptions.Timeout as timeout_error:
      raise ez_wsi_errors.HttpRequestTimeoutError(
          str(timeout_error), 'Request Timeout'
      ) from timeout_error

  @abc.abstractmethod
  def _is_request_error_retryable(
      self, json_response: Mapping[str, Any]
  ) -> bool:
    """Returns true if error at request level is retryable."""

  @abc.abstractmethod
  def _decode_response(
      self,
      embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
      json_response: Mapping[str, Any],
  ) -> _VertexModelResult:
    """Decodes json_response response from Vertex AI endpoint into _VertexModelResult."""

  @abc.abstractmethod
  def _instance_has_retryable_error(self, json_dict: Mapping[str, Any]) -> bool:
    """Decodes response from Vertex AI endpoint into _VertexModelResult."""

  def _regenerate_instance_with_new_auth_token(
      self, prepared_request: PreparedVertexEmbeddingRequest
  ) -> PreparedVertexEmbeddingRequest:
    prepared_request = self.prepare_embedding_request(
        prepared_request.slide_embedding_source
    )
    prepared_request.finalize()
    return prepared_request

  def _merge_embedding_input_results(
      self, model_result: _VertexModelResult, partial_result: _VertexModelResult
  ) -> _VertexModelResult:
    partial_result_counter = 0
    for index in range(len(model_result.instances)):
      if self._instance_has_retryable_error(model_result.instances[index]):
        model_result.instances[index] = partial_result.instances[
            partial_result_counter
        ]
        partial_result_counter += 1
    return model_result

  def _regenerate_embedding_input_requests_with_new_auth_token(
      self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
  ) -> Tuple[str, Sequence[PreparedVertexEmbeddingRequest]]:
    embedding_inputs = [
        self._regenerate_instance_with_new_auth_token(ei)
        for ei in embedding_inputs
    ]
    return self.get_embedding_request(embedding_inputs), embedding_inputs

  def _retry_failed_embedding_input_requests(
      self,
      model_result: _VertexModelResult,
      embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
  ):
    retry_list = []
    for index, embedding_input in enumerate(embedding_inputs):
      if self._instance_has_retryable_error(model_result.instances[index]):
        retry_list.append(embedding_input)
    if not retry_list:
      return '', retry_list
    return self._regenerate_embedding_input_requests_with_new_auth_token(
        retry_list
    )

  def request_embeddings(
      self,
      embedding_inputs: Sequence[
          AbstractPreparedEmbeddingRequest[_VertexModelResult]
      ],
  ) -> _VertexModelResult:
    """Requests embeddings from Vertex AI endpoint.

    Method is roboust to expiration of authentication tokens in either instance
    or endpoint. Will retry three times to attempt to correct authentication
    issues.

    Args:
      embedding_inputs: Prepared list of instances (images) with patches to get
        embeddings.

    Returns:
      VertexModelResult, (List of JSON) results for each instance.
    """
    if not embedding_inputs:
      return _VertexModelResult([])
    embedding_inputs = typing.cast(
        Sequence[PreparedVertexEmbeddingRequest], embedding_inputs
    )
    attempts = 0
    # generates initial embedding request
    json_msg = self.get_embedding_request(embedding_inputs)
    if not json_msg and not embedding_inputs:
      return _VertexModelResult([])
    request_embedding_inputs = embedding_inputs
    model_result = None
    # This retry loop handles authentication errors encountered by
    # the endpoint attempting to connect to data sources for which
    # expired tokens have been provided. In this case we are connecting
    # to the endpoint, but the endpoint is unable to fulfill the request
    # completely and is returning errors in one or more of its responses.
    # More general authentication and retry logic for connecting to the endpoint
    # is handled by the decorators attached to _request_embeddings.
    while True:
      # Get get JSON response from Vertex AI endpoint.
      json_response = self._request_embeddings(json_msg)
      attempts += 1
      # Decode JSON response.
      try:
        json_response = json.loads(json_response)
      except json.JSONDecodeError as exp:
        raise ez_wsi_errors.PatchEmbeddingEndpointError(
            'Endpoint returned invalid JSON.'
        ) from exp
      # If error at response level is found and retryable, regenerate request
      # whole request with new authentication token and retry.
      if self._is_request_error_retryable(json_response):
        json_msg, request_embedding_inputs = (
            self._regenerate_embedding_input_requests_with_new_auth_token(
                request_embedding_inputs
            )
        )
        continue
      # Decode json response into _VertexModelResult.
      result = self._decode_response(request_embedding_inputs, json_response)
      if model_result is None:
        model_result = result
      else:
        # if results already processed merge with prexisiting results.
        self._merge_embedding_input_results(model_result, result)
      if attempts >= 3:
        # if on third attempt, then just return result.
        # should never take more than two attempts.
        return model_result
      # Retry to retrieve embeddings only for the instances that failed.
      # with retriable errors.
      json_msg, request_embedding_inputs = (
          self._retry_failed_embedding_input_requests(
              model_result, embedding_inputs
          )
      )
      if not json_msg:
        # if json message is empty then nothing to do but return.
        if model_result is None:
          # return empty result should never occure.
          return _VertexModelResult([])
        else:
          return model_result


class V1PatchEmbeddingEndpoint(AbstractVertexPatchEmbeddingEndpointBase):
  """Implements Patch embedding V1 API."""

  def __init__(
      self,
      patch_width: int = 224,
      patch_height: int = 224,
      endpoint_api_version: str = 'v1',  # Vertex API version
      project_id: str = 'hai-cd3-foundations',
      endpoint_location: str = 'us-central1',
      endpoint_id: str = '160',
      max_threads: int = _DEFAULT_ENDPOINT_THREADS,
      max_patches_per_request: int = _DEFAULT_MAX_PATCHES_PER_REQUEST,
      retry_count: int = _DEFAULT_RETRY_COUNT,
      send_gcs_patch_bytes_from_client_to_server: bool = False,
      credential_factory: Optional[
          credential_factory_module.AbstractCredentialFactory
      ] = None,
      timeout: Optional[Union[int, float]] = _DEFAULT_TIMEOUT,
  ):
    end_point: List[str] = [
        f'https://{endpoint_location}-aiplatform.googleapis.com',
        endpoint_api_version,
        'projects',
        project_id,
        'locations',
        endpoint_location,
        'endpoints',
        f'{endpoint_id}:predict',
    ]
    end_point_url = '/'.join([ep.strip('/') for ep in end_point])
    super().__init__(
        patch_width,
        patch_height,
        IccProfileNormalization.NONE,
        send_gcs_patch_bytes_from_client_to_server,
        end_point_url,
        max_threads,
        max_patches_per_request,
        _MAX_V1_ENDPOINT_PATCHES_PER_REQUEST,
        retry_count,
        credential_factory,
        timeout,
    )
    self._model_size = 'MEDIUM'
    self._model_kind = 'LOW_PIXEL_SPACING'

  def _dicom_patch_embedding_request(
      self,
      bearer_token: str,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> Mapping[str, Any]:
    """Generates Embedding request for image stored in DICOM Store."""
    patch = typing.cast(
        dicom_slide.DicomPatch, slide_embedding.patches[0].patch
    )
    source_series = patch.source
    path = source_series.path
    instance_uids, required_levels = (
        _get_dicom_instance_uids_and_required_levels(slide_embedding)
    )
    if patch.is_resized:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'V1 encoder does not support image level resizing.'
      )
    if not bearer_token:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'V1 encoder does not support empty bearer tokens.'
      )
    dicom_store_url = path.GetStorePath().complete_url
    if dicom_store_url.startswith('https://healthcare.googleapis.com/v1/'):
      # If calling Google DICOM store using V1 endpoint use V1 endpoint legacy
      # format when generating embeddings against a V1 Google DICOM store.
      dicom_store_url = f'projects/{path.project_id}/locations/{path.location}/datasets/{path.dataset_id}/dicomStores/{path.store_id}'
    return {
        EndpointJsonKeys.DICOM_WEB_STORE_URL: dicom_store_url,
        EndpointJsonKeys.DICOM_STUDY_UID: path.study_uid,
        EndpointJsonKeys.DICOM_SERIES_UID: path.series_uid,
        EndpointJsonKeys.BEARER_TOKEN: bearer_token,
        EndpointJsonKeys.EZ_WSI_STATE: source_series.json_metadata_dict(
            level_subset=required_levels,
            max_json_encoded_icc_profile_size=_MAX_DICOM_SLIDE_ICCPROFILE_METADATA_SIZE,
        ),
        EndpointJsonKeys.INSTANCE_UIDS: instance_uids,
        EndpointJsonKeys.PATCH_COORDINATES: [
            {
                EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
                EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
                EndpointJsonKeys.WIDTH: int(input.patch.width),
                EndpointJsonKeys.HEIGHT: int(input.patch.height),
            }
            for input in slide_embedding.patches
        ],
    }

  def _gcs_patch_embedding_request(
      self,
      bearer_token: str,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> Mapping[str, Any]:
    """Generates Embedding request for image stored in DICOM Store."""
    json_metadata = _get_gcs_image_metadata(
        self.max_request_size_bytes(),
        self.send_gcs_patch_bytes_from_client_to_server,
        self.icc_profile_normalization,
        self.icc_profile_bytes(),
        slide_embedding,
    )
    gcs_patch = typing.cast(
        gcs_image.GcsPatch, slide_embedding.patches[0].patch
    )
    uri = gcs_patch.source.uri
    if gcs_patch.is_resized:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'V1 encoder does not support image image resizing.'
      )
    if not bearer_token and gcs_patch.source.uri:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'V1 encoder does not support empty bearer tokens.'
      )
    return {
        EndpointJsonKeys.PROJECT_NAME: 'placeholder',
        EndpointJsonKeys.GCS_IMAGE_URL: uri,
        EndpointJsonKeys.BEARER_TOKEN: bearer_token,
        EndpointJsonKeys.EZ_WSI_STATE: json_metadata,
        EndpointJsonKeys.PATCH_COORDINATES: [
            {
                EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
                EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
                EndpointJsonKeys.WIDTH: int(input.patch.width),
                EndpointJsonKeys.HEIGHT: int(input.patch.height),
            }
            for input in slide_embedding.patches
        ],
    }

  def _validate_embedding_response(
      self,
      embedding_source: patch_embedding_types.SlideEmbeddingSource,
      embedding_response: Mapping[str, Any],
  ):
    """Validate embedding DICOM UID match patch request."""
    if embedding_source.patches and isinstance(
        embedding_source.patches[0].patch, dicom_slide.DicomPatch
    ):
      # Test StudyInstanceUID and SeriesInstanceUID from request and response
      # match if returning result for DICOM slide.
      patch = typing.cast(
          dicom_slide.DicomPatch, embedding_source.patches[0].patch
      )
      path = patch.source.path
      if (
          embedding_response[EndpointJsonKeys.DICOM_STUDY_UID] != path.study_uid
          or embedding_response[EndpointJsonKeys.DICOM_SERIES_UID]
          != path.series_uid
      ):
        raise ez_wsi_errors.PatchEmbeddingEndpointError(
            'Study or Series UID of embedding does not match request.'
        )

  def _is_request_error_retryable(
      self, json_response: Mapping[str, Any]
  ) -> bool:
    """Returns true if error at request level is retryable."""
    try:
      predictions = json_response[EndpointJsonKeys.PREDICTIONS]
      if isinstance(predictions, Mapping):
        # This is raw response vertex endpoint translates response
        # to alternative format
        error = predictions[EndpointJsonKeys.ERROR_RESPONSE]
      else:
        # This is response as translasted by vertex endpoint
        returned_slide_embeddings, error, ml_version = predictions
        del returned_slide_embeddings, ml_version
      return error == EndpointJsonKeys.INVALID_CREDENTIALS
    except (KeyError, ValueError, TypeError) as _:
      return False

  def _decode_response(
      self,
      embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
      json_response: Mapping[str, Any],
  ) -> _VertexModelResult:
    """Decodes response from Vertex AI endpoint into _VertexModelResult."""
    try:
      predictions = json_response[EndpointJsonKeys.PREDICTIONS]
      if isinstance(predictions, Mapping):
        # This is raw response vertex endpoint translates response
        # to alternative format
        returned_slide_embeddings = predictions[
            EndpointJsonKeys.EMBEDDING_RESULT
        ]
        error = predictions[EndpointJsonKeys.ERROR_RESPONSE]
      else:
        # This is response as translasted by vertex endpoint
        returned_slide_embeddings, error, _ = predictions
    except (KeyError, ValueError, TypeError) as exp:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'Endpoint returned incorrectly formatted JSON.'
      ) from exp
    if error is not None and error:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          f'Endpoint returned error: {error}'
      )
    # Test the number of slide embedding responses matches the request.
    if len(embedding_inputs) != len(returned_slide_embeddings):
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'Number of embedding responses received does not match number of'
          f' embedding requests; expected: {len(embedding_inputs)}; received:'
          f' {len(returned_slide_embeddings)}.'
      )
    return _VertexModelResult(returned_slide_embeddings)

  def _instance_has_retryable_error(self, json_dict: Mapping[str, Any]) -> bool:
    return False

  def process_response(
      self,
      embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
      msg: _VertexModelResult,
  ) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
    """Returns patch embedding results for input and returned embeddings."""
    result = []
    endpoint_patch_width = self.patch_width()
    endpoint_patch_height = self.patch_height()
    for patch_embeddings, instance_input in zip(
        msg.instances, embedding_inputs
    ):
      self._validate_embedding_response(instance_input, patch_embeddings)
      patch_embeddings = patch_embeddings[EndpointJsonKeys.PATCH_EMBEDDINGS]
      # Test the number of patches received for the slide matches the request.
      if len(patch_embeddings) != len(instance_input.patches):
        raise ez_wsi_errors.PatchEmbeddingEndpointError(
            'Number of patches in embedding response does not match request;'
            f' expected: {len(instance_input.patches)}; received:'
            f' {len(patch_embeddings)}.'
        )
      for patch_embedding, source in zip(
          patch_embeddings, instance_input.patches
      ):
        pc = patch_embedding[EndpointJsonKeys.PATCH_COORDINATE]
        # Test the coodinates of the patch matches the request.
        if not _test_patch_coordinates_match(
            pc,
            source.patch.x,
            source.patch.y,
            endpoint_patch_width,
            endpoint_patch_height,
        ):
          raise ez_wsi_errors.PatchEmbeddingEndpointError(
              'Embedding patch coordinates or dimensions do not match request.'
          )
        embedding_value = np.asarray(
            patch_embedding[EndpointJsonKeys.EMBEDDINGS]
        )
        result.append(
            patch_embedding_types.PatchEmbeddingEnsembleResult(
                source, embedding_value, None
            )
        )
    return result

  def get_embedding_request(
      self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
  ) -> str:
    """Returns JSON formmatted embedding request."""
    instances = ','.join(i.json for i in embedding_inputs)
    return f'{{"{EndpointJsonKeys.PARAMETERS}":{{"{EndpointJsonKeys.MODEL_SIZE}":"{self._model_size}","{EndpointJsonKeys.MODEL_KIND}":"{self._model_kind}"}},"{EndpointJsonKeys.INSTANCES}":[{instances}]}}'


def _gen_v2_extensions(
    require_fully_in_source_image: bool,
    image_dimensions: Optional[dicom_slide.ImageDimensions],
    icc_profile: IccProfileNormalization,
    ez_wsi_state: Mapping[str, Any],
) -> Mapping[str, Any]:
  """Returns extensions for pathology embeddings."""
  extension = {
      EndpointJsonKeys.REQUIRE_PATCHES_FULLY_IN_SOURCE_IMAGE: str(
          require_fully_in_source_image
      )
  }
  if image_dimensions is not None:
    extension[EndpointJsonKeys.IMAGE_DIMENSIONS] = dataclasses.asdict(
        image_dimensions
    )
  if icc_profile:
    extension[EndpointJsonKeys.TRANSFORM_IMAGING_TO_ICC_PROFILE] = str(
        icc_profile.value
    )
  if ez_wsi_state:
    extension[EndpointJsonKeys.EZ_WSI_STATE] = ez_wsi_state
  return extension


def _format_error_message(error_code: str, error_description: str) -> str:
  if not error_description:
    return f'Error code: {error_code}'
  return f'Error code: {error_code}; {error_description}'


class V2PatchEmbeddingEndpoint(AbstractVertexPatchEmbeddingEndpointBase):
  """Implements Patch embedding V2 API."""

  def __init__(
      self,
      patch_width: int = 224,
      patch_height: int = 224,
      endpoint_api_version: str = 'v1',  # Vertex API version
      project_id: str = 'hai-cd3-foundations',
      endpoint_location: str = 'us-central1',
      endpoint_id: str = '162',
      max_threads: int = _DEFAULT_ENDPOINT_THREADS,
      max_patches_per_request: int = _DEFAULT_MAX_PATCHES_PER_REQUEST,
      retry_count: int = _DEFAULT_RETRY_COUNT,
      icc_profile_normalization: IccProfileNormalization = (
          IccProfileNormalization.NONE
      ),
      send_gcs_patch_bytes_from_client_to_server: bool = False,
      require_fully_in_source_image: bool = True,
      credential_factory: Optional[
          credential_factory_module.AbstractCredentialFactory
      ] = None,
      timeout: Optional[Union[int, float]] = _DEFAULT_TIMEOUT,
  ):
    end_point: List[str] = [
        f'https://{endpoint_location}-aiplatform.googleapis.com',
        endpoint_api_version,
        'projects',
        project_id,
        'locations',
        endpoint_location,
        'endpoints',
        f'{endpoint_id}:predict',
    ]
    end_point_url = '/'.join([ep.strip('/') for ep in end_point])
    super().__init__(
        patch_width,
        patch_height,
        icc_profile_normalization,
        send_gcs_patch_bytes_from_client_to_server,
        end_point_url,
        max_threads,
        max_patches_per_request,
        _MAX_V2_ENDPOINT_PATCHES_PER_REQUEST,
        retry_count,
        credential_factory,
        timeout,
    )
    self._require_fully_in_source_image = require_fully_in_source_image

  def _dicom_patch_embedding_request(
      self,
      bearer_token: str,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> Mapping[str, Any]:
    """Generates Embedding request for image stored in DICOM Store."""
    patch = typing.cast(
        dicom_slide.DicomPatch, slide_embedding.patches[0].patch
    )
    instance_uids, required_levels = (
        _get_dicom_instance_uids_and_required_levels(slide_embedding)
    )
    if patch.is_resized:
      image_dimensions = dicom_slide.ImageDimensions(
          patch.level.width, patch.level.height
      )
    else:
      image_dimensions = None
    request = {
        EndpointJsonKeys.DICOM_PATH: {
            EndpointJsonKeys.SERIES_PATH: str(
                patch.source.path.GetSeriesPath()
            ),
            EndpointJsonKeys.INSTANCE_UIDS: instance_uids,
        },
        EndpointJsonKeys.EXTENSIONS: _gen_v2_extensions(
            self._require_fully_in_source_image,
            image_dimensions,
            self._icc_profile_normalization,
            patch.source.json_metadata_dict(
                level_subset=required_levels,
                max_json_encoded_icc_profile_size=_MAX_DICOM_SLIDE_ICCPROFILE_METADATA_SIZE,
            ),
        ),
        EndpointJsonKeys.PATCH_COORDINATES: [
            {
                EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
                EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
                EndpointJsonKeys.WIDTH: int(input.patch.width),
                EndpointJsonKeys.HEIGHT: int(input.patch.height),
            }
            for input in slide_embedding.patches
        ],
    }
    if bearer_token:
      request[EndpointJsonKeys.BEARER_TOKEN] = bearer_token
    return request

  def _gcs_patch_embedding_request(
      self,
      bearer_token: str,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> Mapping[str, Any]:
    """Generates Embedding request for image stored in DICOM Store."""
    json_metadata = _get_gcs_image_metadata(
        self.max_request_size_bytes(),
        self.send_gcs_patch_bytes_from_client_to_server,
        self.icc_profile_normalization,
        self.icc_profile_bytes(),
        slide_embedding,
    )
    gcs_patch = typing.cast(
        gcs_image.GcsPatch, slide_embedding.patches[0].patch
    )
    uri = gcs_patch.source.uri
    request = {
        EndpointJsonKeys.IMAGE_FILE_URI: (
            uri if uri else 'gs:///placeholder.png'
        ),
        EndpointJsonKeys.EXTENSIONS: _gen_v2_extensions(
            self._require_fully_in_source_image,
            gcs_patch.source.resize_dimensions,
            self._icc_profile_normalization,
            json_metadata,
        ),
        EndpointJsonKeys.PATCH_COORDINATES: [
            {
                EndpointJsonKeys.X_ORIGIN: int(input.patch.x),
                EndpointJsonKeys.Y_ORIGIN: int(input.patch.y),
                EndpointJsonKeys.WIDTH: int(input.patch.width),
                EndpointJsonKeys.HEIGHT: int(input.patch.height),
            }
            for input in slide_embedding.patches
        ],
    }
    if bearer_token:
      request[EndpointJsonKeys.BEARER_TOKEN] = bearer_token
    return request

  def get_embedding_request(
      self, embedding_inputs: Sequence[PreparedVertexEmbeddingRequest]
  ) -> str:
    """Returns JSON formmatted embedding request."""
    instances = ','.join(i.json for i in embedding_inputs)
    return f'{{"{EndpointJsonKeys.INSTANCES}":[{instances}]}}'

  def _is_request_error_retryable(
      self, json_response: Mapping[str, Any]
  ) -> bool:
    """Returns true if error at request level is retryable."""
    try:
      error_code = json_response[EndpointJsonKeys.VERTEXAI_ERROR]
      if isinstance(error_code, dict):
        error_code = error_code[EndpointJsonKeys.ERROR_CODE]
      return error_code == EndpointJsonKeys.INVALID_CREDENTIALS
    except (KeyError, ValueError, TypeError):
      return False

  def _decode_response(
      self,
      embedding_inputs: Sequence[PreparedVertexEmbeddingRequest],
      json_response: Mapping[str, Any],
  ) -> _VertexModelResult:
    """Decodes response from Vertex AI endpoint into _VertexModelResult."""
    try:
      returned_slide_embeddings = json_response[EndpointJsonKeys.PREDICTIONS]
    except (KeyError, ValueError, TypeError):
      try:
        error_code = json_response[EndpointJsonKeys.VERTEXAI_ERROR]
        error_description = ''
        if isinstance(error_code, dict):
          error_description = error_code.get(
              EndpointJsonKeys.ERROR_CODE_DESCRIPTION, ''
          )
          error_code = error_code[EndpointJsonKeys.ERROR_CODE]
        if isinstance(error_code, str):
          msg = _format_error_message(error_code, error_description)
          raise ez_wsi_errors.PatchEmbeddingEndpointError(  # pylint: disable=raise-missing-from
              f'Endpoint error; {msg}'
          )
      except (KeyError, ValueError, TypeError):
        pass
      raise ez_wsi_errors.PatchEmbeddingEndpointError(  # pylint: disable=raise-missing-from
          'Endpoint did not return a valid JSON response.'
      )
    # Test the number of slide embedding responses matches the request.
    if len(embedding_inputs) != len(returned_slide_embeddings):
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'Number of embedding responses received does not match number of'
          f' embedding requests; expected: {len(embedding_inputs)}; received:'
          f' {len(returned_slide_embeddings)}.'
      )
    return _VertexModelResult(returned_slide_embeddings)

  def _instance_has_retryable_error(self, json_dict: Mapping[str, Any]) -> bool:
    """Decodes response from Vertex AI endpoint into _VertexModelResult."""
    error = json_dict.get(EndpointJsonKeys.ERROR)
    if error is None:
      return False
    try:
      return (
          error[EndpointJsonKeys.ERROR_CODE]
          == EndpointJsonKeys.INVALID_CREDENTIALS
      )
    except (KeyError, TypeError, IndexError) as _:
      return False

  def process_response(
      self,
      embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
      msg: _VertexModelResult,
  ) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
    """Returns patch embedding results for input and returned embeddings."""
    result = []
    endpoint_patch_width = self.patch_width()
    endpoint_patch_height = self.patch_height()
    for returned_instance, instance_input in zip(
        msg.instances, embedding_inputs
    ):
      try:
        error = returned_instance.get(EndpointJsonKeys.ERROR)
        if error is not None:
          error_code = error[EndpointJsonKeys.ERROR_CODE]
          error_description = error.get(
              EndpointJsonKeys.ERROR_CODE_DESCRIPTION, ''
          )
          error_message = '\n'.join([
              'Endpoint error generating instance embeddings.',
              f'Endpoint: {self.end_point_url}',
              f'{_format_error_message(error_code, error_description)}',
          ])
          error = patch_embedding_types.PatchEmbeddingError(
              error_code, error_message
          )
          # Return PatchEmbeddingEnsembleResult with an error
          # for each expected patch. Errors will be raised when
          # embedding values from the patches are accessed. Typically this will
          # occure almost immediately after during ensemble reduction. However,
          # this will also enable callers using PatchEmbeddingSequence
          # (indexed based) to access data for instances which may succeed
          # after a instance fails.
          for patch_source in instance_input.patches:
            result.append(
                patch_embedding_types.PatchEmbeddingEnsembleResult(
                    patch_source, None, error
                )
            )
          continue
        patch_embeddings = returned_instance[EndpointJsonKeys.RESULT][
            EndpointJsonKeys.PATCH_EMBEDDINGS
        ]
        # Test the number of patches received for the slide matches the request.
        if len(patch_embeddings) != len(instance_input.patches):
          raise ez_wsi_errors.PatchEmbeddingEndpointError(
              'Number of patches in embedding response does not match request;'
              f' expected: {len(instance_input.patches)}; received:'
              f' {len(patch_embeddings)}.'
          )
        for patch_embedding, patch_source in zip(
            patch_embeddings, instance_input.patches
        ):
          pc = patch_embedding[EndpointJsonKeys.PATCH_COORDINATE]
          # Test the coodinates of the patch matches the request.
          if not _test_patch_coordinates_match(
              pc,
              patch_source.patch.x,
              patch_source.patch.y,
              endpoint_patch_width,
              endpoint_patch_height,
          ):
            raise ez_wsi_errors.PatchEmbeddingEndpointError(
                'Embedding patch coordinates or dimensions do not match'
                ' request.'
            )
          embedding_value = np.asarray(
              patch_embedding[EndpointJsonKeys.EMBEDDING_VECTOR]
          )
          result.append(
              patch_embedding_types.PatchEmbeddingEnsembleResult(
                  patch_source, embedding_value, None
              )
          )
      except (KeyError, IndexError, TypeError, ValueError) as exp:
        raise ez_wsi_errors.PatchEmbeddingEndpointError(
            'Endpoint returned an unexpected response.'
        ) from exp
    return result


class PreparedLocalEmbeddingRequest(
    AbstractPreparedEmbeddingRequest[np.ndarray]
):
  """Base class for prepared embedding requests."""

  def __init__(
      self,
      slide_embedding_source: patch_embedding_types.SlideEmbeddingSource,
      endpoint_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor],
      icc_profile_bytes: bytes,
      icc_profile_cache: cachetools.LRUCache,
      icc_profile_cache_lock: threading.Lock,
      require_fully_in_source_image: bool,
  ):
    super().__init__(slide_embedding_source)
    self._input_patch_bytes = []
    self._input_patch_bytes_future = None
    self._thread_pool = endpoint_thread_pool
    self._target_icc_profile_bytes = icc_profile_bytes
    self._icc_profile_cache = icc_profile_cache
    self._icc_profile_cache_lock = icc_profile_cache_lock
    self._require_fully_in_source_image = require_fully_in_source_image

  @property
  def json_size_in_bytes(self) -> int:
    return 0

  def _load_patch_bytes(self) -> List[np.ndarray]:
    """Loads patch bytes for each request."""
    if (
        self._slide_embedding_source is None
        or not self._slide_embedding_source.patches
    ):
      return []
    first_patch = self._slide_embedding_source.patches[0].patch
    if isinstance(first_patch, gcs_image.GcsPatch):
      # Make a shallow copy of source patch Source GcsImage to ensure
      # if image bytes are are not retained after the request is processed.
      source_copy = copy.copy(first_patch.source)
      image_bytes = []
      icc_profile_normalization = source_copy.create_icc_profile_transformation(
          self._target_icc_profile_bytes
      )
      for patch in self._slide_embedding_source.patches:
        p = patch.patch
        if (
            self._require_fully_in_source_image
            and not p.is_patch_fully_in_source_image()
        ):
          raise ez_wsi_errors.PatchOutsideOfImageDimensionsError(
              'A portion of the patch does not overlap the image.'
          )
        # create a patch with same coordinates a the temporary source.
        # require in source image is not relevant to the temp patch. Setting to
        # false.
        temp_patch = gcs_image.GcsPatch(
            source_copy, p.x, p.y, p.width, p.height, False
        )
        image_bytes.append(temp_patch.image_bytes(icc_profile_normalization))
      return image_bytes
    if isinstance(first_patch, dicom_slide.DicomPatch):
      # Load in slide DICOM slide imaging using frame cache that is not shared
      # to scope imageing bytes to the embedding request and also avoid.
      # possible LRU cache eviction across parallel reads.

      # Make a shallow copy of source patch Source DicomSlide or
      # or Dicom Microscopy image
      source_copy = copy.copy(first_patch.source)

      # Init new frame cache on the shallow copy source.
      # Source and copy will no longer share the cache.
      fc = source_copy.init_slide_frame_cache()

      # Construct list of patches to return embedding for.
      patch_list = typing.cast(
          List[dicom_slide.DicomPatch],
          [patch.patch for patch in self._slide_embedding_source.patches],
      )
      # Preload list of patches into frame cache. Copy across any loaded
      # imaging that was loaded in the original cache.
      source_copy.preload_patches_in_frame_cache(
          patch_list, False, first_patch.source.slide_frame_cache
      )

      if (
          self._target_icc_profile_bytes is None
          or not self._target_icc_profile_bytes
      ):
        icc_profile_normalization = None
      else:
        dicom_path = str(source_copy.path)
        with self._icc_profile_cache_lock:
          source_icc_profile_bytes = self._icc_profile_cache.get(dicom_path)
        if source_icc_profile_bytes is None:
          if isinstance(source_copy, dicom_slide.DicomSlide):
            source_icc_profile_bytes = source_copy.get_icc_profile_bytes()
          elif isinstance(source_copy, dicom_slide.DicomMicroscopeImage):
            source_icc_profile_bytes = source_copy.get_level_icc_profile_bytes(
                first_patch.level
            )
          else:
            raise ValueError('Unexpected object')
        with self._icc_profile_cache_lock:
          self._icc_profile_cache[dicom_path] = source_icc_profile_bytes
        icc_profile_normalization = (
            dicom_slide.create_icc_profile_transformation(
                source_icc_profile_bytes, self._target_icc_profile_bytes
            )
        )
      fc.block_until_frames_are_loaded()
      # Generate image bytes for each patch.
      image_bytes = []
      for p in patch_list:
        if (
            self._require_fully_in_source_image
            and not p.is_patch_fully_in_source_image()
        ):
          raise ez_wsi_errors.PatchOutsideOfImageDimensionsError(
              'A portion of the patch does not overlap the image.'
          )
        # create  copy of the patch.
        # set the patch to point to the copied source to make patch image
        # retrieval read from the copied sources frame cache.
        temp_patch = dicom_slide.DicomPatch(
            p.get_pyramid_imaging_source_level(),
            p.x,
            p.y,
            p.width,
            p.height,
            source_copy,
            p.level,
            # ensuring patch falls inside image dimensions is not relevant
            # to the temp patch.
            False,
        )
        image_bytes.append(temp_patch.image_bytes(icc_profile_normalization))
      return image_bytes
    raise ValueError('Unexpected object')

  @property
  def input_patch_bytes(self) -> List[np.ndarray]:
    if self._input_patch_bytes_future is not None:
      self._input_patch_bytes = self._input_patch_bytes_future.result()
      self._input_patch_bytes_future = None
    return self._input_patch_bytes

  def finalize(self) -> None:
    if self._thread_pool is None:
      self._input_patch_bytes_future = None
      self._input_patch_bytes = self._load_patch_bytes()
      return
    self._input_patch_bytes_future = self._thread_pool.submit(
        self._load_patch_bytes
    )

  def split(
      self, endpoint: AbstractPatchEmbeddingEndpoint[np.ndarray], max_size: int
  ) -> Tuple[
      Optional[AbstractPreparedEmbeddingRequest[np.ndarray]],
      patch_embedding_types.SlideEmbeddingSource,
  ]:
    # splitting is not relevant for local endpoints.
    # return an unsplit response.
    if self._slide_embedding_source is None:
      raise ValueError('Slide embedding source is None.')
    return None, self._slide_embedding_source


class LocalEndpoint(AbstractPatchEmbeddingEndpoint[np.ndarray]):
  """Endpoint for generating embeddings with a locally loaded model."""

  def __init__(
      self,
      model: Callable[[np.ndarray], np.ndarray],
      icc_profile_normalization: IccProfileNormalization = (
          IccProfileNormalization.NONE
      ),
      patch_width: int = 224,
      patch_height: int = 224,
      require_fully_in_source_image: bool = True,
      max_threads: int = _DEFAULT_ENDPOINT_THREADS,
      retry_count: int = _DEFAULT_RETRY_COUNT,
      max_patches_per_request: int = _DEFAULT_MAX_PATCHES_PER_REQUEST,
      dicom_instance_icc_profile_cache_count: int = _DEFAULT_DICOM_INSTANCE_ICC_PROFILE_CACHE_COUNT,
      timeout: Optional[Union[int, float]] = _DEFAULT_TIMEOUT,
  ):
    super().__init__(icc_profile_normalization, timeout)
    self._require_fully_in_source_image = require_fully_in_source_image
    self._patch_width = patch_width
    self._patch_height = patch_height
    self._max_threads = max(1, max_threads)
    self._max_patches_per_request = int(max(1, max_patches_per_request))
    self._retry_count = max(0, retry_count)
    if self._max_threads < 2:
      self._thread_pool = None
    else:
      self._thread_pool = concurrent.futures.ThreadPoolExecutor(
          max_workers=self._max_threads
      )
    self._dicom_instance_icc_profile_cache_count = (
        dicom_instance_icc_profile_cache_count
    )
    self._model = model
    self._icc_profile_cache = cachetools.LRUCache(
        self._dicom_instance_icc_profile_cache_count
    )
    self._icc_profile_cache_lock = threading.Lock()

  def __del__(self):
    if self._thread_pool is not None:
      self._thread_pool.shutdown(wait=False, cancel_futures=True)  # pylint: disable=attribute-error

  def __getstate__(self) -> MutableMapping[str, Any]:
    """Returns class state for pickle serialization."""
    state = copy.copy(self.__dict__)
    del state['_thread_pool']
    del state['_icc_profile_cache']
    del state['_icc_profile_cache_lock']
    return state

  def __setstate__(self, dct: MutableMapping[str, Any]) -> None:
    """Init class state from pickle serialization."""
    self.__dict__ = dct
    if self._max_threads < 2:
      self._thread_pool = None
    else:
      self._thread_pool = concurrent.futures.ThreadPoolExecutor(
          max_workers=self._max_threads
      )
    self._icc_profile_cache = cachetools.LRUCache(
        self._dicom_instance_icc_profile_cache_count
    )
    self._icc_profile_cache_lock = threading.Lock()

  def max_request_size_bytes(self) -> int:
    """Maximum size in bytes that can be sent in single request."""
    return 0xFFFFFFFFFFFFFFFF  # max size uint64

  def retry_count(self) -> int:
    return self._retry_count

  def max_threads(self) -> int:
    return self._max_threads

  def patch_width(self) -> int:
    return self._patch_width

  def patch_height(self) -> int:
    return self._patch_height

  def max_number_of_patches_per_request(self) -> int:
    return self._max_patches_per_request

  def endpoint_max_number_of_patches_per_request(self) -> int:
    """Maximum number of patches that can be sent to the endpoint at once."""
    return self._max_patches_per_request

  def prepare_embedding_request(
      self,
      slide_embedding: patch_embedding_types.SlideEmbeddingSource,
  ) -> PreparedLocalEmbeddingRequest:
    return PreparedLocalEmbeddingRequest(
        slide_embedding,
        self._thread_pool,
        self.icc_profile_bytes(),
        self._icc_profile_cache,
        self._icc_profile_cache_lock,
        self._require_fully_in_source_image,
    )

  def normalize_imaging(self, input_patch_bytes: np.ndarray) -> np.ndarray:
    """Normalizes input patch bytes to float32 in range [0, 1]."""
    return input_patch_bytes.astype(np.float32) / _UINT8_MAX_VALUE

  def generate_ml_input(
      self,
      embedding_inputs: Sequence[AbstractPreparedEmbeddingRequest[np.ndarray]],
  ) -> np.ndarray:
    """Generates ML input for local model."""
    if not embedding_inputs:
      return np.zeros((), dtype=np.float32)
    normalized_imaging_list = []
    patch_width = self.patch_width()
    patch_height = self.patch_height()
    for e in embedding_inputs:
      e = typing.cast(PreparedLocalEmbeddingRequest, e)
      for single_patch in e.input_patch_bytes:
        normalized_imaging_list.append(
            np.expand_dims(
                normalized_patch_channels(
                    patch_width,
                    patch_height,
                    self.normalize_imaging(single_patch),
                ),
                axis=0,
            )
        )
    return np.concatenate(normalized_imaging_list, axis=0)

  def request_embeddings(
      self,
      embedding_inputs: Sequence[AbstractPreparedEmbeddingRequest[np.ndarray]],
  ) -> np.ndarray:
    """Returns embeddings for input patches."""
    ml_input = self.generate_ml_input(embedding_inputs)
    if not ml_input.shape:
      return ml_input
    return self._model(ml_input)

  def process_response(
      self,
      embedding_inputs: Sequence[patch_embedding_types.SlideEmbeddingSource],
      msg: np.ndarray,
  ) -> List[patch_embedding_types.PatchEmbeddingEnsembleResult]:
    """Converts raw embedding response to list of embedding results."""
    if not bool(msg.shape):
      generated_embedding_count = 0
    else:
      generated_embedding_count = msg.shape[0]
    total_patches = sum(len(i.patches) for i in embedding_inputs)
    if total_patches != generated_embedding_count:
      raise ez_wsi_errors.PatchEmbeddingEndpointError(
          'Number of patches in embedding response does not match request.'
      )
    results = []
    index = 0
    for instance_input in embedding_inputs:
      for patch_source in instance_input.patches:
        results.append(
            patch_embedding_types.PatchEmbeddingEnsembleResult(
                patch_source, msg[index, ...], None
            )
        )
        index += 1
    return results
