in orbit/pyro/lgt.py [0:0]
def __call__(self):
response = self.response
num_of_obs = self.num_of_obs
extra_out = {}
# added for tempured sampling
T = self.t_star
# smoothing params
if self.lev_sm_input < 0:
lev_sm = pyro.sample("lev_sm", dist.Uniform(0, 1))
else:
lev_sm = torch.tensor(self.lev_sm_input, dtype=torch.double)
extra_out["lev_sm"] = lev_sm
if self.slp_sm_input < 0:
slp_sm = pyro.sample("slp_sm", dist.Uniform(0, 1))
else:
slp_sm = torch.tensor(self.slp_sm_input, dtype=torch.double)
extra_out["slp_sm"] = slp_sm
# residual tuning parameters
nu = pyro.sample("nu", dist.Uniform(self.min_nu, self.max_nu))
# prior for residuals
obs_sigma = pyro.sample("obs_sigma", dist.HalfCauchy(self.cauchy_sd))
# regression parameters
if self.num_of_pr == 0:
pr = torch.zeros(num_of_obs)
pr_beta = pyro.deterministic("pr_beta", torch.zeros(0))
else:
with pyro.plate("pr", self.num_of_pr):
# fixed scale ridge
if self.reg_penalty_type == 0:
pr_sigma = self.pr_sigma_prior
# auto scale ridge
elif self.reg_penalty_type == 2:
# weak prior for sigma
pr_sigma = pyro.sample(
"pr_sigma", dist.HalfCauchy(self.auto_ridge_scale)
)
# case when it is not lasso
if self.reg_penalty_type != 1:
# weak prior for betas
pr_beta = pyro.sample(
"pr_beta",
dist.FoldedDistribution(
dist.Normal(self.pr_beta_prior, pr_sigma)
),
)
else:
pr_beta = pyro.sample(
"pr_beta",
dist.FoldedDistribution(
dist.Laplace(self.pr_beta_prior, self.lasso_scale)
),
)
pr = pr_beta @ self.pr_mat.transpose(-1, -2)
if self.num_of_nr == 0:
nr = torch.zeros(num_of_obs)
nr_beta = pyro.deterministic("nr_beta", torch.zeros(0))
else:
with pyro.plate("nr", self.num_of_nr):
# fixed scale ridge
if self.reg_penalty_type == 0:
nr_sigma = self.nr_sigma_prior
# auto scale ridge
elif self.reg_penalty_type == 2:
# weak prior for sigma
nr_sigma = pyro.sample(
"nr_sigma", dist.HalfCauchy(self.auto_ridge_scale)
)
# case when it is not lasso
if self.reg_penalty_type != 1:
# weak prior for betas
nr_beta = -1.0 * pyro.sample(
"nr_beta",
dist.FoldedDistribution(
dist.Normal(self.nr_beta_prior, nr_sigma)
),
)
else:
nr_beta = -1.0 * pyro.sample(
"nr_beta",
dist.FoldedDistribution(
dist.Laplace(self.nr_beta_prior, self.lasso_scale)
),
)
nr = nr_beta @ self.nr_mat.transpose(-1, -2)
if self.num_of_rr == 0:
rr = torch.zeros(num_of_obs)
rr_beta = pyro.deterministic("rr_beta", torch.zeros(0))
else:
with pyro.plate("rr", self.num_of_rr):
# fixed scale ridge
if self.reg_penalty_type == 0:
rr_sigma = self.rr_sigma_prior
# auto scale ridge
elif self.reg_penalty_type == 2:
# weak prior for sigma
rr_sigma = pyro.sample(
"rr_sigma", dist.HalfCauchy(self.auto_ridge_scale)
)
# case when it is not lasso
if self.reg_penalty_type != 1:
# weak prior for betas
rr_beta = pyro.sample(
"rr_beta", dist.Normal(self.rr_beta_prior, rr_sigma)
)
else:
rr_beta = pyro.sample(
"rr_beta", dist.Laplace(self.rr_beta_prior, self.lasso_scale)
)
rr = rr_beta @ self.rr_mat.transpose(-1, -2)
# a hack to make sure we don't use a dimension "1" due to rr_beta and pr_beta sampling
r = pr + nr + rr
if r.dim() > 1:
r = r.unsqueeze(-2)
# trend parameters
# local trend proportion
lt_coef = pyro.sample("lt_coef", dist.Uniform(0, 1))
# global trend proportion
gt_coef = pyro.sample("gt_coef", dist.Uniform(-0.5, 0.5))
# global trend parameter
gt_pow = pyro.sample("gt_pow", dist.Uniform(0, 1))
# seasonal parameters
if self.is_seasonal:
# seasonality smoothing parameter
if self.sea_sm_input < 0:
sea_sm = pyro.sample("sea_sm", dist.Uniform(0, 1))
else:
sea_sm = torch.tensor(self.sea_sm_input, dtype=torch.double)
extra_out["sea_sm"] = sea_sm
# initial seasonality
# 33% lift is with 1 sd prob.
init_sea = pyro.sample(
"init_sea", dist.Normal(0, 0.33).expand([self.seasonality]).to_event(1)
)
init_sea = init_sea - init_sea.mean(-1, keepdim=True)
b = [None] * num_of_obs # slope
l = [None] * num_of_obs # level
if self.is_seasonal:
s = [None] * (self.num_of_obs + self.seasonality)
for t in range(self.seasonality):
s[t] = init_sea[..., t]
s[self.seasonality] = init_sea[..., 0]
else:
s = [torch.tensor(0.0)] * num_of_obs
# states initial condition
b[0] = torch.zeros_like(slp_sm)
if self.is_seasonal:
l[0] = response[0] - r[..., 0] - s[0]
else:
l[0] = response[0] - r[..., 0]
# update process
for t in range(1, num_of_obs):
# this update equation with l[t-1] ONLY.
# intentionally different from the Holt-Winter form
# this change is suggested from Slawek's original SLGT model
l[t] = lev_sm * (response[t] - s[t] - r[..., t]) + (1 - lev_sm) * l[t - 1]
b[t] = slp_sm * (l[t] - l[t - 1]) + (1 - slp_sm) * b[t - 1]
if self.is_seasonal:
s[t + self.seasonality] = (
sea_sm * (response[t] - l[t] - r[..., t]) + (1 - sea_sm) * s[t]
)
# evaluation process
# vectorize as much math as possible
for lst in [b, l, s]:
# torch.stack requires all items to have the same shape, but the
# initial items of our lists may not have batch_shape, so we expand.
lst[0] = lst[0].expand_as(lst[-1])
b = torch.stack(b, dim=-1).reshape(b[0].shape[:-1] + (-1,))
l = torch.stack(l, dim=-1).reshape(l[0].shape[:-1] + (-1,))
s = torch.stack(s, dim=-1).reshape(s[0].shape[:-1] + (-1,))
lgt_sum = l + gt_coef * l.abs() ** gt_pow + lt_coef * b
lgt_sum = torch.cat([l[..., :1], lgt_sum[..., :-1]], dim=-1) # shift by 1
# a hack here as well to get rid of the extra "1" in r.shape
if r.dim() >= 2:
r = r.squeeze(-2)
yhat = lgt_sum + s[..., :num_of_obs] + r
with pyro.plate("response_plate", num_of_obs - 1):
with pyro.poutine.scale(scale=1.0 / T):
pyro.sample(
"response",
dist.StudentT(nu, yhat[..., 1:], obs_sigma),
obs=response[1:],
)
log_prob = dist.StudentT(nu, yhat[..., 1:], obs_sigma).log_prob(response[1:])
# we care beta not the pr_beta, nr_beta, ...
extra_out["beta"] = torch.cat([pr_beta, nr_beta, rr_beta], dim=-1)
extra_out.update(
{"b": b, "l": l, "s": s, "lgt_sum": lgt_sum, "log_prob": log_prob}
)
return extra_out