in pplbench/ppls/stan/crowd_sourced_annotation.py [0:0]
def extract_data_from_stan(self, samples: Dict) -> xr.Dataset:
"""
Takes the output of Stan and converts into a format expected
by PPLBench.
:param samples: samples dictionary from Stan
:returns: a dataset over inferred parameters
"""
# dim 1 is the chains dimension so we squeeze it out
return xr.Dataset(
{
"prev": (["draw", "true_category"], samples["prev"].squeeze(1)),
"confusion_matrix": (
["draw", "labeler", "true_category", "obs_category"],
samples["confusion_matrix"].squeeze(1),
),
},
coords={
"draw": np.arange(samples["prev"].shape[0]),
"true_category": np.arange(samples["prev"].shape[-1]),
"obs_category": np.arange(samples["prev"].shape[-1]),
"labeler": np.arange(samples["confusion_matrix"].squeeze(1).shape[1]),
},
)