ax/benchmark2/benchmark.py (45 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Module for benchmarking Ax algorithms. Key terms used: * Replication: 1 run of an optimization loop; (BenchmarkProblem, BenchmarkMethod) pair. * Test: multiple replications, ran for statistical significance. * Full run: multiple tests on many (BenchmarkProblem, BenchmarkMethod) pairs. * Method: (one of) the algorithm(s) being benchmarked. * Problem: a synthetic function, a surrogate surface, or an ML model, on which to assess the performance of algorithms. """ from time import time from typing import List, Iterable from ax.benchmark2.benchmark_method import BenchmarkMethod from ax.benchmark2.benchmark_problem import BenchmarkProblem from ax.benchmark2.benchmark_result import BenchmarkResult, AggregatedBenchmarkResult from ax.core.experiment import Experiment from ax.service.scheduler import Scheduler def benchmark_replication( problem: BenchmarkProblem, method: BenchmarkMethod, ) -> BenchmarkResult: """Runs one benchmarking replication (equivalent to one optimization loop). Args: problem: The BenchmarkProblem to test against (can be synthetic or real) method: The BenchmarkMethod to test """ experiment = Experiment( name=f"{problem.name}x{method.name}_{time()}", search_space=problem.search_space, optimization_config=problem.optimization_config, runner=problem.runner, ) scheduler = Scheduler( experiment=experiment, generation_strategy=method.generation_strategy.clone_reset(), options=method.scheduler_options, ) scheduler.run_all_trials() return BenchmarkResult.from_scheduler(scheduler=scheduler) def benchmark_test( problem: BenchmarkProblem, method: BenchmarkMethod, num_replications: int = 10 ) -> AggregatedBenchmarkResult: return AggregatedBenchmarkResult.from_benchmark_results( results=[ benchmark_replication(problem=problem, method=method) for _ in range(num_replications) ] ) def benchmark_full_run( problems: Iterable[BenchmarkProblem], methods: Iterable[BenchmarkMethod], num_replications: int = 10, ) -> List[AggregatedBenchmarkResult]: return [ benchmark_test( problem=problem, method=method, num_replications=num_replications ) for problem in problems for method in methods ]