in examples/contrib/epidemiology/sir.py [0:0]
def evaluate(args, model, samples):
# Print estimated values.
names = {"basic_reproduction_number": "R0"}
if not args.heterogeneous:
names["response_rate"] = "rho"
if args.concentration < math.inf:
names["concentration"] = "k"
if "od" in samples:
names["overdispersion"] = "od"
for name, key in names.items():
mean = samples[key].mean().item()
std = samples[key].std().item()
logging.info("{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}"
.format(key, getattr(args, name), mean, std))
# Optionally plot histograms and pairwise correlations.
if args.plot:
import matplotlib.pyplot as plt
import seaborn as sns
# Plot individual histograms.
fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names)))
if len(names) == 1:
axes = [axes]
axes[0].set_title("Posterior parameter estimates")
for ax, (name, key) in zip(axes, names.items()):
truth = getattr(args, name)
sns.distplot(samples[key], ax=ax, label="posterior")
ax.axvline(truth, color="k", label="truth")
ax.set_xlabel(key + " = " + name.replace("_", " "))
ax.set_yticks(())
ax.legend(loc="best")
plt.tight_layout()
# Plot pairwise joint distributions for selected variables.
covariates = [(name, samples[name]) for name in names.values()]
for i, aux in enumerate(samples["auxiliary"].squeeze(1).unbind(-2)):
covariates.append(("aux[{},0]".format(i), aux[:, 0]))
covariates.append(("aux[{},-1]".format(i), aux[:, -1]))
N = len(covariates)
fig, axes = plt.subplots(N, N, figsize=(8, 8), sharex="col", sharey="row")
for i in range(N):
axes[i][0].set_ylabel(covariates[i][0])
axes[0][i].set_xlabel(covariates[i][0])
axes[0][i].xaxis.set_label_position("top")
for j in range(N):
ax = axes[i][j]
ax.set_xticks(())
ax.set_yticks(())
ax.scatter(covariates[j][1], -covariates[i][1],
lw=0, color="darkblue", alpha=0.3)
plt.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)
# Plot Pearson correlation for every pair of unconstrained variables.
def unconstrain(constraint, value):
value = biject_to(constraint).inv(value)
return value.reshape(args.num_samples, -1)
covariates = [("R1", unconstrain(constraints.positive, samples["R0"]))]
if not args.heterogeneous:
covariates.append(
("rho", unconstrain(constraints.unit_interval, samples["rho"])))
if "k" in samples:
covariates.append(
("k", unconstrain(constraints.positive, samples["k"])))
constraint = constraints.interval(-0.5, model.population + 0.5)
for name, aux in zip(model.compartments, samples["auxiliary"].unbind(-2)):
covariates.append((name, unconstrain(constraint, aux)))
x = torch.cat([v for _, v in covariates], dim=-1)
x -= x.mean(0)
x /= x.std(0)
x = x.t().matmul(x)
x /= args.num_samples
x.clamp_(min=-1, max=1)
plt.figure(figsize=(8, 8))
plt.imshow(x, cmap="bwr")
ticks = torch.tensor([0] + [v.size(-1) for _, v in covariates]).cumsum(0)
ticks = (ticks[1:] + ticks[:-1]) / 2
plt.yticks(ticks, [name for name, _ in covariates])
plt.xticks(())
plt.tick_params(length=0)
plt.title("Pearson correlation (unconstrained coordinates)")
plt.tight_layout()