in pyro/contrib/epidemiology/util.py [0:0]
def compute_bin_probs(s, num_quant_bins):
"""
Compute categorical probabilities for a quantization scheme with num_quant_bins many
bins. `s` is a real-valued tensor with values in [0, 1]. Returns probabilities
of shape `s.shape` + `(num_quant_bins,)`
"""
t = 1 - s
if num_quant_bins == 2:
probs = torch.stack([t, s], dim=-1)
return probs
ss = s * s
tt = t * t
if num_quant_bins == 4:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
return probs
if num_quant_bins == 8:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
s3 = ss * s
s4 = ss * ss
s5 = s3 * ss
t3 = tt * t
t4 = tt * tt
t5 = t3 * tt
probs = torch.stack([
2 * t5,
2 + 10 * t + 20 * tt + 20 * t3 + 10 * t4 - 7 * t5,
55 + 115 * t + 70 * tt - 9 * t3 - 25 * t4 + 7 * t5,
302 - 100 * ss + 10 * s4,
302 - 100 * tt + 10 * t4,
55 + 115 * s + 70 * ss - 9 * s3 - 25 * s4 + 7 * s5,
2 + 10 * s + 20 * ss + 20 * s3 + 10 * s4 - 7 * s5,
2 * s5
], dim=-1) * (1/840)
return probs
if num_quant_bins == 12:
# This septic spline interpolates over the nearest 12 integers
s3 = ss * s
s4 = ss * ss
s5 = s3 * ss
s6 = s3 * s3
s7 = s4 * s3
t3 = tt * t
t4 = tt * tt
t5 = t3 * tt
t6 = t3 * t3
t7 = t4 * t3
probs = torch.stack([
693 * t7,
693 + 4851 * t + 14553 * tt + 24255 * t3 + 24255 * t4 + 14553 * t5 + 4851 * t6 - 3267 * t7,
84744 + 282744 * t + 382536 * tt + 249480 * t3 + 55440 * t4 - 24948 * t5 - 18018 * t6 + 5445 * t7,
1017423 + 1823283 * t + 1058211 * tt + 51975 * t3 - 148995 * t4 - 18711 * t5 + 20097 * t6 - 3267 * t7,
3800016 + 3503808 * t + 365904 * tt - 443520 * t3 - 55440 * t4 + 33264 * t5 - 2772 * t6,
8723088 - 1629936 * ss + 110880.0 * s4 - 2772 * s6,
8723088 - 1629936 * tt + 110880.0 * t4 - 2772 * t6,
3800016 + 3503808 * s + 365904 * ss - 443520 * s3 - 55440 * s4 + 33264 * s5 - 2772 * s6,
1017423 + 1823283 * s + 1058211 * ss + 51975 * s3 - 148995 * s4 - 18711 * s5 + 20097 * s6 - 3267 * s7,
84744 + 282744 * s + 382536 * ss + 249480 * s3 + 55440 * s4 - 24948 * s5 - 18018 * s6 + 5445 * s7,
693 + 4851 * s + 14553 * ss + 24255 * s3 + 24255 * s4 + 14553 * s5 + 4851 * s6 - 3267 * s7,
693 * s7,
], dim=-1) * (1/32931360)
return probs
if num_quant_bins == 16:
# This nonic spline interpolates over the nearest 16 integers
w16 = torch.from_numpy(W16).to(s.device).type_as(s)
s_powers = s.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.))
t_powers = t.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.))
splines_t = (w16 * t_powers).sum(-1)
splines_s = (w16 * s_powers).sum(-1)
index = [0, 1, 2, 3, 4, 5, 6, 15, 7, 14, 13, 12, 11, 10, 9, 8]
probs = torch.cat([splines_t, splines_s], dim=-1)
probs = probs.index_select(-1, torch.tensor(index))
return probs
raise ValueError("Unsupported num_quant_bins: {}".format(num_quant_bins))