in botorch/acquisition/multi_objective/multi_output_risk_measures.py [0:0]
def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Calculate the MVaR corresponding to the given samples.
Args:
samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
posterior samples. The q-batches should be ordered so that each
`n_w` block of samples correspond to the same input.
X: A `batch_shape x q x d`-dim tensor of inputs. Ignored.
Returns:
A `sample_shape x batch_shape x q x m`-dim tensor of MVaR values,
if `self.expectation=True`.
Otherwise, this returns a `sample_shape x batch_shape x (q * k') x m`-dim
tensor, where `k'` is the maximum `k` across all batches that is returned
by `get_mvar_set_...`. Each `(q * k') x m` corresponds to the `k` MVaR
values for each `q` batch of `n_w` inputs, padded up to `k'` by repeating
the last element. If `self.pad_to_n_w`, we set `k' = self.n_w`, producing
a deterministic return shape.
"""
batch_shape, m = samples.shape[:-2], samples.shape[-1]
prepared_samples = self._prepare_samples(samples)
# This is -1 x n_w x m.
prepared_samples = prepared_samples.reshape(-1, *prepared_samples.shape[-2:])
# Get the mvar set using the appropriate method based on device, m & n_w.
# NOTE: The `n_w <= 64` part is based on testing on a 24 core CPU.
# `get_mvar_set_gpu` heavily relies on parallelized batch computations and
# may scale worse on CPUs with fewer cores.
# Using `no_grad` here since `MVaR` is not differentiable.
with torch.no_grad():
if (
samples.device == torch.device("cpu")
and m == 2
and prepared_samples.shape[-2] <= 64
):
mvar_set = self.get_mvar_set_cpu(prepared_samples)
else:
mvar_set = self.get_mvar_set_gpu(prepared_samples)
if samples.requires_grad:
# TODO: Investigate differentiability of MVaR.
warnings.warn(
"Got `samples` that requires grad, but computing MVaR involves "
"non-differentable operations and the results will not be "
"differentiable. This may lead to errors down the line!",
RuntimeWarning,
)
# Set the `pad_size` to either `self.n_w` or the size of the largest MVaR set.
pad_size = self.n_w if self.pad_to_n_w else max([_.shape[0] for _ in mvar_set])
padded_mvar_list = []
for mvar_ in mvar_set:
if self.expectation:
padded_mvar_list.append(mvar_.mean(dim=0))
else:
# Repeat the last entry to make `mvar_set` `n_w x m`.
repeats_needed = pad_size - mvar_.shape[0]
padded_mvar_list.append(
torch.cat([mvar_, mvar_[-1].expand(repeats_needed, m)], dim=0)
)
mvars = torch.stack(padded_mvar_list, dim=0)
return mvars.view(*batch_shape, -1, m)