ax/plot/bandit_rollout.py (56 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List import plotly.graph_objs as go from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment from ax.plot.base import AxPlotConfig, AxPlotTypes from ax.plot.color import MIXED_SCALE, rgba def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig: """Plot bandit rollout from ane experiement.""" categories: List[str] = [] arms: Dict[str, Dict[str, Any]] = {} data = [] index = 0 for trial in sorted(experiment.trials.values(), key=lambda trial: trial.index): if not isinstance(trial, BatchTrial): raise ValueError( "Bandit rollout graph is not supported for BaseTrial." ) # pragma: no cover category = f"Round {trial.index}" categories.append(category) for arm, weight in trial.normalized_arm_weights(total=100).items(): if arm.name not in arms: arms[arm.name] = { "index": index, "name": arm.name, "x": [], "y": [], "text": [], } index += 1 arms[arm.name]["x"].append(category) arms[arm.name]["y"].append(weight) arms[arm.name]["text"].append("{:.2f}%".format(weight)) for key in arms.keys(): data.append(arms[key]) colors = [rgba(c) for c in MIXED_SCALE] layout = go.Layout( title="Rollout Process<br>Bandit Weight Graph", xaxis={ "title": "Rounds", "zeroline": False, "categoryorder": "array", "categoryarray": categories, }, yaxis={"title": "Percent", "showline": False}, barmode="stack", showlegend=False, margin={"r": 40}, ) bandit_config = {"type": "bar", "hoverinfo": "name+text", "width": 0.5} bandits = [ dict(bandit_config, marker={"color": colors[d["index"] % len(colors)]}, **d) for d in data ] for bandit in bandits: del bandit["index"] # Have to delete index or figure creation causes error fig = go.Figure(data=bandits, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)