ax/storage/json_store/registry.py (242 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Callable, Dict, Type from ax.benchmark.benchmark_problem import BenchmarkProblem, SimpleBenchmarkProblem from ax.benchmark.benchmark_result import BenchmarkResult from ax.core import ObservationFeatures from ax.core.arm import Arm from ax.core.base_trial import TrialStatus from ax.core.batch_trial import AbandonedArm, BatchTrial, GeneratorRunStruct from ax.core.data import Data from ax.core.experiment import DataType, Experiment from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData, MapKeyInfo from ax.core.map_metric import MapMetric from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, OptimizationConfig, ) from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint from ax.core.parameter import ( ChoiceParameter, FixedParameter, ParameterType, RangeParameter, ) from ax.core.parameter_constraint import ( OrderConstraint, ParameterConstraint, SumConstraint, ) from ax.core.search_space import SearchSpace, HierarchicalSearchSpace from ax.core.trial import Trial from ax.core.types import ComparisonOp from ax.early_stopping.strategies import ( PercentileEarlyStoppingStrategy, ThresholdEarlyStoppingStrategy, ) from ax.metrics.branin import AugmentedBraninMetric, BraninMetric, NegativeBraninMetric from ax.metrics.branin_map import BraninTimestampMapMetric from ax.metrics.chemistry import ChemistryProblemType, ChemistryMetric from ax.metrics.factorial import FactorialMetric from ax.metrics.hartmann6 import AugmentedHartmann6Metric, Hartmann6Metric from ax.metrics.l2norm import L2NormMetric from ax.metrics.noisy_function import NoisyFunctionMetric from ax.metrics.sklearn import SklearnMetric, SklearnDataset, SklearnModelType from ax.modelbridge.factory import Models from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.winsorize import WinsorizationConfig from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.list_surrogate import ListSurrogate from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.runners.synthetic import SyntheticRunner from ax.storage.json_store.decoders import ( class_from_json, transform_type_from_json, ) from ax.storage.json_store.encoders import ( arm_to_dict, batch_to_dict, benchmark_problem_to_dict, botorch_component_to_dict, botorch_model_to_dict, botorch_modular_to_dict, choice_parameter_to_dict, data_to_dict, experiment_to_dict, fixed_parameter_to_dict, generation_step_to_dict, generation_strategy_to_dict, generator_run_to_dict, map_data_to_dict, map_key_info_to_dict, metric_to_dict, multi_objective_optimization_config_to_dict, multi_objective_to_dict, multi_type_experiment_to_dict, objective_to_dict, observation_features_to_dict, optimization_config_to_dict, order_parameter_constraint_to_dict, outcome_constraint_to_dict, parameter_constraint_to_dict, percentile_early_stopping_strategy_to_dict, range_parameter_to_dict, runner_to_dict, scalarized_objective_to_dict, search_space_to_dict, sum_parameter_constraint_to_dict, surrogate_to_dict, transform_type_to_dict, trial_to_dict, threshold_early_stopping_strategy_to_dict, winsorization_config_to_dict, ) from ax.storage.utils import DomainType, ParameterConstraintType from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.model import Model from gpytorch.constraints import Interval from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.priors.torch_priors import GammaPrior from torch.nn import Module CORE_ENCODER_REGISTRY: Dict[Type, Callable[[Any], Dict[str, Any]]] = { Arm: arm_to_dict, AugmentedBraninMetric: metric_to_dict, AugmentedHartmann6Metric: metric_to_dict, BatchTrial: batch_to_dict, BenchmarkProblem: benchmark_problem_to_dict, BoTorchModel: botorch_model_to_dict, BraninMetric: metric_to_dict, BraninTimestampMapMetric: metric_to_dict, ChoiceParameter: choice_parameter_to_dict, Data: data_to_dict, Experiment: experiment_to_dict, FactorialMetric: metric_to_dict, FixedParameter: fixed_parameter_to_dict, GammaPrior: botorch_component_to_dict, GenerationStep: generation_step_to_dict, GenerationStrategy: generation_strategy_to_dict, GeneratorRun: generator_run_to_dict, Hartmann6Metric: metric_to_dict, Interval: botorch_component_to_dict, ListSurrogate: surrogate_to_dict, L2NormMetric: metric_to_dict, MapData: map_data_to_dict, MapKeyInfo: map_key_info_to_dict, MapMetric: metric_to_dict, Metric: metric_to_dict, MultiObjective: multi_objective_to_dict, MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict, MultiTypeExperiment: multi_type_experiment_to_dict, PercentileEarlyStoppingStrategy: percentile_early_stopping_strategy_to_dict, SklearnMetric: metric_to_dict, ChemistryMetric: metric_to_dict, NegativeBraninMetric: metric_to_dict, NoisyFunctionMetric: metric_to_dict, Objective: objective_to_dict, ObjectiveThreshold: outcome_constraint_to_dict, OptimizationConfig: optimization_config_to_dict, OrderConstraint: order_parameter_constraint_to_dict, OutcomeConstraint: outcome_constraint_to_dict, ParameterConstraint: parameter_constraint_to_dict, RangeParameter: range_parameter_to_dict, ScalarizedObjective: scalarized_objective_to_dict, SearchSpace: search_space_to_dict, HierarchicalSearchSpace: search_space_to_dict, SimpleBenchmarkProblem: benchmark_problem_to_dict, SumConstraint: sum_parameter_constraint_to_dict, Surrogate: surrogate_to_dict, SyntheticRunner: runner_to_dict, ThresholdEarlyStoppingStrategy: threshold_early_stopping_strategy_to_dict, Trial: trial_to_dict, ObservationFeatures: observation_features_to_dict, WinsorizationConfig: winsorization_config_to_dict, } # Registry for class types, not instances. # NOTE: Avoid putting a class along with its subclass in `CLASS_ENCODER_REGISTRY`. # The encoder iterates through this dictionary and uses the first superclass that # it finds, which might not be the intended superclass. CORE_CLASS_ENCODER_REGISTRY: Dict[Type, Callable[[Any], Dict[str, Any]]] = { Acquisition: botorch_modular_to_dict, # Ax MBM component AcquisitionFunction: botorch_modular_to_dict, # BoTorch component Likelihood: botorch_modular_to_dict, # BoTorch component Module: botorch_modular_to_dict, # BoTorch module MarginalLogLikelihood: botorch_modular_to_dict, # BoTorch component Model: botorch_modular_to_dict, # BoTorch component Transform: transform_type_to_dict, # Ax general (not just MBM) component } CORE_DECODER_REGISTRY: Dict[str, Type] = { "AbandonedArm": AbandonedArm, "AugmentedBraninMetric": AugmentedBraninMetric, "AugmentedHartmann6Metric": AugmentedHartmann6Metric, "Arm": Arm, "BatchTrial": BatchTrial, "BenchmarkProblem": BenchmarkProblem, "BenchmarkResult": BenchmarkResult, "BoTorchModel": BoTorchModel, "BraninMetric": BraninMetric, "BraninTimestampMapMetric": BraninTimestampMapMetric, "ChemistryMetric": ChemistryMetric, "ChemistryProblemType": ChemistryProblemType, "ChoiceParameter": ChoiceParameter, "ComparisonOp": ComparisonOp, "Data": Data, "DataType": DataType, "DomainType": DomainType, "Experiment": Experiment, "FactorialMetric": FactorialMetric, "FixedParameter": FixedParameter, "GammaPrior": GammaPrior, "GenerationStrategy": GenerationStrategy, "GenerationStep": GenerationStep, "GeneratorRun": GeneratorRun, "GeneratorRunStruct": GeneratorRunStruct, "Hartmann6Metric": Hartmann6Metric, "HierarchicalSearchSpace": HierarchicalSearchSpace, "Interval": Interval, "ListSurrogate": ListSurrogate, "L2NormMetric": L2NormMetric, "MapData": MapData, "MapMetric": MapMetric, "MapKeyInfo": MapKeyInfo, "Metric": Metric, "Models": Models, "MultiObjective": MultiObjective, "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, "MultiTypeExperiment": MultiTypeExperiment, "NegativeBraninMetric": NegativeBraninMetric, "NoisyFunctionMetric": NoisyFunctionMetric, "Objective": Objective, "ObjectiveThreshold": ObjectiveThreshold, "OptimizationConfig": OptimizationConfig, "OrderConstraint": OrderConstraint, "OutcomeConstraint": OutcomeConstraint, "ParameterConstraint": ParameterConstraint, "ParameterConstraintType": ParameterConstraintType, "ParameterType": ParameterType, "PercentileEarlyStoppingStrategy": PercentileEarlyStoppingStrategy, "RangeParameter": RangeParameter, "ScalarizedObjective": ScalarizedObjective, "SearchSpace": SearchSpace, "SimpleBenchmarkProblem": SimpleBenchmarkProblem, "SklearnDataset": SklearnDataset, "SklearnMetric": SklearnMetric, "SklearnModelType": SklearnModelType, "SumConstraint": SumConstraint, "Surrogate": Surrogate, "SyntheticRunner": SyntheticRunner, "Trial": Trial, "TrialStatus": TrialStatus, "ThresholdEarlyStoppingStrategy": ThresholdEarlyStoppingStrategy, "ObservationFeatures": ObservationFeatures, "WinsorizationConfig": WinsorizationConfig, } # Registry for class types, not instances. CORE_CLASS_DECODER_REGISTRY: Dict[str, Callable[[Dict[str, Any]], Any]] = { "Type[Acquisition]": class_from_json, "Type[AcquisitionFunction]": class_from_json, "Type[Likelihood]": class_from_json, "Type[Module]": class_from_json, "Type[MarginalLogLikelihood]": class_from_json, "Type[Model]": class_from_json, "Type[Transform]": transform_type_from_json, }