in pyro/infer/trace_mmd.py [0:0]
def _differentiable_loss_parts(self, model, guide, args, kwargs):
all_model_samples = defaultdict(list)
all_guide_samples = defaultdict(list)
loglikelihood = 0.0
penalty = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
if self.vectorize_particles:
model_trace_independent = poutine.trace(
self._vectorized_num_particles(model)
).get_trace(*args, **kwargs)
else:
model_trace_independent = poutine.trace(model, graph_type='flat').get_trace(*args, **kwargs)
loglikelihood_particle = 0.0
for name, model_site in model_trace.nodes.items():
if model_site['type'] == 'sample':
if name in guide_trace and not model_site['is_observed']:
guide_site = guide_trace.nodes[name]
independent_model_site = model_trace_independent.nodes[name]
if not independent_model_site["fn"].has_rsample:
raise ValueError("Model site {} is not reparameterizable".format(name))
if not guide_site["fn"].has_rsample:
raise ValueError("Guide site {} is not reparameterizable".format(name))
particle_dim = -self.max_plate_nesting - independent_model_site["fn"].event_dim
model_samples = independent_model_site['value']
guide_samples = guide_site['value']
if self.vectorize_particles:
model_samples = model_samples.transpose(-model_samples.dim(), particle_dim)
model_samples = model_samples.view(model_samples.shape[0], -1)
guide_samples = guide_samples.transpose(-guide_samples.dim(), particle_dim)
guide_samples = guide_samples.view(guide_samples.shape[0], -1)
else:
model_samples = model_samples.view(1, -1)
guide_samples = guide_samples.view(1, -1)
all_model_samples[name].append(model_samples)
all_guide_samples[name].append(guide_samples)
else:
loglikelihood_particle = loglikelihood_particle + model_site['log_prob_sum']
loglikelihood = loglikelihood_particle / self.num_particles + loglikelihood
for name in all_model_samples.keys():
all_model_samples[name] = torch.cat(all_model_samples[name])
all_guide_samples[name] = torch.cat(all_guide_samples[name])
divergence = _compute_mmd(all_model_samples[name], all_guide_samples[name], kernel=self._kernel[name])
penalty = self._mmd_scale[name] * divergence + penalty
warn_if_nan(loglikelihood, "loglikelihood")
warn_if_nan(penalty, "penalty")
return loglikelihood, penalty