ez_wsi_dicomweb/credential_factory.py (153 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """DICOMweb Abstract Credential Factory and Default implementation.""" import abc import copy import hashlib import json import os import threading from typing import Any, Dict, List, Mapping, Optional, Union import cachetools import google.auth import google.auth.transport.requests from google.oauth2 import service_account import requests # Required by: google.auth.transport.requests.Request() _SCOPES = [ 'https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/cloud-healthcare', ] _ONE_HOUR_IN_SECONDS = 60 * 60 _cache_tools_lock = threading.Lock() _credential_factory_cache = cachetools.TTLCache( maxsize=100, ttl=_ONE_HOUR_IN_SECONDS ) def _init_fork_module_state(): """Initializes the credential factory cache.""" global _cache_tools_lock global _credential_factory_cache _cache_tools_lock = threading.Lock() _credential_factory_cache = cachetools.TTLCache(maxsize=100, ttl=60 * 60) def clear_credential_cache(): """Clears the credential factory cache to force credential refresh.""" with _cache_tools_lock: _credential_factory_cache.clear() class AbstractCredentialFactory(metaclass=abc.ABCMeta): """Generates the credentials used to access DICOM store. Implementations of the abstract credential factory should be compatible with pickle serialization. The purpose of the credential factory is to enable EZ-WSI to construct the credentials needed to access the DICOM store following pickle deserialization. As one example, this enables EZ-WSI DICOM Store and DICOM Slide classes to be initialized once and passed through a cloud dataflow pipeline. """ @abc.abstractmethod def get_credentials(self) -> google.auth.credentials.Credentials: """Return credentials to use to access DICOM Store.""" def credential_source_hash(self) -> str: """Returns hash value to identify credential source. Empty string considered undefined. Used enable crediental caching across multiple credential factory instances that target the same credential source. Only should be defined if the acquisition of credential is time consuming. """ return '' def refresh_credentials( auth_credentials: google.auth.credentials.Credentials, credential_factory: Optional[AbstractCredentialFactory] = None, ) -> google.auth.credentials.Credentials: """Refreshs credentials.""" if not auth_credentials.valid: with _cache_tools_lock: auth_credentials.refresh(google.auth.transport.requests.Request()) if credential_factory is not None: credential_source_hash = credential_factory.credential_source_hash() if credential_source_hash: _credential_factory_cache[credential_source_hash] = auth_credentials return auth_credentials def get_default_gcp_project() -> str: """Return GCP project current user os runniing in.""" return google.auth.default(scopes=_SCOPES)[1] class CredentialFactory(AbstractCredentialFactory): """Factory for default or service account credential creation.""" def __init__( self, json_param: Optional[ Union[Mapping[str, Any], str, bytes, os.PathLike[Any]] ] = None, scopes: Optional[List[str]] = None, ) -> None: """Credential Factory Constructor. Args: json_param: Optional parameter that defines location of JSON file, or loaded JSON that contains service account credentials which should be used for auth. If undefined, then the default credentials of the running environment are used. scopes: Credential scopes if undefined defaults to: ['https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/cloud-healthcare',] """ if not json_param: self._json = {} elif ( isinstance(json_param, str) or isinstance(json_param, bytes) or isinstance(json_param, os.PathLike) ): # Read JSON from file.""" with open(json_param, 'rt') as infile: self._json = json.load(infile) else: # Use in memory JSON loaded in memory as python Dict. self._json = copy.copy(json_param) if not self._json: self._credential_source_hash = 'application_default_credentials' else: self._credential_source_hash = hashlib.sha3_512( json.dumps(self._json).encode('utf-8') ).hexdigest() self._scopes = _SCOPES if scopes is None else copy.copy(scopes) def get_credentials(self) -> google.auth.credentials.Credentials: """Returns credentials to use to accessing DICOM store.""" credential_source_hash = self.credential_source_hash() with _cache_tools_lock: credentials = _credential_factory_cache.get(credential_source_hash) if credentials is None: if self._json: credentials = service_account.Credentials.from_service_account_info( self._json, scopes=self._scopes ) else: credentials = google.auth.default(scopes=self._scopes)[0] _credential_factory_cache[credential_source_hash] = credentials return refresh_credentials(credentials, self) def credential_source_hash(self) -> str: return self._credential_source_hash ServiceAccountCredentialFactory = CredentialFactory class DefaultCredentialFactory(CredentialFactory): """Factory for default credential creation.""" def __init__(self, scopes: Optional[List[str]] = None) -> None: super().__init__(scopes=scopes) class PassThroughCredentials(google.auth.credentials.Credentials): """Credentials that do not provide any authentication refresh information. These are similar to anonymous credentials in that refreshing is not possible. """ def __init__(self, token: str): """Initializes a Pete Credential object with a token.""" super().__init__() self.token = token @property def expired(self) -> bool: """Returns False, assume tokens never expire.""" return False @property def valid(self) -> bool: """Returns True, assume tokens are always valid.""" return True def refresh(self, request: google.auth.transport.Request) -> None: return def apply(self, headers: Dict[Any, Any], token: Optional[str] = None) -> None: """Apply the token to the authentication header. Args: headers: The HTTP request headers. token: If specified, overrides the current access token. Returns: Nothing. """ headers['authorization'] = 'Bearer {}'.format(token or self.token) def before_request( self, request: google.auth.transport.Request, method: str, url: str, headers: Dict[Any, Any], ) -> None: """Performs credential-specific before request logic. Calls apply to apply the token to the authentication header. Args: request: The object used to make HTTP requests. method: The request's HTTP method or the RPC method being invoked. url: The request's URI or the RPC service's URI. headers: The request's headers. Returns: Nothing. """ self.apply(headers) class NoAuthCredentials(google.auth.credentials.Credentials): """Credentials that do not provide any authentication refresh information. These are similar to anonymous credentials in that refreshing is not possible. """ @property def expired(self) -> bool: """Never expire.""" return False @property def valid(self) -> bool: """Returns True, always valid.""" return True def refresh(self, request: google.auth.transport.Request) -> None: return def apply(self, headers: Dict[Any, Any], token: Optional[str] = None) -> None: return def before_request( self, request: google.auth.transport.Request, method: str, url: str, headers: Dict[Any, Any], ) -> None: return class TokenPassthroughCredentialFactory(AbstractCredentialFactory): """Factory for token passthrough credential creation.""" def __init__(self, bearer_token: str) -> None: """Credential Factory Constructor. Args: bearer_token: The user provided bearer token that allows us to access their dicom store. """ self._bearer_token = bearer_token self._credential = PassThroughCredentials(self._bearer_token) def get_credentials(self) -> google.auth.credentials.Credentials: """Returns credentials to use to accessing DICOM store.""" return self._credential class GoogleAuthCredentialFactory(AbstractCredentialFactory): """Enables external acquired Credentials to be used with EZ-WSI. credentials factory can not be pickled. """ def __init__(self, credentials: google.auth.credentials.Credentials): self._auth_credentials = credentials def get_credentials(self) -> google.auth.credentials.Credentials: """Return credentials.""" # credentials origin unknown, so don't cache. Pass None for # credential_factory. return refresh_credentials(self._auth_credentials, None) class NoAuthCredentialsFactory(AbstractCredentialFactory): """Credentials that do not provide any authentication information.""" def __init__(self): self._auth_credentials = NoAuthCredentials() def get_credentials(self) -> google.auth.credentials.Credentials: """Return credentials that provide no authentication information.""" return self._auth_credentials # If module is forked, re-init the threading lock and cache to ensure that # state from the parent process is not used in the child. os.register_at_fork( after_in_child=_init_fork_module_state, # pylint: disable=protected-access )