in pplbench/ppls/jags/robust_regression.py [0:0]
def extract_data_from_jags(self, samples: Dict) -> xr.Dataset:
# dim 2 is the chains dimension so we squeeze it out
return xr.Dataset(
{
# alpha, nu, sigma dimensions are [1, samples], we want [samples]
"alpha": (["draw"], samples["alpha"].squeeze(0)),
"nu": (["draw"], samples["nu"].squeeze(0)),
"sigma": (["draw"], samples["sigma"].squeeze(0)),
# beta dimensions are [k, samples], we want [samples, k]
"beta": (["draw", "feature"], samples["beta"].T),
},
coords={
"draw": np.arange(samples["beta"].shape[1]),
"feature": np.arange(samples["beta"].shape[0]),
},
)