def predict()

in examples/contrib/epidemiology/sir.py [0:0]


def predict(args, model, truth):
    samples = model.predict(forecast=args.forecast)

    obs = model.data

    new_I = samples.get("S2I", samples.get("E2I"))
    median = new_I.median(dim=0).values
    logging.info("Median prediction of new infections (starting on day 0):\n{}"
                 .format(" ".join(map(str, map(int, median)))))

    # Optionally plot the latent and forecasted series of new infections.
    if args.plot:
        import matplotlib.pyplot as plt
        plt.figure()
        time = torch.arange(args.duration + args.forecast)
        p05 = new_I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
        p95 = new_I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
        plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
        plt.plot(time, median, "r-", label="median")
        plt.plot(time[:args.duration], obs, "k.", label="observed")
        if truth is not None:
            plt.plot(time, truth, "k--", label="truth")
        plt.axvline(args.duration - 0.5, color="gray", lw=1)
        plt.xlim(0, len(time) - 1)
        plt.ylim(0, None)
        plt.xlabel("day after first infection")
        plt.ylabel("new infections per day")
        plt.title("New infections in population of {}".format(args.population))
        plt.legend(loc="upper left")
        plt.tight_layout()

        # Plot Re time series.
        if args.heterogeneous:
            plt.figure()
            Re = samples["Re"]
            median = Re.median(dim=0).values
            p05 = Re.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
            p95 = Re.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
            plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
            plt.plot(time, median, "r-", label="median")
            plt.plot(time[:args.duration], obs, "k.", label="observed")
            plt.axvline(args.duration - 0.5, color="gray", lw=1)
            plt.xlim(0, len(time) - 1)
            plt.ylim(0, None)
            plt.xlabel("day after first infection")
            plt.ylabel("Re")
            plt.title("Effective reproductive number over time")
            plt.legend(loc="upper left")
            plt.tight_layout()