in pplbench/ppls/pymc3/noisy_or_topic.py [0:0]
def extract_data_from_pymc3(self, samples: MultiTrace) -> xr.Dataset:
return xr.Dataset(
{
"active": (
["draw", "topic"],
np.concatenate(
tuple(
np.expand_dims(samples[f"active[{j}]"], 1)
for j in range(1 + self.num_topics)
),
axis=1,
),
)
},
coords={
"draw": np.arange(len(samples["active[0]"])),
"topic": np.arange(1 + self.num_topics),
},
)