def __init__()

in pyro/contrib/forecast/forecaster.py [0:0]


    def __init__(self, model, data, covariates, *,
                 guide=None,
                 init_loc_fn=init_to_sample,
                 init_scale=0.1,
                 create_plates=None,
                 optim=None,
                 learning_rate=0.01,
                 betas=(0.9, 0.99),
                 learning_rate_decay=0.1,
                 clip_norm=10.0,
                 dct_gradients=False,
                 subsample_aware=False,
                 num_steps=1001,
                 num_particles=1,
                 vectorize_particles=True,
                 warm_start=False,
                 log_every=100):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        self.model = model
        if guide is None:
            guide = AutoNormal(self.model, init_loc_fn=init_loc_fn, init_scale=init_scale,
                               create_plates=create_plates)
        self.guide = guide

        # Initialize.
        if warm_start:
            model = PrefixWarmStartMessenger()(model)
            guide = PrefixWarmStartMessenger()(guide)
        if dct_gradients:
            model = MarkDCTParamMessenger("time")(model)
            guide = MarkDCTParamMessenger("time")(guide)
        elbo = Trace_ELBO(num_particles=num_particles,
                          vectorize_particles=vectorize_particles)
        elbo._guess_max_plate_nesting(model, guide, (data, covariates), {})
        elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1)  # force a time plate

        losses = []
        if num_steps:
            if optim is None:
                optim = DCTAdam({"lr": learning_rate, "betas": betas,
                                 "lrd": learning_rate_decay ** (1 / num_steps),
                                 "clip_norm": clip_norm,
                                 "subsample_aware": subsample_aware})
            svi = SVI(self.model, self.guide, optim, elbo)
            for step in range(num_steps):
                loss = svi.step(data, covariates) / data.numel()
                if log_every and step % log_every == 0:
                    logger.info("step {: >4d} loss = {:0.6g}".format(step, loss))
                losses.append(loss)

        self.guide.create_plates = None  # Disable subsampling after training.
        self.max_plate_nesting = elbo.max_plate_nesting
        self.losses = losses