def evaluate()

in examples/contrib/epidemiology/sir.py [0:0]


def evaluate(args, model, samples):
    # Print estimated values.
    names = {"basic_reproduction_number": "R0"}
    if not args.heterogeneous:
        names["response_rate"] = "rho"
    if args.concentration < math.inf:
        names["concentration"] = "k"
    if "od" in samples:
        names["overdispersion"] = "od"
    for name, key in names.items():
        mean = samples[key].mean().item()
        std = samples[key].std().item()
        logging.info("{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}"
                     .format(key, getattr(args, name), mean, std))

    # Optionally plot histograms and pairwise correlations.
    if args.plot:
        import matplotlib.pyplot as plt
        import seaborn as sns

        # Plot individual histograms.
        fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names)))
        if len(names) == 1:
            axes = [axes]
        axes[0].set_title("Posterior parameter estimates")
        for ax, (name, key) in zip(axes, names.items()):
            truth = getattr(args, name)
            sns.distplot(samples[key], ax=ax, label="posterior")
            ax.axvline(truth, color="k", label="truth")
            ax.set_xlabel(key + " = " + name.replace("_", " "))
            ax.set_yticks(())
            ax.legend(loc="best")
        plt.tight_layout()

        # Plot pairwise joint distributions for selected variables.
        covariates = [(name, samples[name]) for name in names.values()]
        for i, aux in enumerate(samples["auxiliary"].squeeze(1).unbind(-2)):
            covariates.append(("aux[{},0]".format(i), aux[:, 0]))
            covariates.append(("aux[{},-1]".format(i), aux[:, -1]))
        N = len(covariates)
        fig, axes = plt.subplots(N, N, figsize=(8, 8), sharex="col", sharey="row")
        for i in range(N):
            axes[i][0].set_ylabel(covariates[i][0])
            axes[0][i].set_xlabel(covariates[i][0])
            axes[0][i].xaxis.set_label_position("top")
            for j in range(N):
                ax = axes[i][j]
                ax.set_xticks(())
                ax.set_yticks(())
                ax.scatter(covariates[j][1], -covariates[i][1],
                           lw=0, color="darkblue", alpha=0.3)
        plt.tight_layout()
        plt.subplots_adjust(wspace=0, hspace=0)

        # Plot Pearson correlation for every pair of unconstrained variables.
        def unconstrain(constraint, value):
            value = biject_to(constraint).inv(value)
            return value.reshape(args.num_samples, -1)

        covariates = [("R1", unconstrain(constraints.positive, samples["R0"]))]
        if not args.heterogeneous:
            covariates.append(
                ("rho", unconstrain(constraints.unit_interval, samples["rho"])))
        if "k" in samples:
            covariates.append(
                ("k", unconstrain(constraints.positive, samples["k"])))
        constraint = constraints.interval(-0.5, model.population + 0.5)
        for name, aux in zip(model.compartments, samples["auxiliary"].unbind(-2)):
            covariates.append((name, unconstrain(constraint, aux)))
        x = torch.cat([v for _, v in covariates], dim=-1)
        x -= x.mean(0)
        x /= x.std(0)
        x = x.t().matmul(x)
        x /= args.num_samples
        x.clamp_(min=-1, max=1)
        plt.figure(figsize=(8, 8))
        plt.imshow(x, cmap="bwr")
        ticks = torch.tensor([0] + [v.size(-1) for _, v in covariates]).cumsum(0)
        ticks = (ticks[1:] + ticks[:-1]) / 2
        plt.yticks(ticks, [name for name, _ in covariates])
        plt.xticks(())
        plt.tick_params(length=0)
        plt.title("Pearson correlation (unconstrained coordinates)")
        plt.tight_layout()