in pubs/owenetal/code/stratplots.py [0:0]
def plot_novel_lse_grids(sobol_trials, opt_trials, funtype="detection"):
"""
Generates Fig. TBA
"""
logger = BenchmarkLogger(log_every=opt_trials) # we only care about final perf
bench_rbf = {
"common": {"pairwise": False, "target": 0.75},
"experiment": {
"acqf": "MonotonicMCLSE",
"modelbridge_cls": "MonotonicSingleProbitModelbridge",
"init_strat_cls": "SobolStrategy",
"opt_strat_cls": "ModelWrapperStrategy",
"model": "MonotonicRejectionGP",
"parnames": "[context,intensity]",
},
"MonotonicMCLSE": {
"target": 0.75,
"beta": 3.98,
},
"MonotonicRejectionGP": {
"inducing_size": 100,
"mean_covar_factory": [
"monotonic_mean_covar_factory",
],
"monotonic_idxs": ["[1]", "[]"],
"uniform_idxs": "[]",
},
"MonotonicSingleProbitModelbridge": {"restarts": 10, "samps": 1000},
"SobolStrategy": {
"n_trials": [sobol_trials],
},
"ModelWrapperStrategy": {
"n_trials": [opt_trials],
"refit_every": [refit_every],
},
}
bench_song = {
"common": {"pairwise": False, "target": 0.75},
"experiment": {
"acqf": "BernoulliMCMutualInformation",
"modelbridge_cls": "SingleProbitModelbridgeWithSongHeuristic",
"init_strat_cls": "SobolStrategy",
"opt_strat_cls": "ModelWrapperStrategy",
"model": "GPClassificationModel",
"parnames": "[context,intensity]",
},
"GPClassificationModel": {
"inducing_size": 100,
"dim": 2,
"mean_covar_factory": [
"song_mean_covar_factory",
],
},
"SingleProbitModelbridgeWithSongHeuristic": {"restarts": 10, "samps": 1000},
"SobolStrategy": {
"n_trials": [sobol_trials],
},
"ModelWrapperStrategy": {
"n_trials": [opt_trials],
"refit_every": [refit_every],
},
}
all_bench_configs = [bench_rbf, bench_song]
if funtype == "detection":
testfun = novel_detection_testfun
yes_label = "Detected trial"
no_label = "Nondetected trial"
elif funtype == "discrimination":
testfun = novel_discrimination_testfun
yes_label = "Correct trial"
no_label = "Incorrect trial"
else:
raise RuntimeError("unknown testfun")
class NovelProblem(LSEProblem, Problem):
def f(self, x):
return testfun(x)
lb = [-1, -1]
ub = [1, 1]
benches = []
problem = NovelProblem(lb, ub, gridsize=50)
for config in all_bench_configs:
full_config = copy(config)
full_config["common"]["lb"] = str(lb)
full_config["common"]["ub"] = str(ub)
benches.append(
Benchmark(
problem=problem,
logger=logger,
configs=full_config,
global_seed=global_seed,
n_reps=1,
)
)
combo_bench = combine_benchmarks(*benches)
strats = []
for config in combo_bench.combinations:
strat = combo_bench.run_experiment(config, logger, seed=global_seed, rep=0)
strats.append(strat)
titles = [
"Monotonic RBF Model, LSE (ours)",
"Nonmonotonic RBF Model, LSE (ours)",
"Linear-Additive Model, BALD",
]
fig, axes = plt.subplots(2, 2, figsize=(7.5, 6.5))
plotting_axes = [axes[1, 0], axes[0, 1], axes[0, 0]]
fig.delaxes(axes[1, 1])
_ = [
plot_strat(
strat=strat_,
title=title_,
ax=ax_,
true_testfun=testfun,
yes_label=yes_label,
no_label=no_label,
show=False,
include_legend=False,
include_colorbar=False
)
for ax_, strat_, title_ in zip(plotting_axes, strats, titles)
]
fig.tight_layout()
handles, labels = axes[1, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower right", bbox_to_anchor=(0.8, 0.2))
cbr = fig.colorbar(axes[1, 0].images[0], ax=plotting_axes)
cbr.set_label("Probability of Detection")
return fig