def pystan_vb_extract()

in pplbench/ppls/stan/inference.py [0:0]


    def pystan_vb_extract(results: OrderedDict):
        """
        From: https://gist.github.com/lwiklendt/9c7099288f85b59edc903a5aed2d2d64
        Converts vb results from pystan into a format similar to fit.extract()
        where fit is returned from sampling.
        This version is modified from the above reference to add a chain dimension
        for consistency with fit.extract(..)
        :param results: returned from vb
        """
        param_specs = results["sampler_param_names"]
        samples = results["sampler_params"]
        n = len(samples[0])

        # first pass, calculate the shape
        param_shapes: dict = OrderedDict()
        for param_spec in param_specs:
            splt = param_spec.split("[")
            name = splt[0]
            if len(splt) > 1:
                idxs = [
                    int(i) for i in splt[1][:-1].split(",")
                ]  # no +1 for shape calculation because pystan already returns 1-based indexes for vb!
            else:
                idxs = []
            param_shapes[name] = np.maximum(idxs, param_shapes.get(name, idxs))

        # create arrays
        params = OrderedDict(
            [
                (name, np.nan * np.empty((n,) + tuple(shape)))
                for name, shape in param_shapes.items()
            ]
        )

        # second pass, set arrays
        for param_spec, param_samples in zip(param_specs, samples):
            splt = param_spec.split("[")
            name = splt[0]
            if len(splt) > 1:
                idxs = [
                    int(i) - 1 for i in splt[1][:-1].split(",")
                ]  # -1 because pystan returns 1-based indexes for vb!
            else:
                idxs = []
            params[name][(...,) + tuple(idxs)] = param_samples

        # finally, add the chain dimension
        for name, value in params.items():
            params[name] = np.expand_dims(value, axis=1)

        return params