def extract_data_from_jags()

in pplbench/ppls/jags/logistic_regression.py [0:0]


    def extract_data_from_jags(self, samples: Dict) -> xr.Dataset:
        # dim 2 is the chains dimension so we squeeze it out
        return xr.Dataset(
            {
                # alpha dimensions are [1, samples], we want [samples]
                "alpha": (["draw"], samples["alpha"].squeeze(0)),
                # beta dimensions are [k, samples], we want [samples, k]
                "beta": (["draw", "feature"], samples["beta"].T),
            },
            coords={
                "draw": np.arange(samples["beta"].shape[1]),
                "feature": np.arange(samples["beta"].shape[0]),
            },
        )