in ax/models/torch/fully_bayesian.py [0:0]
def compute_dists(X: Tensor, Z: Tensor, lengthscale: Tensor) -> Tensor:
"""Compute kernel distances.
TODO: use gpytorch `Distance` module. This will require some care to make sure
jit compilation works as expected.
"""
mean = X.mean(dim=0)
X_ = (X - mean).div(lengthscale)
Z_ = (Z - mean).div(lengthscale)
x1 = X_
x2 = Z_
adjustment = x1.mean(-2, keepdim=True)
x1 = x1 - adjustment
# x1 and x2 should be identical in all dims except -2 at this point
x2 = x2 - adjustment
x1_eq_x2 = torch.equal(x1, x2)
# Compute squared distance matrix using quadratic expansion
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x1_pad = torch.ones_like(x1_norm)
if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
x2_norm, x2_pad = x1_norm, x1_pad
else:
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
x2_pad = torch.ones_like(x2_norm)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
x2_pad = torch.ones_like(x2_norm)
x1_ = torch.cat([-2.0 * x1, x1_norm, x1_pad], dim=-1)
x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
res = x1_.matmul(x2_.transpose(-2, -1))
if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
res.diagonal(dim1=-2, dim2=-1).fill_(0) # pyre-ignore [16]
# Zero out negative values
dist = res.clamp_min_(1e-30).sqrt()
return dist