def compute_dists()

in ax/models/torch/fully_bayesian.py [0:0]


def compute_dists(X: Tensor, Z: Tensor, lengthscale: Tensor) -> Tensor:
    """Compute kernel distances.

    TODO: use gpytorch `Distance` module. This will require some care to make sure
    jit compilation works as expected.
    """
    mean = X.mean(dim=0)
    X_ = (X - mean).div(lengthscale)
    Z_ = (Z - mean).div(lengthscale)
    x1 = X_
    x2 = Z_
    adjustment = x1.mean(-2, keepdim=True)
    x1 = x1 - adjustment
    # x1 and x2 should be identical in all dims except -2 at this point
    x2 = x2 - adjustment
    x1_eq_x2 = torch.equal(x1, x2)

    # Compute squared distance matrix using quadratic expansion
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x1_pad = torch.ones_like(x1_norm)
    if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
        x2_norm, x2_pad = x1_norm, x1_pad
    else:
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        x2_pad = torch.ones_like(x2_norm)
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    x2_pad = torch.ones_like(x2_norm)
    x1_ = torch.cat([-2.0 * x1, x1_norm, x1_pad], dim=-1)
    x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
    res = x1_.matmul(x2_.transpose(-2, -1))

    if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
        res.diagonal(dim1=-2, dim2=-1).fill_(0)  # pyre-ignore [16]

    # Zero out negative values
    dist = res.clamp_min_(1e-30).sqrt()
    return dist