in pyro/infer/predictive.py [0:0]
def _predictive(model, posterior_samples, num_samples, return_sites=(),
return_trace=False, parallel=False, model_args=(), model_kwargs={}):
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1)
model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs))
reshaped_samples = {}
for name, sample in posterior_samples.items():
sample_shape = sample.shape[1:]
sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape)
reshaped_samples[name] = sample
if return_trace:
trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
.get_trace(*model_args, **model_kwargs)
return trace
return_site_shapes = {}
for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape
# non-empty return-sites
if return_sites:
if site in return_sites:
return_site_shapes[site] = site_shape
# special case (for guides): include all sites
elif return_sites is None:
return_site_shapes[site] = site_shape
# default case: return sites = ()
# include all sites not in posterior samples
elif site not in posterior_samples:
return_site_shapes[site] = site_shape
# handle _RETURN site
if return_sites is not None and '_RETURN' in return_sites:
value = model_trace.nodes['_RETURN']['value']
shape = (num_samples,) + value.shape if torch.is_tensor(value) else None
return_site_shapes['_RETURN'] = shape
if not parallel:
return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples,
return_site_shapes, return_trace=False)
trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
.get_trace(*model_args, **model_kwargs)
predictions = {}
for site, shape in return_site_shapes.items():
value = trace.nodes[site]['value']
if site == '_RETURN' and shape is None:
predictions[site] = value
continue
if value.numel() < reduce((lambda x, y: x * y), shape):
predictions[site] = value.expand(shape)
else:
predictions[site] = value.reshape(shape)
return predictions