def compute_bin_probs()

in pyro/contrib/epidemiology/util.py [0:0]


def compute_bin_probs(s, num_quant_bins):
    """
    Compute categorical probabilities for a quantization scheme with num_quant_bins many
    bins. `s` is a real-valued tensor with values in [0, 1]. Returns probabilities
    of shape `s.shape` + `(num_quant_bins,)`
    """

    t = 1 - s

    if num_quant_bins == 2:
        probs = torch.stack([t, s], dim=-1)
        return probs

    ss = s * s
    tt = t * t

    if num_quant_bins == 4:
        # This cubic spline interpolates over the nearest four integers, ensuring
        # piecewise quadratic gradients.
        probs = torch.stack([
            t * tt,
            4 + ss * (3 * s - 6),
            4 + tt * (3 * t - 6),
            s * ss,
        ], dim=-1) * (1/6)
        return probs

    if num_quant_bins == 8:
        # This quintic spline interpolates over the nearest eight integers, ensuring
        # piecewise quartic gradients.
        s3 = ss * s
        s4 = ss * ss
        s5 = s3 * ss

        t3 = tt * t
        t4 = tt * tt
        t5 = t3 * tt

        probs = torch.stack([
            2 * t5,
            2 + 10 * t + 20 * tt + 20 * t3 + 10 * t4 - 7 * t5,
            55 + 115 * t + 70 * tt - 9 * t3 - 25 * t4 + 7 * t5,
            302 - 100 * ss + 10 * s4,
            302 - 100 * tt + 10 * t4,
            55 + 115 * s + 70 * ss - 9 * s3 - 25 * s4 + 7 * s5,
            2 + 10 * s + 20 * ss + 20 * s3 + 10 * s4 - 7 * s5,
            2 * s5
        ], dim=-1) * (1/840)
        return probs

    if num_quant_bins == 12:
        # This septic spline interpolates over the nearest 12 integers
        s3 = ss * s
        s4 = ss * ss
        s5 = s3 * ss
        s6 = s3 * s3
        s7 = s4 * s3

        t3 = tt * t
        t4 = tt * tt
        t5 = t3 * tt
        t6 = t3 * t3
        t7 = t4 * t3

        probs = torch.stack([
            693 * t7,
            693 + 4851 * t + 14553 * tt + 24255 * t3 + 24255 * t4 + 14553 * t5 + 4851 * t6 - 3267 * t7,
            84744 + 282744 * t + 382536 * tt + 249480 * t3 + 55440 * t4 - 24948 * t5 - 18018 * t6 + 5445 * t7,
            1017423 + 1823283 * t + 1058211 * tt + 51975 * t3 - 148995 * t4 - 18711 * t5 + 20097 * t6 - 3267 * t7,
            3800016 + 3503808 * t + 365904 * tt - 443520 * t3 - 55440 * t4 + 33264 * t5 - 2772 * t6,
            8723088 - 1629936 * ss + 110880.0 * s4 - 2772 * s6,
            8723088 - 1629936 * tt + 110880.0 * t4 - 2772 * t6,
            3800016 + 3503808 * s + 365904 * ss - 443520 * s3 - 55440 * s4 + 33264 * s5 - 2772 * s6,
            1017423 + 1823283 * s + 1058211 * ss + 51975 * s3 - 148995 * s4 - 18711 * s5 + 20097 * s6 - 3267 * s7,
            84744 + 282744 * s + 382536 * ss + 249480 * s3 + 55440 * s4 - 24948 * s5 - 18018 * s6 + 5445 * s7,
            693 + 4851 * s + 14553 * ss + 24255 * s3 + 24255 * s4 + 14553 * s5 + 4851 * s6 - 3267 * s7,
            693 * s7,
        ], dim=-1) * (1/32931360)
        return probs

    if num_quant_bins == 16:
        # This nonic spline interpolates over the nearest 16 integers
        w16 = torch.from_numpy(W16).to(s.device).type_as(s)
        s_powers = s.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.))
        t_powers = t.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.))
        splines_t = (w16 * t_powers).sum(-1)
        splines_s = (w16 * s_powers).sum(-1)
        index = [0, 1, 2, 3, 4, 5, 6, 15, 7, 14, 13, 12, 11, 10, 9, 8]
        probs = torch.cat([splines_t, splines_s], dim=-1)
        probs = probs.index_select(-1, torch.tensor(index))
        return probs

    raise ValueError("Unsupported num_quant_bins: {}".format(num_quant_bins))