in aepsych/generators/optimize_acqf_generator.py [0:0]
def gen(self, num_points: int, model: ModelProtocol) -> np.ndarray:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int, optional): Number of points to query.
model (ModelProtocol): Fitted model of the data.
Returns:
np.ndarray: Next set of point(s) to evaluate, [num_points x dim].
"""
# eval should be inherited from superclass
model.eval() # type: ignore
train_x = model.train_inputs[0]
acqf = self._instantiate_acquisition_fn(model, train_x)
logger.info("Starting gen...")
starttime = time.time()
if self.max_gen_time is None:
new_candidate, _ = optimize_acqf(
acq_function=acqf,
bounds=torch.tensor(np.c_[model.lb, model.ub]).T.to(train_x),
q=num_points,
num_restarts=self.restarts,
raw_samples=self.samps,
)
else:
# figure out how long evaluating a single samp
starttime = time.time()
_ = acqf(train_x[0:1, :])
single_eval_time = time.time() - starttime
# only a heuristic for total num evals since everything is stochastic,
# but the reasoning is: we initialize with self.samps samps, subsample
# self.restarts from them in proportion to the value of the acqf, and
# run that many optimization. So:
# total_time = single_eval_time * n_eval * restarts + single_eval_time * samps
# and we solve for n_eval
n_eval = int(
(self.max_gen_time - single_eval_time * self.samps)
/ (single_eval_time * self.restarts)
)
if n_eval > 10:
# heuristic, if we can't afford 10 evals per restart, just use quasi-random search
options = {"maxfun": n_eval}
logger.info(f"gen maxfun is {n_eval}")
new_candidate, _ = optimize_acqf(
acq_function=acqf,
bounds=torch.tensor(np.c_[model.lb, model.ub]).T.to(train_x),
q=num_points,
num_restarts=self.restarts,
raw_samples=self.samps,
options=options,
)
else:
logger.info(f"gen maxfun is {n_eval}, falling back to random search...")
nsamp = int(self.max_gen_time / single_eval_time)
# Generate the points at which to sample
X = make_scaled_sobol(lb=model.lb, ub=model.ub, size=nsamp)
acqvals = acqf(X[:, None, :])
best_indx = torch.argmax(acqvals, dim=0)
new_candidate = X[best_indx, None]
logger.info(f"Gen done, time={time.time()-starttime}")
return new_candidate.numpy()