ax/utils/testing/benchmark_stubs.py (116 lines of code) (raw):
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import reduce
from types import FunctionType
from typing import Any, cast
import numpy as np
from ax.benchmark.benchmark_problem import BenchmarkProblem, SimpleBenchmarkProblem
from ax.benchmark2.benchmark_method import BenchmarkMethod
from ax.benchmark2.benchmark_problem import (
MultiObjectiveBenchmarkProblem,
SingleObjectiveBenchmarkProblem,
)
from ax.benchmark2.benchmark_result import (
AggregatedBenchmarkResult,
BenchmarkResult,
)
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.service.scheduler import SchedulerOptions
from ax.utils.common.constants import Keys
from ax.utils.measurement.synthetic_functions import branin
from ax.utils.testing.core_stubs import (
get_branin_optimization_config,
get_branin_search_space,
)
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.models.gp_regression import FixedNoiseGP
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.test_functions.synthetic import Branin
def get_branin_simple_benchmark_problem() -> SimpleBenchmarkProblem:
return SimpleBenchmarkProblem(f=branin)
def get_sum_simple_benchmark_problem() -> SimpleBenchmarkProblem:
return SimpleBenchmarkProblem(f=sum, name="Sum", domain=[(0.0, 1.0), (0.0, 1.0)])
def sample_multiplication_fxn(*args: Any) -> float:
return reduce(lambda x, y: x * y, args)
def get_mult_simple_benchmark_problem() -> SimpleBenchmarkProblem:
return SimpleBenchmarkProblem(
f=cast(FunctionType, sample_multiplication_fxn),
name="Sum",
domain=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
)
def get_branin_benchmark_problem() -> BenchmarkProblem:
return BenchmarkProblem(
search_space=get_branin_search_space(),
optimization_config=get_branin_optimization_config(),
optimal_value=branin.fmin,
evaluate_suggested=False,
)
# Benchmark2
def get_single_objective_benchmark_problem() -> SingleObjectiveBenchmarkProblem:
return SingleObjectiveBenchmarkProblem.from_botorch_synthetic(test_problem=Branin())
def get_multi_objective_benchmark_problem() -> MultiObjectiveBenchmarkProblem:
return MultiObjectiveBenchmarkProblem.from_botorch_multi_objective(
test_problem=BraninCurrin()
)
def get_sobol_benchmark_method() -> BenchmarkMethod:
return BenchmarkMethod(
name="SOBOL",
generation_strategy=GenerationStrategy(
steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)],
name="SOBOL",
),
scheduler_options=SchedulerOptions(
total_trials=4, init_seconds_between_polls=0
),
)
def get_sobol_gpei_benchmark_method() -> BenchmarkMethod:
return BenchmarkMethod(
name="MBO_SOBOL_GPEI",
generation_strategy=GenerationStrategy(
name="Modular::Sobol+GPEI",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=3, min_trials_observed=3),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
"surrogate": Surrogate(FixedNoiseGP),
"botorch_acqf_class": qNoisyExpectedImprovement,
},
model_gen_kwargs={
"model_gen_options": {
Keys.OPTIMIZER_KWARGS: {
"num_restarts": 50,
"raw_samples": 1024,
},
Keys.ACQF_KWARGS: {
"prune_baseline": True,
"qmc": True,
"mc_samples": 512,
},
}
},
),
],
),
scheduler_options=SchedulerOptions(
total_trials=4, init_seconds_between_polls=0
),
)
def get_benchmark_result() -> BenchmarkResult:
problem = get_single_objective_benchmark_problem()
return BenchmarkResult(
name="test_benchmrking_result",
experiment=Experiment(
name="test_benchmarking_experiment",
search_space=problem.search_space,
optimization_config=problem.optimization_config,
runner=problem.runner,
is_test=True,
),
optimization_trace=np.array([3, 2, 1, 0]),
fit_time=0.1,
gen_time=0.2,
)
def get_aggregated_benchmark_result() -> AggregatedBenchmarkResult:
return AggregatedBenchmarkResult.from_benchmark_results([get_benchmark_result()])