in leaderboard/irt/pyirt/three_param_logistic.py [0:0]
def model_hierarchical(self, models, items, obs):
mu_b = pyro.sample(
"mu_b",
dist.Normal(
torch.tensor(0.0, device=self.device), torch.tensor(1.0e6, device=self.device),
),
)
u_b = pyro.sample(
"u_b",
dist.Gamma(
torch.tensor(1.0, device=self.device), torch.tensor(1.0, device=self.device),
),
)
mu_theta = pyro.sample(
"mu_theta",
dist.Normal(
torch.tensor(0.0, device=self.device), torch.tensor(1.0e6, device=self.device),
),
)
u_theta = pyro.sample(
"u_theta",
dist.Gamma(
torch.tensor(1.0, device=self.device), torch.tensor(1.0, device=self.device),
),
)
mu_gamma = pyro.sample(
"mu_gamma",
dist.Normal(
torch.tensor(0.0, device=self.device), torch.tensor(1.0e6, device=self.device),
),
)
u_gamma = pyro.sample(
"u_gamma",
dist.Gamma(
torch.tensor(1.0, device=self.device), torch.tensor(1.0, device=self.device),
),
)
# Fraction of feasible
fixed = True
if fixed:
# Implementation 1: Simple variable to be fit
lambdas = pyro.param(
"lambdas",
torch.ones(self.num_items, device=self.device),
constraint=constraints.unit_interval,
)
else:
# Implementation 2: RV sampled from common beta distribution
lambda_alpha = pyro.param(
"lambda_alpha", torch.tensor(20.0), constraint=constraints.positive,
)
lambda_beta = pyro.param(
"lambda_beta", torch.tensor(1), constraint=constraints.positive,
)
with pyro.plate("lambdas", self.num_items, device=self.device):
lambdas = pyro.sample("lambda", dist.Beta(lambda_alpha, lambda_beta))
with pyro.plate("thetas", self.num_models, device=self.device):
ability = pyro.sample("theta", dist.Normal(mu_theta, 1.0 / u_theta))
with pyro.plate("bs", self.num_items, device=self.device):
diff = pyro.sample("b", dist.Normal(mu_b, 1.0 / u_b))
with pyro.plate("gammas", self.num_items, device=self.device):
disc = pyro.sample("gamma", dist.Normal(mu_gamma, 1.0 / u_gamma))
with pyro.plate("observe_data", obs.size(0)):
p_star = torch.sigmoid(disc[items] * (ability[models] - diff[items]))
pyro.sample(
"obs", dist.Bernoulli(probs=lambdas[items] * p_star), obs=obs,
)