def compute_posterior_stats()

in examples/sparse_regression.py [0:0]


def compute_posterior_stats(X, Y, msq, lam, eta1, xisq, c, sigma, jitter=1.0e-4):
    N, P = X.shape

    # prepare for computation of posterior statistics for singleton weights
    probe = torch.zeros((P, 2, P), dtype=X.dtype, device=X.device)
    probe[:, 0, :] = torch.eye(P, dtype=X.dtype, device=X.device)
    probe[:, 1, :] = -torch.eye(P, dtype=X.dtype, device=X.device)

    eta2 = eta1.pow(2.0) * xisq.sqrt() / msq
    kappa = msq.sqrt() * lam / (msq + (eta1 * lam).pow(2.0)).sqrt()

    kX = kappa * X
    kprobe = kappa * probe
    kprobe = kprobe.reshape(-1, P)

    # compute various kernels
    k_xx = kernel(kX, kX, eta1, eta2, c) + (jitter + sigma ** 2) * torch.eye(N, dtype=X.dtype, device=X.device)
    k_xx_inv = torch.inverse(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    # compute mean and variance for singleton weights
    vec = torch.tensor([0.50, -0.50], dtype=X.dtype, device=X.device)
    mu = torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)).squeeze(-1).reshape(P, 2)
    mu = (mu * vec).sum(-1)

    var = k_prbprb - torch.matmul(k_probeX, torch.matmul(k_xx_inv, k_probeX.t()))
    var = var.reshape(P, 2, P, 2).diagonal(dim1=-4, dim2=-2)  # 2 2 P
    std = ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)).sum(-2).clamp(min=0.0).sqrt()

    # select active dimensions (those that are non-zero with sufficient statistical significance)
    active_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)).bool()
    active_dims = active_dims.nonzero().squeeze(-1)

    print("Identified the following active dimensions:", active_dims.data.numpy().flatten())
    print("Mean estimate for active singleton weights:\n", mu[active_dims].data.numpy())

    # if there are 0 or 1 active dimensions there are no quadratic weights to be found
    M = len(active_dims)
    if M < 2:
        return active_dims.data.numpy(), []

    # prep for computation of posterior statistics for quadratic weights
    left_dims, right_dims = torch.ones(M, M).triu(1).nonzero().t()
    left_dims, right_dims = active_dims[left_dims], active_dims[right_dims]

    probe = torch.zeros(left_dims.size(0), 4, P, dtype=X.dtype, device=X.device)
    left_dims_expand = left_dims.unsqueeze(-1).expand(left_dims.size(0), P)
    right_dims_expand = right_dims.unsqueeze(-1).expand(right_dims.size(0), P)
    for dim, value in zip(range(4), [1.0, 1.0, -1.0, -1.0]):
        probe[:, dim, :].scatter_(-1, left_dims_expand, value)
    for dim, value in zip(range(4), [1.0, -1.0, 1.0, -1.0]):
        probe[:, dim, :].scatter_(-1, right_dims_expand, value)

    kprobe = kappa * probe
    kprobe = kprobe.reshape(-1, P)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    # compute mean and covariance for a subset of weights theta_ij (namely those with
    # 'active' dimensions i and j)
    vec = torch.tensor([0.25, -0.25, -0.25, 0.25], dtype=X.dtype, device=X.device)
    mu = torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)).squeeze(-1).reshape(left_dims.size(0), 4)
    mu = (mu * vec).sum(-1)

    var = k_prbprb - torch.matmul(k_probeX, torch.matmul(k_xx_inv, k_probeX.t()))
    var = var.reshape(left_dims.size(0), 4, left_dims.size(0), 4).diagonal(dim1=-4, dim2=-2)
    std = ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)).sum(-2).clamp(min=0.0).sqrt()

    active_quad_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)) & (mu.abs() > 1.0e-4).bool()
    active_quad_dims = active_quad_dims.nonzero()

    active_quadratic_dims = np.stack([left_dims[active_quad_dims].data.numpy().flatten(),
                                      right_dims[active_quad_dims].data.numpy().flatten()], axis=1)
    active_quadratic_dims = np.split(active_quadratic_dims, active_quadratic_dims.shape[0])
    active_quadratic_dims = [tuple(a.tolist()[0]) for a in active_quadratic_dims]

    return active_dims.data.numpy(), active_quadratic_dims