in pplbench/ppls/stan/inference.py [0:0]
def pystan_vb_extract(results: OrderedDict):
"""
From: https://gist.github.com/lwiklendt/9c7099288f85b59edc903a5aed2d2d64
Converts vb results from pystan into a format similar to fit.extract()
where fit is returned from sampling.
This version is modified from the above reference to add a chain dimension
for consistency with fit.extract(..)
:param results: returned from vb
"""
param_specs = results["sampler_param_names"]
samples = results["sampler_params"]
n = len(samples[0])
# first pass, calculate the shape
param_shapes: dict = OrderedDict()
for param_spec in param_specs:
splt = param_spec.split("[")
name = splt[0]
if len(splt) > 1:
idxs = [
int(i) for i in splt[1][:-1].split(",")
] # no +1 for shape calculation because pystan already returns 1-based indexes for vb!
else:
idxs = []
param_shapes[name] = np.maximum(idxs, param_shapes.get(name, idxs))
# create arrays
params = OrderedDict(
[
(name, np.nan * np.empty((n,) + tuple(shape)))
for name, shape in param_shapes.items()
]
)
# second pass, set arrays
for param_spec, param_samples in zip(param_specs, samples):
splt = param_spec.split("[")
name = splt[0]
if len(splt) > 1:
idxs = [
int(i) - 1 for i in splt[1][:-1].split(",")
] # -1 because pystan returns 1-based indexes for vb!
else:
idxs = []
params[name][(...,) + tuple(idxs)] = param_samples
# finally, add the chain dimension
for name, value in params.items():
params[name] = np.expand_dims(value, axis=1)
return params