in pplbench/ppls/jags/n_schools.py [0:0]
def extract_data_from_jags(self, samples: Dict) -> xr.Dataset:
return xr.Dataset(
{
# JAGS adds an extra dimension for scalars
"sigma_state": (["draw"], samples["sigma_state"].squeeze(0)),
"sigma_district": (["draw"], samples["sigma_district"].squeeze(0)),
"sigma_type": (["draw"], samples["sigma_type"].squeeze(0)),
"beta_baseline": (["draw"], samples["beta_baseline"].squeeze(0)),
# draw is the last dimension
"beta_state": (["state", "draw"], samples["beta_state"]),
"beta_district": (
["state", "district", "draw"],
samples["beta_district"],
),
"beta_type": (["type", "draw"], samples["beta_type"]),
},
coords={
"draw": np.arange(samples["beta_baseline"].shape[-1]),
"state": np.arange(self.attrs["num_states"]),
"district": np.arange(self.attrs["num_districts_per_state"]),
"type": np.arange(self.attrs["num_types"]),
},
)