in botorch/models/gpytorch.py [0:0]
def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
r"""Subset the model along the output dimension.
Args:
idcs: The output indices to subset the model to.
Returns:
The current model, subset to the specified output indices.
"""
try:
subset_batch_dict = self._subset_batch_dict
except AttributeError:
raise NotImplementedError(
"subset_output requires the model to define a `_subset_dict` attribute"
)
m = len(idcs)
new_model = deepcopy(self)
tidxr = torch.tensor(idcs, device=new_model.train_targets.device)
idxr = tidxr if m > 1 else idcs[0]
new_tail_bs = torch.Size([m]) if m > 1 else torch.Size()
new_model._num_outputs = m
new_model._aug_batch_shape = new_model._aug_batch_shape[:-1] + new_tail_bs
new_model.train_inputs = tuple(
ti[..., idxr, :, :] for ti in new_model.train_inputs
)
new_model.train_targets = new_model.train_targets[..., idxr, :]
# adjust batch shapes of parameters/buffers if necessary
for full_name, p in itertools.chain(
new_model.named_parameters(), new_model.named_buffers()
):
if full_name in subset_batch_dict:
idx = subset_batch_dict[full_name]
new_data = p.index_select(dim=idx, index=tidxr)
if m == 1:
new_data = new_data.squeeze(idx)
p.data = new_data
mod_name = full_name.split(".")[:-1]
mod_batch_shape(new_model, mod_name, m if m > 1 else 0)
# subset outcome transform if present
try:
subset_octf = new_model.outcome_transform.subset_output(idcs=idcs)
new_model.outcome_transform = subset_octf
except AttributeError:
pass
return new_model