tutorials/generation_strategy.ipynb (566 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep\n",
"from ax.modelbridge.registry import Models, ModelRegistryBase\n",
"from ax.modelbridge.dispatch_utils import choose_generation_strategy\n",
"from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n",
"\n",
"from ax.utils.testing.core_stubs import get_branin_search_space, get_branin_experiment"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generation Strategy (GS) Tutorial\n",
"\n",
"`GenerationStrategy` ([API reference](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)) is a key abstraction in Ax:\n",
"- It allows for specifying multiple optimization algorithms to chain one after another in the course of the optimization. \n",
"- Many higher-level APIs in Ax use generation strategies: Service and Loop APIs, `Scheduler` etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/).\n",
"- Generation strategy allows for storage and resumption of modeling setups, making optimization resumable from SQL or JSON snapshots.\n",
"\n",
"This tutorial walks through a few examples of generation strategies and discusses its important settings. Before reading it, we recommend familiarizing yourself with how `Model` and `ModelBridge` work in Ax: https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack.\n",
"\n",
"**Contents:**\n",
"1. Quick-start examples\n",
" 1. Manually configured GS\n",
" 2. Auto-selected GS\n",
" 3. Candidate generation from a GS\n",
"2. Deep dive: `GenerationStep` a building block of the generation strategy\n",
" 1. Describing a model\n",
" 2. Other `GenerationStep` settings\n",
" 3. Chaining `GenerationStep`-s together\n",
" 4. `max_parallelism` enforcement and handling the `MaxParallelismReachedException`\n",
"3. `GenerationStrategy` storage\n",
" 1. JSON storage\n",
" 2. SQL storage\n",
"4. Advanced considerations / \"gotchas\"\n",
" 1. Generation strategy produces `GeneratorRun`-s, not `Trial`-s\n",
" 2. `model_kwargs` elements that don't have associated serialization logic in Ax\n",
" 3. Why prefer `Models` registry enum entries over a factory function?\n",
" 4. How to request more modeling setups in `Models`?\n",
" \n",
"----"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Quick-start examples\n",
"\n",
"### 1A. Manually configured generation strategy\n",
"\n",
"Below is a typical generation strategy used for most single-objective optimization cases in Ax:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"gs = GenerationStrategy(\n",
" steps=[\n",
" # 1. Initialization step (does not require pre-existing data and is well-suited for \n",
" # initial sampling of the search space)\n",
" GenerationStep(\n",
" model=Models.SOBOL,\n",
" num_trials=5, # How many trials should be produced from this generation step\n",
" min_trials_observed=3, # How many trials need to be completed to move to next model\n",
" max_parallelism=5, # Max parallelism for this step\n",
" model_kwargs={\"seed\": 999}, # Any kwargs you want passed into the model\n",
" model_gen_kwargs={}, # Any kwargs you want passed to `modelbridge.gen`\n",
" ),\n",
" # 2. Bayesian optimization step (requires data obtained from previous phase and learns\n",
" # from all data available at the time of each new candidate generation call)\n",
" GenerationStep(\n",
" model=Models.GPEI,\n",
" num_trials=-1, # No limitation on how many trials should be produced from this step\n",
" max_parallelism=3, # Parallelism limit for this step, often lower than for Sobol\n",
" # More on parallelism vs. required samples in BayesOpt:\n",
" # https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials\n",
" ),\n",
" ]\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1B. Auto-selected generation strategy\n",
"\n",
"Ax provides a [`choose_generation_strategy`](https://github.com/facebook/Ax/blob/main/ax/modelbridge/dispatch_utils.py#L115) utility, which can auto-select a suitable generation strategy given a search space and an array of other optional settings. The utility is fairly simple at the moment, but additional development (support for multi-objective optimization, multi-fidelity optimization, Bayesian optimization with categorical kernels etc.) is coming soon."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 06-15 07:59:03] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.\n"
]
},
{
"data": {
"text/plain": [
"GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])"
]
},
"execution_count": 12,
"metadata": {
"bento_obj_id": "139922521218736"
},
"output_type": "execute_result"
}
],
"source": [
"gs = choose_generation_strategy(\n",
" # Required arguments:\n",
" search_space=get_branin_search_space(), # Ax `SearchSpace`\n",
" \n",
" # Some optional arguments (shown with their defaults), see API docs for more settings:\n",
" # https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils\n",
" use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`\n",
" no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt\n",
" max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS \n",
")\n",
"gs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1C. Candidate generation from a generation strategy\n",
"\n",
"While often used through Service or Loop API or other higher-order abstractions like the Ax `Scheduler` (where the generation strategy is used to fit models and produce candidates from them under-the-hood), it's also possible to use the GS directly, in place of a `ModelBridge` instance. The interface of `GenerationStrategy.gen` is the same as `ModelBridge.gen`.\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"experiment = get_branin_experiment()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that it's important to **specify pending observations** to the call to `gen` to avoid getting the same points re-suggested. Without `pending_observations` argument, Ax models are not aware of points that should be excluded from generation. Points are considered \"pending\" when they belong to `STAGED`, `RUNNING`, or `ABANDONED` trials (with the latter included so model does not re-suggest points that are considered \"bad\" and should not be re-suggested).\n",
"\n",
"If the call to `get_pending_obervation_features` becomes slow in your setup (since it performs data-fetching etc.), you can opt for `get_pending_observation_features_based_on_trial_status` (also from `ax.modelbridge.modelbridge_utils`), but note the limitations of that utility (detailed in its docstring)."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GeneratorRun(1 arms, total weight 1.0)"
]
},
"execution_count": 23,
"metadata": {
"bento_obj_id": "139922521218448"
},
"output_type": "execute_result"
}
],
"source": [
"generator_run = gs.gen(\n",
" experiment=experiment, # Ax `Experiment`, for which to generate new candidates\n",
" data=None, # Ax `Data` to use for model training, optional.\n",
" n=1, # Number of candidate arms to produce\n",
" pending_observations=get_pending_observation_features(experiment), # Points that should not be re-generated\n",
" # Any other kwargs specified will be passed through to `ModelBridge.gen` along with `GenerationStep.model_gen_kwargs`\n",
")\n",
"generator_run"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can add the newly produced [`GeneratorRun`](https://ax.dev/docs/glossary.html#generator-run) to the experiment as a [`Trial` (or `BatchTrial` if `n` > 1)](https://ax.dev/docs/glossary.html#trial):"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Trial(experiment_name='branin_test_experiment', index=0, status=TrialStatus.CANDIDATE, arm=Arm(name='0_0', parameters={'x1': 2.4094051076099277, 'x2': 13.29242150299251}))"
]
},
"execution_count": 24,
"metadata": {
"bento_obj_id": "139923550679968"
},
"output_type": "execute_result"
}
],
"source": [
"trial = experiment.new_trial(generator_run)\n",
"trial"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Important notes on `GenerationStrategy.gen`:**\n",
"- if `data` argument above is not specified, GS will pull experiment data from cache via `experiment.lookup_data`,\n",
"- without specifying `pending_observations`, the GS (and any model in Ax) could produce the same candidate over and over, as without that argument the model is not 'aware' that the candidate is part of a `RUNNING` or `ABANDONED` trial and should not be re-suggested again.\n",
"\n",
"In cases where `get_pending_observation_features` is too slow and the experiment consists of 1-arm `Trial`-s only, it's possible to use `get_pending_observation_features_based_on_trial_status` instead (found in the same file)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that when using the Ax Service API, one of the arguments to `AxClient` is `choose_generation_strategy_kwargs`; specifying that argument is a convenient way to influence the choice of generation strategy in `AxClient` without manually specifying a full `GenerationStrategy`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"-----"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. `GenerationStep` as a building block of generation strategy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2A. Describing a model to use in a given `GenerationStep`\n",
"\n",
"There are two ways of specifying a model for a generation step: via an entry in a `Models` enum or via a 'factory function' –– a callable model constructor (e.g. [`get_GPEI`](https://github.com/facebook/Ax/blob/0e454b71d5e07b183c0866855555b6a21ddd5da1/ax/modelbridge/factory.py#L154) and other factory functions in the same file). Note that using the latter path, a factory function, will prohibit `GenerationStrategy` storage and is generally discouraged. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2B. Other `GenerationStep` settings\n",
"\n",
"All of the available settings are described in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"One step in the generation strategy, corresponds to a single model.\n",
" Describes the model, how many trials will be generated with this model, what\n",
" minimum number of observations is required to proceed to the next model, etc.\n",
"\n",
" NOTE: Model can be specified either from the model registry\n",
" (`ax.modelbridge.registry.Models` or using a callable model constructor. Only\n",
" models from the registry can be saved, and thus optimization can only be\n",
" resumed if interrupted when using models from the registry.\n",
"\n",
" Args:\n",
" model: A member of `Models` enum or a callable returning an instance of\n",
" `ModelBridge` with an instantiated underlying `Model`. Refer to\n",
" `ax/modelbridge/factory.py` for examples of such callables.\n",
" num_trials: How many trials to generate with the model from this step.\n",
" If set to -1, trials will continue to be generated from this model\n",
" as long as `generation_strategy.gen` is called (available only for\n",
" the last of the generation steps).\n",
" min_trials_observed: How many trials must be completed before the\n",
" generation strategy can proceed to the next step. Defaults to 0.\n",
" If `num_trials` of a given step have been generated but `min_trials_\n",
" observed` have not been completed, a call to `generation_strategy.gen`\n",
" will fail with a `DataRequiredError`.\n",
" max_parallelism: How many trials generated in the course of this step are\n",
" allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously.\n",
" If `max_parallelism` trials from this step are already running, a call\n",
" to `generation_strategy.gen` will fail with a `MaxParallelismReached\n",
" Exception`, indicating that more trials need to be completed before\n",
" generating and running next trials.\n",
" use_update: Whether to use `model_bridge.update` instead or reinstantiating\n",
" model + bridge on every call to `gen` within a single generation step.\n",
" NOTE: use of `update` on stateful models that do not implement `_get_state`\n",
" may result in inability to correctly resume a generation strategy from\n",
" a serialized state.\n",
" enforce_num_trials: Whether to enforce that only `num_trials` are generated\n",
" from the given step. If False and `num_trials` have been generated, but\n",
" `min_trials_observed` have not been completed, `generation_strategy.gen`\n",
" will continue generating trials from the current step, exceeding `num_\n",
" trials` for it. Allows to avoid `DataRequiredError`, but delays\n",
" proceeding to next generation step.\n",
" model_kwargs: Dictionary of kwargs to pass into the model constructor on\n",
" instantiation. E.g. if `model` is `Models.SOBOL`, kwargs will be applied\n",
" as `Models.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, `get_sobol(\n",
" **model_kwargs)`. NOTE: if generation strategy is interrupted and\n",
" resumed from a stored snapshot and its last used model has state saved on\n",
" its generator runs, `model_kwargs` is updated with the state dict of the\n",
" model, retrieved from the last generator run of this generation strategy.\n",
" model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the\n",
" step's model's `gen` under the hood; `model_gen_kwargs` will be passed to\n",
" the model's `gen` like so: `model.gen(**model_gen_kwargs)`.\n",
" index: Index of this generation step, for use internally in `Generation\n",
" Strategy`. Do not assign as it will be reassigned when instantiating\n",
" `GenerationStrategy` with a list of its steps.\n",
"\n",
" \n"
]
}
],
"source": [
"print(GenerationStep.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2C. Chaining `GenerationStep`-s together\n",
"\n",
"A `GenerationStrategy` moves from one step to another when: \n",
"1. `N=num_trials` generator runs were produced and attached as trials to the experiment AND \n",
"2. `M=min_trials_observed` have been completed and have data.\n",
"\n",
"**Caveat: `enforce_num_trials` setting**:\n",
"\n",
"1. If `enforce_num_trials=True` for a given generation step, if 1) is reached but 2) is not yet reached, the generation strategy will raise a `DataRequiredError`, indicating that more trials need to be completed before the next step.\n",
"2. If `enforce_num_trials=False`, the GS will continue producing generator runs from the current step until 2) is reached."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2D. `max_parallelism` enforcement\n",
"\n",
"Generation strategy can restrict the number of trials that can be ran simultaneously (to encourage sequential optimization, which benefits Bayesian optimization performance). When the parallelism limit is reached, a call to `GenerationStrategy.gen` will result in a `MaxParallelismReachedException`.\n",
"\n",
"The correct way to handle this exception:\n",
"1. Make sure that `GenerationStep.max_parallelism` is configured correctly for all steps in your generation strategy (to disable it completely, configure `GenerationStep.max_parallelism=None`),\n",
"2. When encountering the exception, wait to produce more generator runs until more trial evluations complete and log the trial completion via `trial.mark_completed`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"----\n",
"\n",
"## 3. SQL and JSON storage of a generation strategy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When used through Service API or `Scheduler`, generation strategy will be automatically stored to SQL or JSON via specifying `DBSettings` to either `AxClient` or `Scheduler` (details in respective tutorials in the [\"Tutorials\" page](https://ax.dev/tutorials/)). Generation strategy can also be stored to SQL or JSON individually, as shown below.\n",
"\n",
"More detail on SQL and JSON storage in Ax generally can be [found in \"Building Blocks of Ax\" tutorial](https://ax.dev/tutorials/building_blocks.html#9.-Save-to-JSON-or-SQL)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3A. SQL storage\n",
"For SQL storage setup in Ax, read through the [\"Storage\" documentation page](https://ax.dev/docs/storage.html).\n",
"\n",
"Note that unlike an Ax experiment, a generation strategy does not have a name or another unique identifier. Therefore, a generation strategy is stored in association with experiment and can be retrieved by the associated experiment's name."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ax.storage.sqa_store.save import save_generation_strategy, save_experiment\n",
"from ax.storage.sqa_store.load import load_experiment, load_generation_strategy_by_experiment_name\n",
"\n",
"from ax.storage.sqa_store.db import init_engine_and_session_factory,get_engine, create_all_tables\n",
"from ax.storage.sqa_store.load import load_experiment\n",
"from ax.storage.sqa_store.save import save_experiment\n",
"\n",
"init_engine_and_session_factory(url='sqlite:///foo2.db')\n",
"\n",
"engine = get_engine()\n",
"create_all_tables(engine)\n",
"\n",
"save_experiment(experiment)\n",
"save_generation_strategy(gs)\n",
"\n",
"experiment = load_experiment(experiment_name=experiment.name)\n",
"gs = load_generation_strategy_by_experiment_name(\n",
" experiment_name=experiment.name, \n",
" experiment=experiment, # Can optionally specify experiment object to avoid loading it from database twice\n",
")\n",
"gs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3B. JSON storage"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])"
]
},
"execution_count": 31,
"metadata": {
"bento_obj_id": "139923550893296"
},
"output_type": "execute_result"
}
],
"source": [
"from ax.storage.json_store.encoder import object_to_json\n",
"from ax.storage.json_store.decoder import object_from_json\n",
"\n",
"gs_json = object_to_json(gs) # Can be written to a file or string via `json.dump` etc.\n",
"gs = object_from_json(gs_json) # Decoded back from JSON (can be loaded from file, string via `json.load` etc.)\n",
"gs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"------"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Advanced considerations\n",
"\n",
"Below is a list of important \"gotchas\" of using generation strategy (especially outside of the higher-level APIs like the Service API or the `Scheduler`):\n",
"\n",
"### 3A. `GenerationStrategy.gen` produces `GeneratorRun`-s, not trials\n",
"\n",
"Since `GenerationStrategy.gen` mimics `ModelBridge.gen` and allows for human-in-the-loop usage mode, a call to `gen` produces a `GeneratorRun`, which can then be added (or altered before addition or not added at all) to a `Trial` or `BatchTrial` on a given experiment. So it's important to add the generator run to a trial, since otherwise it will not be attached to the experiment on its own."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Trial(experiment_name='branin_test_experiment', index=1, status=TrialStatus.CANDIDATE, arm=Arm(name='1_0', parameters={'x1': -0.34071301110088825, 'x2': 7.061324520036578}))"
]
},
"execution_count": 27,
"metadata": {
"bento_obj_id": "139923551043648"
},
"output_type": "execute_result"
}
],
"source": [
"generator_run = gs.gen(\n",
" experiment=experiment, n=1, pending_observations=get_pending_observation_features(experiment)\n",
")\n",
"experiment.new_trial(generator_run)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3B. `model_kwargs` elements that do not define serialization logic in Ax"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that passing objects that are not yet serializable in Ax (e.g. a BoTorch `Prior` object) as part of `GenerationStep.model_kwargs` or `GenerationStep.model_gen_kwargs` will prevent correct generation strategy storage. If this becomes a problem, feel free to open an issue on our Github: https://github.com/facebook/Ax/issues to get help with adding storage support for a given object."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3C. Why prefer `Models` enum entries over a factory function?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. **Storage potential:** a call to, for example, `Models.GPEI` captures all arguments to the model and model bridge and stores them on a generator runs, subsequently produced by the model. Since the capturing logic is part of `Models.__call__` function, it is not present in a factory function. Furthermore, there is no safe and flexible way to serialize callables in Python.\n",
"2. **Standardization:** While a 'factory function' is by default more flexible (accepts any specified inputs and produces a `ModelBridge` with an underlying `Model` instance based on them), it is not standard in terms of its inputs. `Models` introduces a standardized interface, making it easy to adapt any example to one's specific case."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3D. How can I request more modeling setups added to `Models` and natively supported in Ax?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Please open a [Github issue](https://github.com/facebook/Ax/issues) to request a new modeling setup in Ax (or for any other questions or requests)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}