in leaderboard/irt/pyirt/multidim_one_param_logistic.py [0:0]
def guide_hierarchical(self, models, items, obs):
loc_mu_b_param = pyro.param("loc_mu_b", torch.zeros(self.dims, device=self.device))
scale_mu_b_param = pyro.param(
"scale_mu_b",
1e2 * torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
loc_mu_theta_param = pyro.param("loc_mu_theta", torch.zeros(self.dims, device=self.device))
scale_mu_theta_param = pyro.param(
"scale_mu_theta",
1e2 * torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
loc_mu_gamma_param = pyro.param("loc_mu_gamma", torch.zeros(self.dims, device=self.device))
scale_mu_gamma_param = pyro.param(
"scale_mu_gamma",
1.0e2 * torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
alpha_b_param = pyro.param(
"alpha_b", torch.ones(self.dims, device=self.device), constraint=constraints.positive,
)
beta_b_param = pyro.param(
"beta_b", torch.ones(self.dims, device=self.device), constraint=constraints.positive,
)
alpha_theta_param = pyro.param(
"alpha_theta",
torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
beta_theta_param = pyro.param(
"beta_theta",
torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
alpha_gamma_param = pyro.param(
"alpha_gamma",
torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
beta_gamma_param = pyro.param(
"beta_gamma",
torch.ones(self.dims, device=self.device),
constraint=constraints.positive,
)
m_theta_param = pyro.param(
"loc_ability", torch.zeros([self.num_models, self.dims], device=self.device)
)
s_theta_param = pyro.param(
"scale_ability",
torch.ones([self.num_models, self.dims], device=self.device),
constraint=constraints.positive,
)
m_b_param = pyro.param(
"loc_diff", torch.zeros([self.num_items, self.dims], device=self.device)
)
s_b_param = pyro.param(
"scale_diff",
torch.ones([self.num_items, self.dims], device=self.device),
constraint=constraints.positive,
)
m_gamma_param = pyro.param(
"loc_disc", torch.zeros([self.num_items, self.dims], device=self.device)
)
s_gamma_param = pyro.param(
"scale_disc",
torch.ones([self.num_items, self.dims], device=self.device),
constraint=constraints.positive,
)
# sample statements
mu_b = pyro.sample("mu_b", dist.Normal(loc_mu_b_param, scale_mu_b_param))
u_b = pyro.sample("u_b", dist.Gamma(alpha_b_param, beta_b_param))
mu_theta = pyro.sample("mu_theta", dist.Normal(loc_mu_theta_param, scale_mu_theta_param))
u_theta = pyro.sample("u_theta", dist.Gamma(alpha_theta_param, beta_theta_param))
mu_gamma = pyro.sample("mu_gamma", dist.Normal(loc_mu_gamma_param, scale_mu_gamma_param))
u_gamma = pyro.sample("u_gamma", dist.Gamma(alpha_gamma_param, beta_gamma_param))
with pyro.plate("thetas", self.num_models, dim=-2, device=self.device):
with pyro.plate("theta_dims", self.dims, dim=-1):
pyro.sample("theta", dist.Normal(m_theta_param, s_theta_param))
with pyro.plate("bs", self.num_items, dim=-2, device=self.device):
with pyro.plate("bs_dims", self.dims, dim=-1):
pyro.sample("b", dist.Normal(m_b_param, s_b_param))
with pyro.plate("gammas", self.num_items, dim=-2, device=self.device):
with pyro.plate("gamma_dims", self.dims, dim=-1, device=self.device):
pyro.sample("gamma", dist.Normal(m_gamma_param, s_gamma_param))