def plot_novel_lse_grids()

in pubs/owenetal/code/stratplots.py [0:0]


def plot_novel_lse_grids(sobol_trials, opt_trials, funtype="detection"):
    """
    Generates Fig. TBA
    """

    logger = BenchmarkLogger(log_every=opt_trials)  # we only care about final perf
    bench_rbf = {
        "common": {"pairwise": False, "target": 0.75},
        "experiment": {
            "acqf": "MonotonicMCLSE",
            "modelbridge_cls": "MonotonicSingleProbitModelbridge",
            "init_strat_cls": "SobolStrategy",
            "opt_strat_cls": "ModelWrapperStrategy",
            "model": "MonotonicRejectionGP",
            "parnames": "[context,intensity]",
        },
        "MonotonicMCLSE": {
            "target": 0.75,
            "beta": 3.98,
        },
        "MonotonicRejectionGP": {
            "inducing_size": 100,
            "mean_covar_factory": [
                "monotonic_mean_covar_factory",
            ],
            "monotonic_idxs": ["[1]", "[]"],
            "uniform_idxs": "[]",
        },
        "MonotonicSingleProbitModelbridge": {"restarts": 10, "samps": 1000},
        "SobolStrategy": {
            "n_trials": [sobol_trials],
        },
        "ModelWrapperStrategy": {
            "n_trials": [opt_trials],
            "refit_every": [refit_every],
        },
    }
    bench_song = {
        "common": {"pairwise": False, "target": 0.75},
        "experiment": {
            "acqf": "BernoulliMCMutualInformation",
            "modelbridge_cls": "SingleProbitModelbridgeWithSongHeuristic",
            "init_strat_cls": "SobolStrategy",
            "opt_strat_cls": "ModelWrapperStrategy",
            "model": "GPClassificationModel",
            "parnames": "[context,intensity]",
        },
        "GPClassificationModel": {
            "inducing_size": 100,
            "dim": 2,
            "mean_covar_factory": [
                "song_mean_covar_factory",
            ],
        },
        "SingleProbitModelbridgeWithSongHeuristic": {"restarts": 10, "samps": 1000},
        "SobolStrategy": {
            "n_trials": [sobol_trials],
        },
        "ModelWrapperStrategy": {
            "n_trials": [opt_trials],
            "refit_every": [refit_every],
        },
    }
    all_bench_configs = [bench_rbf, bench_song]

    if funtype == "detection":
        testfun = novel_detection_testfun
        yes_label = "Detected trial"
        no_label = "Nondetected trial"
    elif funtype == "discrimination":
        testfun = novel_discrimination_testfun
        yes_label = "Correct trial"
        no_label = "Incorrect trial"
    else:
        raise RuntimeError("unknown testfun")

    class NovelProblem(LSEProblem, Problem):
        def f(self, x):
            return testfun(x)

    lb = [-1, -1]
    ub = [1, 1]
    benches = []
    problem = NovelProblem(lb, ub, gridsize=50)
    for config in all_bench_configs:
        full_config = copy(config)
        full_config["common"]["lb"] = str(lb)
        full_config["common"]["ub"] = str(ub)
        benches.append(
            Benchmark(
                problem=problem,
                logger=logger,
                configs=full_config,
                global_seed=global_seed,
                n_reps=1,
            )
        )
    combo_bench = combine_benchmarks(*benches)
    strats = []

    for config in combo_bench.combinations:
        strat = combo_bench.run_experiment(config, logger, seed=global_seed, rep=0)
        strats.append(strat)

    titles = [
        "Monotonic RBF Model, LSE (ours)",
        "Nonmonotonic RBF Model, LSE (ours)",
        "Linear-Additive Model, BALD",
    ]
    fig, axes = plt.subplots(2, 2, figsize=(7.5, 6.5))
    plotting_axes = [axes[1, 0], axes[0, 1], axes[0, 0]]
    fig.delaxes(axes[1, 1])
    _ = [
        plot_strat(
            strat=strat_,
            title=title_,
            ax=ax_,
            true_testfun=testfun,
            yes_label=yes_label,
            no_label=no_label,
            show=False,
            include_legend=False,
            include_colorbar=False
        )
        for ax_, strat_, title_ in zip(plotting_axes, strats, titles)
    ]
    fig.tight_layout()
    handles, labels = axes[1, 0].get_legend_handles_labels()

    fig.legend(handles, labels, loc="lower right", bbox_to_anchor=(0.8, 0.2))
    cbr = fig.colorbar(axes[1, 0].images[0], ax=plotting_axes)
    cbr.set_label("Probability of Detection")

    return fig