def get_model()

in pplbench/ppls/pymc3/noisy_or_topic.py [0:0]


    def get_model(self, data: xr.Dataset) -> pm.Model:
        # transpose the dataset to ensure that it is the way we expect
        data = data.transpose("sentence", "word")
        active = [None for _ in range(1 + self.num_topics)]
        with pm.Model() as model:
            S = pm.Data("S_obs", data.S.values[0])
            active[0] = pm.Bernoulli("active[0]", p=1.0)
            for j in range(1, self.num_topics + 1):
                # note: if p = 1 - exp(-w) then logit(p) = log(1-exp(-w)) + w
                w = self.edge_weight[j, :j] @ active[:j]
                topic_logit = pm.math.log1mexp(w) + w
                active[j] = pm.Bernoulli(f"active[{j}]", logit_p=topic_logit)
            w = pm.math.dot(self.edge_weight[1 + self.num_topics :], active)
            word_logit = pm.math.log1mexp(w) + w
            pm.Bernoulli("S", logit_p=word_logit, observed=S, shape=self.num_words)

        return model