def extract_data_from_jags()

in pplbench/ppls/jags/n_schools.py [0:0]


    def extract_data_from_jags(self, samples: Dict) -> xr.Dataset:
        return xr.Dataset(
            {
                # JAGS adds an extra dimension for scalars
                "sigma_state": (["draw"], samples["sigma_state"].squeeze(0)),
                "sigma_district": (["draw"], samples["sigma_district"].squeeze(0)),
                "sigma_type": (["draw"], samples["sigma_type"].squeeze(0)),
                "beta_baseline": (["draw"], samples["beta_baseline"].squeeze(0)),
                # draw is the last dimension
                "beta_state": (["state", "draw"], samples["beta_state"]),
                "beta_district": (
                    ["state", "district", "draw"],
                    samples["beta_district"],
                ),
                "beta_type": (["type", "draw"], samples["beta_type"]),
            },
            coords={
                "draw": np.arange(samples["beta_baseline"].shape[-1]),
                "state": np.arange(self.attrs["num_states"]),
                "district": np.arange(self.attrs["num_districts_per_state"]),
                "type": np.arange(self.attrs["num_types"]),
            },
        )