in examples/contrib/forecast/bart.py [0:0]
def model(self, zero_data, covariates):
period = 24 * 7
duration, dim = zero_data.shape[-2:]
assert dim == 2 # Data is bivariate: (arrivals, departures).
# Sample global parameters.
noise_scale = pyro.sample("noise_scale",
dist.LogNormal(torch.full((dim,), -3), 1).to_event(1))
assert noise_scale.shape[-1:] == (dim,)
trans_timescale = pyro.sample("trans_timescale",
dist.LogNormal(torch.zeros(dim), 1).to_event(1))
assert trans_timescale.shape[-1:] == (dim,)
trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period))
trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim,))
assert trans_loc.shape[-1:] == (dim,)
trans_scale = pyro.sample("trans_scale",
dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
trans_corr = pyro.sample("trans_corr",
dist.LKJCorrCholesky(dim, torch.ones(())))
trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr
assert trans_scale_tril.shape[-2:] == (dim, dim)
obs_scale = pyro.sample("obs_scale",
dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
obs_corr = pyro.sample("obs_corr",
dist.LKJCorrCholesky(dim, torch.ones(())))
obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr
assert obs_scale_tril.shape[-2:] == (dim, dim)
# Note the initial seasonality should be sampled in a plate with the
# same dim as the time_plate, dim=-1. That way we can repeat the dim
# below using periodic_repeat().
with pyro.plate("season_plate", period, dim=-1):
season_init = pyro.sample("season_init",
dist.Normal(torch.zeros(dim), 1).to_event(1))
assert season_init.shape[-2:] == (period, dim)
# Sample independent noise at each time step.
with self.time_plate:
season_noise = pyro.sample("season_noise",
dist.Normal(0, noise_scale).to_event(1))
assert season_noise.shape[-2:] == (duration, dim)
# Construct a prediction. This prediction has an exactly repeated
# seasonal part plus slow seasonal drift. We use two deterministic,
# linear functions to transform our diagonal Normal noise to nontrivial
# samples from a Gaussian process.
prediction = (periodic_repeat(season_init, duration, dim=-2) +
periodic_cumsum(season_noise, period, dim=-2))
assert prediction.shape[-2:] == (duration, dim)
# Construct a joint noise model. This model is a GaussianHMM, whose
# .rsample() and .log_prob() methods are parallelized over time; this
# this entire model is parallelized over time.
init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
trans_mat = trans_timescale.neg().exp().diag_embed()
trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril)
obs_mat = torch.eye(dim)
obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril)
noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist,
duration=duration)
assert noise_model.event_shape == (duration, dim)
# The final statement registers our noise model and prediction.
self.predict(noise_model, prediction)