ax/storage/registry_bundle.py (111 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from abc import abstractproperty, ABC from typing import Any, Callable, Optional, Type, Dict from ax.core.metric import Metric from ax.core.runner import Runner from ax.storage.json_store.registry import ( CORE_ENCODER_REGISTRY, CORE_CLASS_DECODER_REGISTRY, CORE_DECODER_REGISTRY, CORE_CLASS_ENCODER_REGISTRY, ) from ax.storage.metric_registry import register_metrics from ax.storage.runner_registry import register_runners from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.encoder import Encoder from ax.storage.sqa_store.sqa_config import SQAConfig class RegistryBundleBase(ABC): """An abstraction to help with storing experiments with custom Metrics and Runners. Rather than managing registries individually, the RegistryBundle consumes custom metrics, runners, and configuration information and is lazily creates the storage registries needed for saving and loading. Args: metric_clss: A dictionary from Metric classes to the int their type should be encoded as in the associated SQAMetric. If no None is passed for the int a hash will be generated. runner_clss: A dictionary from Runner classes to the int their type should be encoded as in the associated SQARunner. If no None is passed for the int a hash will be generated. json_encoder_registry: A dictionary from Types to methods from an instance of the type to JSON. json_class_encoder_registry: A dictionary from Types to methods from the type's class to JSON. json_decoder_registry: A dictionary from str class labels to their associated Type. json_class_decoder_registry: A dictionary from str class labels to an associated method for reconstruction. """ def __init__( self, metric_clss: Dict[Type[Metric], Optional[int]], runner_clss: Dict[Type[Runner], Optional[int]], json_encoder_registry: Dict[Type, Callable[[Any], Dict[str, Any]]], json_class_encoder_registry: Dict[Type, Callable[[Any], Dict[str, Any]]], json_decoder_registry: Dict[str, Type], json_class_decoder_registry: Dict[str, Callable[[Dict[str, Any]], Any]], ) -> None: self._metric_registry, encoder_registry, decoder_registry = register_metrics( metric_clss=metric_clss, encoder_registry=json_encoder_registry, decoder_registry=json_decoder_registry, ) ( self._runner_registry, self._encoder_registry, self._decoder_registry, ) = register_runners( runner_clss=runner_clss, encoder_registry=encoder_registry, decoder_registry=decoder_registry, ) self._json_class_encoder_registry = json_class_encoder_registry self._json_class_decoder_registry = json_class_decoder_registry @property def metric_registry(self) -> Dict[Type[Metric], int]: return self._metric_registry @property def runner_registry(self) -> Dict[Type[Runner], int]: return self._runner_registry @property def encoder_registry(self) -> Dict[Type, Callable[[Any], Dict[str, Any]]]: return self._encoder_registry @property def decoder_registry(self) -> Dict[str, Type]: return self._decoder_registry @property def class_encoder_registry(self) -> Dict[Type, Callable[[Any], Dict[str, Any]]]: return self._json_class_encoder_registry @property def class_decoder_registry(self) -> Dict[str, Callable[[Dict[str, Any]], Any]]: return self._json_class_decoder_registry @abstractproperty def sqa_config(self) -> SQAConfig: pass @abstractproperty def encoder(self) -> Encoder: pass @abstractproperty def decoder(self) -> Decoder: pass class RegistryBundle(RegistryBundleBase): """A concrete implementation of RegistryBundleBase with sensible defaults.""" def __init__( self, metric_clss: Dict[Type[Metric], Optional[int]], runner_clss: Dict[Type[Runner], Optional[int]], json_encoder_registry: Dict[ Type, Callable[[Any], Dict[str, Any]] ] = CORE_ENCODER_REGISTRY, json_class_encoder_registry: Dict[ Type, Callable[[Any], Dict[str, Any]] ] = CORE_CLASS_ENCODER_REGISTRY, json_decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, json_class_decoder_registry: Dict[ str, Callable[[Dict[str, Any]], Any] ] = CORE_CLASS_DECODER_REGISTRY, ) -> None: super().__init__( metric_clss=metric_clss, runner_clss=runner_clss, json_encoder_registry=json_encoder_registry, json_class_encoder_registry=json_class_encoder_registry, json_decoder_registry=json_decoder_registry, json_class_decoder_registry=json_class_decoder_registry, ) self._sqa_config = SQAConfig( json_encoder_registry={**self.encoder_registry, **CORE_ENCODER_REGISTRY}, json_decoder_registry={**self.decoder_registry, **CORE_DECODER_REGISTRY}, metric_registry=self.metric_registry, runner_registry=self.runner_registry, json_class_encoder_registry=self.class_encoder_registry, json_class_decoder_registry=self.class_decoder_registry, ) self._encoder = Encoder(self._sqa_config) self._decoder = Decoder(self._sqa_config) # TODO[mpolson64] change @property to @cached_property once we deprecate 3.7 @property def sqa_config(self) -> SQAConfig: return self._sqa_config @property def encoder(self) -> Encoder: return self._encoder @property def decoder(self) -> Decoder: return self._decoder