in botorch/models/kernels/linear_truncated_fidelity.py [0:0]
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor:
if params.get("last_dim_is_batch", False):
raise NotImplementedError(
"last_dim_is_batch not yet supported by LinearTruncatedFidelityKernel"
)
power = self.power.view(*self.batch_shape, 1, 1)
active_dimsM = torch.tensor(
[i for i in range(x1.size(-1)) if i not in self.fidelity_dims],
device=x1.device,
)
if len(active_dimsM) == 0:
raise RuntimeError(
"Input to LinearTruncatedFidelityKernel must have at least one "
" non-fidelity dimension"
)
x1_ = x1.index_select(dim=-1, index=active_dimsM)
x2_ = x2.index_select(dim=-1, index=active_dimsM)
covar_unbiased = self.covar_module_unbiased(x1_, x2_, diag=diag)
covar_biased = self.covar_module_biased(x1_, x2_, diag=diag)
# clamp to avoid numerical issues
fd_idxr0 = torch.full(
(1,), self.fidelity_dims[0], dtype=torch.long, device=x1.device
)
x11_ = x1.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
x21t_ = x2.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
if not diag:
x21t_ = x21t_.transpose(-1, -2)
cross_term_1 = (1 - x11_) * (1 - x21t_)
bias_factor = cross_term_1 * (1 + x11_ * x21t_).pow(power)
if len(self.fidelity_dims) > 1:
# clamp to avoid numerical issues
fd_idxr1 = torch.full(
(1,), self.fidelity_dims[1], dtype=torch.long, device=x1.device
)
x12_ = x1.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
x22t_ = x2.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
x1b_ = torch.cat([x11_, x12_], dim=-1)
if diag:
x2bt_ = torch.cat([x21t_, x22t_], dim=-1)
k = (1 + (x1b_ * x2bt_).sum(dim=-1, keepdim=True)).pow(power)
else:
x22t_ = x22t_.transpose(-1, -2)
x2bt_ = torch.cat([x21t_, x22t_], dim=-2)
k = (1 + x1b_ @ x2bt_).pow(power)
cross_term_2 = (1 - x12_) * (1 - x22t_)
bias_factor += cross_term_2 * (1 + x12_ * x22t_).pow(power)
bias_factor += cross_term_2 * cross_term_1 * k
if diag:
bias_factor = bias_factor.view(covar_biased.shape)
return covar_unbiased + bias_factor * covar_biased