def forward()

in botorch/models/kernels/linear_truncated_fidelity.py [0:0]


    def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor:
        if params.get("last_dim_is_batch", False):
            raise NotImplementedError(
                "last_dim_is_batch not yet supported by LinearTruncatedFidelityKernel"
            )

        power = self.power.view(*self.batch_shape, 1, 1)
        active_dimsM = torch.tensor(
            [i for i in range(x1.size(-1)) if i not in self.fidelity_dims],
            device=x1.device,
        )
        if len(active_dimsM) == 0:
            raise RuntimeError(
                "Input to LinearTruncatedFidelityKernel must have at least one "
                " non-fidelity dimension"
            )
        x1_ = x1.index_select(dim=-1, index=active_dimsM)
        x2_ = x2.index_select(dim=-1, index=active_dimsM)
        covar_unbiased = self.covar_module_unbiased(x1_, x2_, diag=diag)
        covar_biased = self.covar_module_biased(x1_, x2_, diag=diag)

        # clamp to avoid numerical issues
        fd_idxr0 = torch.full(
            (1,), self.fidelity_dims[0], dtype=torch.long, device=x1.device
        )
        x11_ = x1.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
        x21t_ = x2.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
        if not diag:
            x21t_ = x21t_.transpose(-1, -2)
        cross_term_1 = (1 - x11_) * (1 - x21t_)
        bias_factor = cross_term_1 * (1 + x11_ * x21t_).pow(power)

        if len(self.fidelity_dims) > 1:
            # clamp to avoid numerical issues
            fd_idxr1 = torch.full(
                (1,), self.fidelity_dims[1], dtype=torch.long, device=x1.device
            )
            x12_ = x1.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
            x22t_ = x2.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
            x1b_ = torch.cat([x11_, x12_], dim=-1)
            if diag:
                x2bt_ = torch.cat([x21t_, x22t_], dim=-1)
                k = (1 + (x1b_ * x2bt_).sum(dim=-1, keepdim=True)).pow(power)
            else:
                x22t_ = x22t_.transpose(-1, -2)
                x2bt_ = torch.cat([x21t_, x22t_], dim=-2)
                k = (1 + x1b_ @ x2bt_).pow(power)

            cross_term_2 = (1 - x12_) * (1 - x22t_)
            bias_factor += cross_term_2 * (1 + x12_ * x22t_).pow(power)
            bias_factor += cross_term_2 * cross_term_1 * k

        if diag:
            bias_factor = bias_factor.view(covar_biased.shape)

        return covar_unbiased + bias_factor * covar_biased