def plot_acquisition_examples()

in pubs/owenetal/code/stratplots.py [0:0]


def plot_acquisition_examples(sobol_trials, opt_trials, target_level=0.75):
    ### Same model, different acqf figure ####

    configs = {
        "common": {
            "pairwise": False,
            "target": target_level,
            "lb": "[-3]",
            "ub": "[3]",
        },
        "experiment": {
            "acqf": [
                "MonotonicMCPosteriorVariance",
                "MonotonicBernoulliMCMutualInformation",
                "MonotonicMCLSE",
            ],
            "modelbridge_cls": "MonotonicSingleProbitModelbridge",
            "init_strat_cls": "SobolStrategy",
            "opt_strat_cls": "ModelWrapperStrategy",
            "model": "MonotonicRejectionGP",
            "parnames": "[intensity]",
        },
        "MonotonicMCLSE": {
            "target": target_level,
            "beta": 3.98,
        },
        "MonotonicRejectionGP": {
            "inducing_size": 100,
            "mean_covar_factory": "monotonic_mean_covar_factory",
            "monotonic_idxs": "[0]",
            "uniform_idxs": "[]",
        },
        "MonotonicSingleProbitModelbridge": {"restarts": 10, "samps": 1000},
        "SobolStrategy": {"n_trials": sobol_trials},
        "ModelWrapperStrategy": {
            "n_trials": opt_trials,
            "refit_every": refit_every,
        },
    }

    def true_testfun(x):
        return norm.cdf(3 * x)

    class SimpleLinearProblem(Problem):
        def f(self, x):
            return norm.ppf(true_testfun(x))

    lb = [-3]
    ub = [3]

    logger = BenchmarkLogger()
    problem = SimpleLinearProblem(lb, ub)
    bench = Benchmark(
        problem=problem,
        logger=logger,
        configs=configs,
        global_seed=global_seed,
        n_reps=1,
    )

    # sobol_trials
    # now run each for just init trials, taking care to reseed each time
    strats = []
    for c in bench.combinations:
        np.random.seed(global_seed)
        torch.manual_seed(global_seed)
        s = SequentialStrategy.from_config(Config(config_dict=c))
        for _ in range(sobol_trials):
            next_x = s.gen()
            s.add_data(next_x, [problem.sample_y(next_x)])
        strats.append(s)

    # get first gen from all 3
    first_gens = [s.gen() for s in strats]

    fig, ax = plt.subplots(2, 2)
    plot_strat(
        strat=strats[0],
        title=f"First active trial\n (after {sobol_trials} Sobol trials)",
        ax=ax[0, 0],
        true_testfun=true_testfun,
        target_level=target_level,
        show=False,
        include_legend=False
    )
    samps = [
        norm.cdf(s.sample(torch.Tensor(g), num_samples=10000))
        for s, g in zip(strats, first_gens)
    ]
    predictions = [np.mean(s) for s in samps]
    names = ["First BALV sample", "First BALD sample", "First LSE sample"]
    markers = ["s", "*", "^"]
    for i in range(3):
        ax[0, 0].scatter(
            first_gens[i][0][0],
            predictions[i],
            label=names[i],
            marker=markers[i],
            color="black",
        )

    # now run them all for the full duration
    for s in strats:
        for _tr in range(opt_trials):
            next_x = s.gen()
            s.add_data(next_x, [problem.sample_y(next_x)])

    plotting_axes = [ax[0, 1], ax[1, 0], ax[1, 1]]

    titles = [
        f"Monotonic RBF Model,\n BALV, after {sobol_trials+opt_trials} total trials",
        f"Monotonic RBF Model,\n BALD, after {sobol_trials+opt_trials} total trials",
        f"Monotonic RBF Model,\n LSE (ours) after {sobol_trials+opt_trials} total trials",
    ]

    _ = [
        plot_strat(
            strat=s, title=t, ax=a, true_testfun=true_testfun, target_level=target_level, show=False, include_legend=False
        )
        for a, s, t in zip(plotting_axes, strats, titles)
    ]
    fig.tight_layout()
    handles, labels = ax[0, 0].get_legend_handles_labels()
    lgd = fig.legend(handles, labels, loc="lower right", bbox_to_anchor=(1.5, 0.25))
    # return legend so savefig works correctly
    return fig, lgd