in pplbench/ppls/stan/n_schools.py [0:0]
def extract_data_from_stan(self, samples: Dict) -> xr.Dataset:
"""
Takes the output of Stan and converts into a format expected
by PPLBench.
:param samples: samples dictionary from Stan
:returns: a dataset over inferred parameters
"""
return xr.Dataset(
{
"sigma_state": (["draw"], samples["sigma_state"].squeeze(1)),
"sigma_district": (["draw"], samples["sigma_district"].squeeze(1)),
"sigma_type": (["draw"], samples["sigma_type"].squeeze(1)),
"beta_baseline": (["draw"], samples["beta_baseline"].squeeze(1)),
"beta_state": (["draw", "state"], samples["beta_state"].squeeze(1)),
"beta_district": (
["draw", "state", "district"],
samples["beta_district"].squeeze(1),
),
"beta_type": (["draw", "type"], samples["beta_type"].squeeze(1)),
},
coords={
"draw": np.arange(samples["beta_baseline"].shape[0]),
"state": np.arange(self.attrs["num_states"]),
"district": np.arange(self.attrs["num_districts_per_state"]),
"type": np.arange(self.attrs["num_types"]),
},
)