in examples/sparse_regression.py [0:0]
def compute_posterior_stats(X, Y, msq, lam, eta1, xisq, c, sigma, jitter=1.0e-4):
N, P = X.shape
# prepare for computation of posterior statistics for singleton weights
probe = torch.zeros((P, 2, P), dtype=X.dtype, device=X.device)
probe[:, 0, :] = torch.eye(P, dtype=X.dtype, device=X.device)
probe[:, 1, :] = -torch.eye(P, dtype=X.dtype, device=X.device)
eta2 = eta1.pow(2.0) * xisq.sqrt() / msq
kappa = msq.sqrt() * lam / (msq + (eta1 * lam).pow(2.0)).sqrt()
kX = kappa * X
kprobe = kappa * probe
kprobe = kprobe.reshape(-1, P)
# compute various kernels
k_xx = kernel(kX, kX, eta1, eta2, c) + (jitter + sigma ** 2) * torch.eye(N, dtype=X.dtype, device=X.device)
k_xx_inv = torch.inverse(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
# compute mean and variance for singleton weights
vec = torch.tensor([0.50, -0.50], dtype=X.dtype, device=X.device)
mu = torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)).squeeze(-1).reshape(P, 2)
mu = (mu * vec).sum(-1)
var = k_prbprb - torch.matmul(k_probeX, torch.matmul(k_xx_inv, k_probeX.t()))
var = var.reshape(P, 2, P, 2).diagonal(dim1=-4, dim2=-2) # 2 2 P
std = ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)).sum(-2).clamp(min=0.0).sqrt()
# select active dimensions (those that are non-zero with sufficient statistical significance)
active_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)).bool()
active_dims = active_dims.nonzero().squeeze(-1)
print("Identified the following active dimensions:", active_dims.data.numpy().flatten())
print("Mean estimate for active singleton weights:\n", mu[active_dims].data.numpy())
# if there are 0 or 1 active dimensions there are no quadratic weights to be found
M = len(active_dims)
if M < 2:
return active_dims.data.numpy(), []
# prep for computation of posterior statistics for quadratic weights
left_dims, right_dims = torch.ones(M, M).triu(1).nonzero().t()
left_dims, right_dims = active_dims[left_dims], active_dims[right_dims]
probe = torch.zeros(left_dims.size(0), 4, P, dtype=X.dtype, device=X.device)
left_dims_expand = left_dims.unsqueeze(-1).expand(left_dims.size(0), P)
right_dims_expand = right_dims.unsqueeze(-1).expand(right_dims.size(0), P)
for dim, value in zip(range(4), [1.0, 1.0, -1.0, -1.0]):
probe[:, dim, :].scatter_(-1, left_dims_expand, value)
for dim, value in zip(range(4), [1.0, -1.0, 1.0, -1.0]):
probe[:, dim, :].scatter_(-1, right_dims_expand, value)
kprobe = kappa * probe
kprobe = kprobe.reshape(-1, P)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
# compute mean and covariance for a subset of weights theta_ij (namely those with
# 'active' dimensions i and j)
vec = torch.tensor([0.25, -0.25, -0.25, 0.25], dtype=X.dtype, device=X.device)
mu = torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)).squeeze(-1).reshape(left_dims.size(0), 4)
mu = (mu * vec).sum(-1)
var = k_prbprb - torch.matmul(k_probeX, torch.matmul(k_xx_inv, k_probeX.t()))
var = var.reshape(left_dims.size(0), 4, left_dims.size(0), 4).diagonal(dim1=-4, dim2=-2)
std = ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)).sum(-2).clamp(min=0.0).sqrt()
active_quad_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)) & (mu.abs() > 1.0e-4).bool()
active_quad_dims = active_quad_dims.nonzero()
active_quadratic_dims = np.stack([left_dims[active_quad_dims].data.numpy().flatten(),
right_dims[active_quad_dims].data.numpy().flatten()], axis=1)
active_quadratic_dims = np.split(active_quadratic_dims, active_quadratic_dims.shape[0])
active_quadratic_dims = [tuple(a.tolist()[0]) for a in active_quadratic_dims]
return active_dims.data.numpy(), active_quadratic_dims