in pyro/contrib/forecast/forecaster.py [0:0]
def __init__(self, model, data, covariates, *,
guide=None,
init_loc_fn=init_to_sample,
init_scale=0.1,
create_plates=None,
optim=None,
learning_rate=0.01,
betas=(0.9, 0.99),
learning_rate_decay=0.1,
clip_norm=10.0,
dct_gradients=False,
subsample_aware=False,
num_steps=1001,
num_particles=1,
vectorize_particles=True,
warm_start=False,
log_every=100):
assert data.size(-2) == covariates.size(-2)
super().__init__()
self.model = model
if guide is None:
guide = AutoNormal(self.model, init_loc_fn=init_loc_fn, init_scale=init_scale,
create_plates=create_plates)
self.guide = guide
# Initialize.
if warm_start:
model = PrefixWarmStartMessenger()(model)
guide = PrefixWarmStartMessenger()(guide)
if dct_gradients:
model = MarkDCTParamMessenger("time")(model)
guide = MarkDCTParamMessenger("time")(guide)
elbo = Trace_ELBO(num_particles=num_particles,
vectorize_particles=vectorize_particles)
elbo._guess_max_plate_nesting(model, guide, (data, covariates), {})
elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1) # force a time plate
losses = []
if num_steps:
if optim is None:
optim = DCTAdam({"lr": learning_rate, "betas": betas,
"lrd": learning_rate_decay ** (1 / num_steps),
"clip_norm": clip_norm,
"subsample_aware": subsample_aware})
svi = SVI(self.model, self.guide, optim, elbo)
for step in range(num_steps):
loss = svi.step(data, covariates) / data.numel()
if log_every and step % log_every == 0:
logger.info("step {: >4d} loss = {:0.6g}".format(step, loss))
losses.append(loss)
self.guide.create_plates = None # Disable subsampling after training.
self.max_plate_nesting = elbo.max_plate_nesting
self.losses = losses