in pplbench/ppls/stan/robust_regression.py [0:0]
def extract_data_from_stan(self, samples: Dict) -> xr.Dataset:
"""
Takes the output of Stan and converts into a format expected
by PPLBench.
:param samples: samples dictionary from Stan
:returns: a dataset over inferred parameters
"""
# dim 1 is the chains dimension so we squeeze it out
return xr.Dataset(
{
"alpha": (["draw"], samples["alpha"].squeeze(1)),
"beta": (["draw", "feature"], samples["beta"].squeeze(1)),
"nu": (["draw"], samples["nu"].squeeze(1)),
"sigma": (["draw"], samples["sigma"].squeeze(1)),
},
coords={
"draw": np.arange(samples["beta"].shape[0]),
"feature": np.arange(samples["beta"].shape[-1]),
},
)