def effective_sample_size()

in utils.py [0:0]


def effective_sample_size(query_samples: Tensor) -> Tensor:
    """
    Computes effective sample size given query_samples
    :param query_samples: samples of shape (num_chains, iterations) from the posterior
    """
    n_chains, n_samples, *query_dim = query_samples.shape

    samples = query_samples - query_samples.mean(dim=1, keepdim=True)
    samples = samples.transpose(1, -1)
    # computes fourier transform (with padding)
    padding = torch.zeros(samples.shape, dtype=samples.dtype)
    padded_samples = torch.cat((samples, padding), dim=-1)
    fvi = torch.rfft(padded_samples, 1, onesided=False)
    # multiply by complex conjugate
    acf = fvi.pow(2).sum(-1, keepdim=True)
    # transform back to reals (with padding)
    padding = torch.zeros(acf.shape, dtype=acf.dtype)
    padded_acf = torch.cat((acf, padding), dim=-1)
    rho_per_chain = torch.irfft(padded_acf, 1, onesided=False)

    rho_per_chain = rho_per_chain.narrow(-1, 0, n_samples)
    num_per_lag = torch.tensor(range(n_samples, 0, -1), dtype=samples.dtype)
    rho_per_chain = torch.div(rho_per_chain, num_per_lag)
    rho_per_chain = rho_per_chain.transpose(1, -1)

    rho_avg = rho_per_chain.mean(dim=0)
    w, var_hat = _compute_var(query_samples)
    if n_chains > 1:
        rho = 1 - ((w - rho_avg) / var_hat)
    else:
        rho = rho_avg / var_hat
    rho[0] = 1

    # reshape to 2d matrix where each row contains all samples for specific dim
    rho_2d = torch.stack(torch.unbind(rho, dim=0), dim=-1).reshape(-1, n_samples)
    rho_sum = torch.zeros(rho_2d.shape[0])

    for i, chain in enumerate(torch.unbind(rho_2d, dim=0)):
        total_sum = torch.tensor(0.0, dtype=samples.dtype)
        for t in range(n_samples // 2):
            rho_even = chain[2 * t]
            rho_odd = chain[2 * t + 1]
            if rho_even + rho_odd < 0:
                break
            else:
                total_sum += rho_even + rho_odd
        rho_sum[i] = total_sum

    rho_sum = torch.reshape(rho_sum, query_dim)
    return torch.div(n_chains * n_samples, -1 + 2 * rho_sum)