def guide_hierarchical()

in leaderboard/irt/pyirt/multidim_one_param_logistic.py [0:0]


    def guide_hierarchical(self, models, items, obs):
        loc_mu_b_param = pyro.param("loc_mu_b", torch.zeros(self.dims, device=self.device))
        scale_mu_b_param = pyro.param(
            "scale_mu_b",
            1e2 * torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )

        loc_mu_theta_param = pyro.param("loc_mu_theta", torch.zeros(self.dims, device=self.device))
        scale_mu_theta_param = pyro.param(
            "scale_mu_theta",
            1e2 * torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )

        loc_mu_gamma_param = pyro.param("loc_mu_gamma", torch.zeros(self.dims, device=self.device))
        scale_mu_gamma_param = pyro.param(
            "scale_mu_gamma",
            1.0e2 * torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )

        alpha_b_param = pyro.param(
            "alpha_b", torch.ones(self.dims, device=self.device), constraint=constraints.positive,
        )
        beta_b_param = pyro.param(
            "beta_b", torch.ones(self.dims, device=self.device), constraint=constraints.positive,
        )

        alpha_theta_param = pyro.param(
            "alpha_theta",
            torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )
        beta_theta_param = pyro.param(
            "beta_theta",
            torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )

        alpha_gamma_param = pyro.param(
            "alpha_gamma",
            torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )
        beta_gamma_param = pyro.param(
            "beta_gamma",
            torch.ones(self.dims, device=self.device),
            constraint=constraints.positive,
        )

        m_theta_param = pyro.param(
            "loc_ability", torch.zeros([self.num_models, self.dims], device=self.device)
        )
        s_theta_param = pyro.param(
            "scale_ability",
            torch.ones([self.num_models, self.dims], device=self.device),
            constraint=constraints.positive,
        )

        m_b_param = pyro.param(
            "loc_diff", torch.zeros([self.num_items, self.dims], device=self.device)
        )
        s_b_param = pyro.param(
            "scale_diff",
            torch.ones([self.num_items, self.dims], device=self.device),
            constraint=constraints.positive,
        )

        m_gamma_param = pyro.param(
            "loc_disc", torch.zeros([self.num_items, self.dims], device=self.device)
        )
        s_gamma_param = pyro.param(
            "scale_disc",
            torch.ones([self.num_items, self.dims], device=self.device),
            constraint=constraints.positive,
        )

        # sample statements
        mu_b = pyro.sample("mu_b", dist.Normal(loc_mu_b_param, scale_mu_b_param))
        u_b = pyro.sample("u_b", dist.Gamma(alpha_b_param, beta_b_param))
        mu_theta = pyro.sample("mu_theta", dist.Normal(loc_mu_theta_param, scale_mu_theta_param))
        u_theta = pyro.sample("u_theta", dist.Gamma(alpha_theta_param, beta_theta_param))

        mu_gamma = pyro.sample("mu_gamma", dist.Normal(loc_mu_gamma_param, scale_mu_gamma_param))
        u_gamma = pyro.sample("u_gamma", dist.Gamma(alpha_gamma_param, beta_gamma_param))

        with pyro.plate("thetas", self.num_models, dim=-2, device=self.device):
            with pyro.plate("theta_dims", self.dims, dim=-1):
                pyro.sample("theta", dist.Normal(m_theta_param, s_theta_param))

        with pyro.plate("bs", self.num_items, dim=-2, device=self.device):
            with pyro.plate("bs_dims", self.dims, dim=-1):
                pyro.sample("b", dist.Normal(m_b_param, s_b_param))

        with pyro.plate("gammas", self.num_items, dim=-2, device=self.device):
            with pyro.plate("gamma_dims", self.dims, dim=-1, device=self.device):
                pyro.sample("gamma", dist.Normal(m_gamma_param, s_gamma_param))