tutorials/generation_strategy.ipynb (566 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep\n", "from ax.modelbridge.registry import Models, ModelRegistryBase\n", "from ax.modelbridge.dispatch_utils import choose_generation_strategy\n", "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", "\n", "from ax.utils.testing.core_stubs import get_branin_search_space, get_branin_experiment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generation Strategy (GS) Tutorial\n", "\n", "`GenerationStrategy` ([API reference](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)) is a key abstraction in Ax:\n", "- It allows for specifying multiple optimization algorithms to chain one after another in the course of the optimization. \n", "- Many higher-level APIs in Ax use generation strategies: Service and Loop APIs, `Scheduler` etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/).\n", "- Generation strategy allows for storage and resumption of modeling setups, making optimization resumable from SQL or JSON snapshots.\n", "\n", "This tutorial walks through a few examples of generation strategies and discusses its important settings. Before reading it, we recommend familiarizing yourself with how `Model` and `ModelBridge` work in Ax: https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack.\n", "\n", "**Contents:**\n", "1. Quick-start examples\n", " 1. Manually configured GS\n", " 2. Auto-selected GS\n", " 3. Candidate generation from a GS\n", "2. Deep dive: `GenerationStep` a building block of the generation strategy\n", " 1. Describing a model\n", " 2. Other `GenerationStep` settings\n", " 3. Chaining `GenerationStep`-s together\n", " 4. `max_parallelism` enforcement and handling the `MaxParallelismReachedException`\n", "3. `GenerationStrategy` storage\n", " 1. JSON storage\n", " 2. SQL storage\n", "4. Advanced considerations / \"gotchas\"\n", " 1. Generation strategy produces `GeneratorRun`-s, not `Trial`-s\n", " 2. `model_kwargs` elements that don't have associated serialization logic in Ax\n", " 3. Why prefer `Models` registry enum entries over a factory function?\n", " 4. How to request more modeling setups in `Models`?\n", " \n", "----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Quick-start examples\n", "\n", "### 1A. Manually configured generation strategy\n", "\n", "Below is a typical generation strategy used for most single-objective optimization cases in Ax:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "gs = GenerationStrategy(\n", " steps=[\n", " # 1. Initialization step (does not require pre-existing data and is well-suited for \n", " # initial sampling of the search space)\n", " GenerationStep(\n", " model=Models.SOBOL,\n", " num_trials=5, # How many trials should be produced from this generation step\n", " min_trials_observed=3, # How many trials need to be completed to move to next model\n", " max_parallelism=5, # Max parallelism for this step\n", " model_kwargs={\"seed\": 999}, # Any kwargs you want passed into the model\n", " model_gen_kwargs={}, # Any kwargs you want passed to `modelbridge.gen`\n", " ),\n", " # 2. Bayesian optimization step (requires data obtained from previous phase and learns\n", " # from all data available at the time of each new candidate generation call)\n", " GenerationStep(\n", " model=Models.GPEI,\n", " num_trials=-1, # No limitation on how many trials should be produced from this step\n", " max_parallelism=3, # Parallelism limit for this step, often lower than for Sobol\n", " # More on parallelism vs. required samples in BayesOpt:\n", " # https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials\n", " ),\n", " ]\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1B. Auto-selected generation strategy\n", "\n", "Ax provides a [`choose_generation_strategy`](https://github.com/facebook/Ax/blob/main/ax/modelbridge/dispatch_utils.py#L115) utility, which can auto-select a suitable generation strategy given a search space and an array of other optional settings. The utility is fairly simple at the moment, but additional development (support for multi-objective optimization, multi-fidelity optimization, Bayesian optimization with categorical kernels etc.) is coming soon." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 06-15 07:59:03] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.\n" ] }, { "data": { "text/plain": [ "GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])" ] }, "execution_count": 12, "metadata": { "bento_obj_id": "139922521218736" }, "output_type": "execute_result" } ], "source": [ "gs = choose_generation_strategy(\n", " # Required arguments:\n", " search_space=get_branin_search_space(), # Ax `SearchSpace`\n", " \n", " # Some optional arguments (shown with their defaults), see API docs for more settings:\n", " # https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils\n", " use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`\n", " no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt\n", " max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS \n", ")\n", "gs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1C. Candidate generation from a generation strategy\n", "\n", "While often used through Service or Loop API or other higher-order abstractions like the Ax `Scheduler` (where the generation strategy is used to fit models and produce candidates from them under-the-hood), it's also possible to use the GS directly, in place of a `ModelBridge` instance. The interface of `GenerationStrategy.gen` is the same as `ModelBridge.gen`.\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "experiment = get_branin_experiment()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that it's important to **specify pending observations** to the call to `gen` to avoid getting the same points re-suggested. Without `pending_observations` argument, Ax models are not aware of points that should be excluded from generation. Points are considered \"pending\" when they belong to `STAGED`, `RUNNING`, or `ABANDONED` trials (with the latter included so model does not re-suggest points that are considered \"bad\" and should not be re-suggested).\n", "\n", "If the call to `get_pending_obervation_features` becomes slow in your setup (since it performs data-fetching etc.), you can opt for `get_pending_observation_features_based_on_trial_status` (also from `ax.modelbridge.modelbridge_utils`), but note the limitations of that utility (detailed in its docstring)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GeneratorRun(1 arms, total weight 1.0)" ] }, "execution_count": 23, "metadata": { "bento_obj_id": "139922521218448" }, "output_type": "execute_result" } ], "source": [ "generator_run = gs.gen(\n", " experiment=experiment, # Ax `Experiment`, for which to generate new candidates\n", " data=None, # Ax `Data` to use for model training, optional.\n", " n=1, # Number of candidate arms to produce\n", " pending_observations=get_pending_observation_features(experiment), # Points that should not be re-generated\n", " # Any other kwargs specified will be passed through to `ModelBridge.gen` along with `GenerationStep.model_gen_kwargs`\n", ")\n", "generator_run" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we can add the newly produced [`GeneratorRun`](https://ax.dev/docs/glossary.html#generator-run) to the experiment as a [`Trial` (or `BatchTrial` if `n` > 1)](https://ax.dev/docs/glossary.html#trial):" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Trial(experiment_name='branin_test_experiment', index=0, status=TrialStatus.CANDIDATE, arm=Arm(name='0_0', parameters={'x1': 2.4094051076099277, 'x2': 13.29242150299251}))" ] }, "execution_count": 24, "metadata": { "bento_obj_id": "139923550679968" }, "output_type": "execute_result" } ], "source": [ "trial = experiment.new_trial(generator_run)\n", "trial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Important notes on `GenerationStrategy.gen`:**\n", "- if `data` argument above is not specified, GS will pull experiment data from cache via `experiment.lookup_data`,\n", "- without specifying `pending_observations`, the GS (and any model in Ax) could produce the same candidate over and over, as without that argument the model is not 'aware' that the candidate is part of a `RUNNING` or `ABANDONED` trial and should not be re-suggested again.\n", "\n", "In cases where `get_pending_observation_features` is too slow and the experiment consists of 1-arm `Trial`-s only, it's possible to use `get_pending_observation_features_based_on_trial_status` instead (found in the same file)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that when using the Ax Service API, one of the arguments to `AxClient` is `choose_generation_strategy_kwargs`; specifying that argument is a convenient way to influence the choice of generation strategy in `AxClient` without manually specifying a full `GenerationStrategy`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "-----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. `GenerationStep` as a building block of generation strategy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2A. Describing a model to use in a given `GenerationStep`\n", "\n", "There are two ways of specifying a model for a generation step: via an entry in a `Models` enum or via a 'factory function' –– a callable model constructor (e.g. [`get_GPEI`](https://github.com/facebook/Ax/blob/0e454b71d5e07b183c0866855555b6a21ddd5da1/ax/modelbridge/factory.py#L154) and other factory functions in the same file). Note that using the latter path, a factory function, will prohibit `GenerationStrategy` storage and is generally discouraged. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2B. Other `GenerationStep` settings\n", "\n", "All of the available settings are described in the documentation:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "One step in the generation strategy, corresponds to a single model.\n", " Describes the model, how many trials will be generated with this model, what\n", " minimum number of observations is required to proceed to the next model, etc.\n", "\n", " NOTE: Model can be specified either from the model registry\n", " (`ax.modelbridge.registry.Models` or using a callable model constructor. Only\n", " models from the registry can be saved, and thus optimization can only be\n", " resumed if interrupted when using models from the registry.\n", "\n", " Args:\n", " model: A member of `Models` enum or a callable returning an instance of\n", " `ModelBridge` with an instantiated underlying `Model`. Refer to\n", " `ax/modelbridge/factory.py` for examples of such callables.\n", " num_trials: How many trials to generate with the model from this step.\n", " If set to -1, trials will continue to be generated from this model\n", " as long as `generation_strategy.gen` is called (available only for\n", " the last of the generation steps).\n", " min_trials_observed: How many trials must be completed before the\n", " generation strategy can proceed to the next step. Defaults to 0.\n", " If `num_trials` of a given step have been generated but `min_trials_\n", " observed` have not been completed, a call to `generation_strategy.gen`\n", " will fail with a `DataRequiredError`.\n", " max_parallelism: How many trials generated in the course of this step are\n", " allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously.\n", " If `max_parallelism` trials from this step are already running, a call\n", " to `generation_strategy.gen` will fail with a `MaxParallelismReached\n", " Exception`, indicating that more trials need to be completed before\n", " generating and running next trials.\n", " use_update: Whether to use `model_bridge.update` instead or reinstantiating\n", " model + bridge on every call to `gen` within a single generation step.\n", " NOTE: use of `update` on stateful models that do not implement `_get_state`\n", " may result in inability to correctly resume a generation strategy from\n", " a serialized state.\n", " enforce_num_trials: Whether to enforce that only `num_trials` are generated\n", " from the given step. If False and `num_trials` have been generated, but\n", " `min_trials_observed` have not been completed, `generation_strategy.gen`\n", " will continue generating trials from the current step, exceeding `num_\n", " trials` for it. Allows to avoid `DataRequiredError`, but delays\n", " proceeding to next generation step.\n", " model_kwargs: Dictionary of kwargs to pass into the model constructor on\n", " instantiation. E.g. if `model` is `Models.SOBOL`, kwargs will be applied\n", " as `Models.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, `get_sobol(\n", " **model_kwargs)`. NOTE: if generation strategy is interrupted and\n", " resumed from a stored snapshot and its last used model has state saved on\n", " its generator runs, `model_kwargs` is updated with the state dict of the\n", " model, retrieved from the last generator run of this generation strategy.\n", " model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the\n", " step's model's `gen` under the hood; `model_gen_kwargs` will be passed to\n", " the model's `gen` like so: `model.gen(**model_gen_kwargs)`.\n", " index: Index of this generation step, for use internally in `Generation\n", " Strategy`. Do not assign as it will be reassigned when instantiating\n", " `GenerationStrategy` with a list of its steps.\n", "\n", " \n" ] } ], "source": [ "print(GenerationStep.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2C. Chaining `GenerationStep`-s together\n", "\n", "A `GenerationStrategy` moves from one step to another when: \n", "1. `N=num_trials` generator runs were produced and attached as trials to the experiment AND \n", "2. `M=min_trials_observed` have been completed and have data.\n", "\n", "**Caveat: `enforce_num_trials` setting**:\n", "\n", "1. If `enforce_num_trials=True` for a given generation step, if 1) is reached but 2) is not yet reached, the generation strategy will raise a `DataRequiredError`, indicating that more trials need to be completed before the next step.\n", "2. If `enforce_num_trials=False`, the GS will continue producing generator runs from the current step until 2) is reached." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2D. `max_parallelism` enforcement\n", "\n", "Generation strategy can restrict the number of trials that can be ran simultaneously (to encourage sequential optimization, which benefits Bayesian optimization performance). When the parallelism limit is reached, a call to `GenerationStrategy.gen` will result in a `MaxParallelismReachedException`.\n", "\n", "The correct way to handle this exception:\n", "1. Make sure that `GenerationStep.max_parallelism` is configured correctly for all steps in your generation strategy (to disable it completely, configure `GenerationStep.max_parallelism=None`),\n", "2. When encountering the exception, wait to produce more generator runs until more trial evluations complete and log the trial completion via `trial.mark_completed`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----\n", "\n", "## 3. SQL and JSON storage of a generation strategy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When used through Service API or `Scheduler`, generation strategy will be automatically stored to SQL or JSON via specifying `DBSettings` to either `AxClient` or `Scheduler` (details in respective tutorials in the [\"Tutorials\" page](https://ax.dev/tutorials/)). Generation strategy can also be stored to SQL or JSON individually, as shown below.\n", "\n", "More detail on SQL and JSON storage in Ax generally can be [found in \"Building Blocks of Ax\" tutorial](https://ax.dev/tutorials/building_blocks.html#9.-Save-to-JSON-or-SQL)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3A. SQL storage\n", "For SQL storage setup in Ax, read through the [\"Storage\" documentation page](https://ax.dev/docs/storage.html).\n", "\n", "Note that unlike an Ax experiment, a generation strategy does not have a name or another unique identifier. Therefore, a generation strategy is stored in association with experiment and can be retrieved by the associated experiment's name." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ax.storage.sqa_store.save import save_generation_strategy, save_experiment\n", "from ax.storage.sqa_store.load import load_experiment, load_generation_strategy_by_experiment_name\n", "\n", "from ax.storage.sqa_store.db import init_engine_and_session_factory,get_engine, create_all_tables\n", "from ax.storage.sqa_store.load import load_experiment\n", "from ax.storage.sqa_store.save import save_experiment\n", "\n", "init_engine_and_session_factory(url='sqlite:///foo2.db')\n", "\n", "engine = get_engine()\n", "create_all_tables(engine)\n", "\n", "save_experiment(experiment)\n", "save_generation_strategy(gs)\n", "\n", "experiment = load_experiment(experiment_name=experiment.name)\n", "gs = load_generation_strategy_by_experiment_name(\n", " experiment_name=experiment.name, \n", " experiment=experiment, # Can optionally specify experiment object to avoid loading it from database twice\n", ")\n", "gs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3B. JSON storage" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])" ] }, "execution_count": 31, "metadata": { "bento_obj_id": "139923550893296" }, "output_type": "execute_result" } ], "source": [ "from ax.storage.json_store.encoder import object_to_json\n", "from ax.storage.json_store.decoder import object_from_json\n", "\n", "gs_json = object_to_json(gs) # Can be written to a file or string via `json.dump` etc.\n", "gs = object_from_json(gs_json) # Decoded back from JSON (can be loaded from file, string via `json.load` etc.)\n", "gs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Advanced considerations\n", "\n", "Below is a list of important \"gotchas\" of using generation strategy (especially outside of the higher-level APIs like the Service API or the `Scheduler`):\n", "\n", "### 3A. `GenerationStrategy.gen` produces `GeneratorRun`-s, not trials\n", "\n", "Since `GenerationStrategy.gen` mimics `ModelBridge.gen` and allows for human-in-the-loop usage mode, a call to `gen` produces a `GeneratorRun`, which can then be added (or altered before addition or not added at all) to a `Trial` or `BatchTrial` on a given experiment. So it's important to add the generator run to a trial, since otherwise it will not be attached to the experiment on its own." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Trial(experiment_name='branin_test_experiment', index=1, status=TrialStatus.CANDIDATE, arm=Arm(name='1_0', parameters={'x1': -0.34071301110088825, 'x2': 7.061324520036578}))" ] }, "execution_count": 27, "metadata": { "bento_obj_id": "139923551043648" }, "output_type": "execute_result" } ], "source": [ "generator_run = gs.gen(\n", " experiment=experiment, n=1, pending_observations=get_pending_observation_features(experiment)\n", ")\n", "experiment.new_trial(generator_run)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3B. `model_kwargs` elements that do not define serialization logic in Ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that passing objects that are not yet serializable in Ax (e.g. a BoTorch `Prior` object) as part of `GenerationStep.model_kwargs` or `GenerationStep.model_gen_kwargs` will prevent correct generation strategy storage. If this becomes a problem, feel free to open an issue on our Github: https://github.com/facebook/Ax/issues to get help with adding storage support for a given object." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3C. Why prefer `Models` enum entries over a factory function?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. **Storage potential:** a call to, for example, `Models.GPEI` captures all arguments to the model and model bridge and stores them on a generator runs, subsequently produced by the model. Since the capturing logic is part of `Models.__call__` function, it is not present in a factory function. Furthermore, there is no safe and flexible way to serialize callables in Python.\n", "2. **Standardization:** While a 'factory function' is by default more flexible (accepts any specified inputs and produces a `ModelBridge` with an underlying `Model` instance based on them), it is not standard in terms of its inputs. `Models` introduces a standardized interface, making it easy to adapt any example to one's specific case." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3D. How can I request more modeling setups added to `Models` and natively supported in Ax?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Please open a [Github issue](https://github.com/facebook/Ax/issues) to request a new modeling setup in Ax (or for any other questions or requests)." ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }