def extract_data_from_stan()

in pplbench/ppls/stan/robust_regression.py [0:0]


    def extract_data_from_stan(self, samples: Dict) -> xr.Dataset:
        """
        Takes the output of Stan and converts into a format expected
        by PPLBench.
        :param samples: samples dictionary from Stan
        :returns: a dataset over inferred parameters
        """
        # dim 1 is the chains dimension so we squeeze it out
        return xr.Dataset(
            {
                "alpha": (["draw"], samples["alpha"].squeeze(1)),
                "beta": (["draw", "feature"], samples["beta"].squeeze(1)),
                "nu": (["draw"], samples["nu"].squeeze(1)),
                "sigma": (["draw"], samples["sigma"].squeeze(1)),
            },
            coords={
                "draw": np.arange(samples["beta"].shape[0]),
                "feature": np.arange(samples["beta"].shape[-1]),
            },
        )