def _differentiable_loss_parts()

in pyro/infer/trace_mmd.py [0:0]


    def _differentiable_loss_parts(self, model, guide, args, kwargs):
        all_model_samples = defaultdict(list)
        all_guide_samples = defaultdict(list)

        loglikelihood = 0.0
        penalty = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
            if self.vectorize_particles:
                model_trace_independent = poutine.trace(
                    self._vectorized_num_particles(model)
                ).get_trace(*args, **kwargs)
            else:
                model_trace_independent = poutine.trace(model, graph_type='flat').get_trace(*args, **kwargs)

            loglikelihood_particle = 0.0
            for name, model_site in model_trace.nodes.items():
                if model_site['type'] == 'sample':
                    if name in guide_trace and not model_site['is_observed']:
                        guide_site = guide_trace.nodes[name]
                        independent_model_site = model_trace_independent.nodes[name]
                        if not independent_model_site["fn"].has_rsample:
                            raise ValueError("Model site {} is not reparameterizable".format(name))
                        if not guide_site["fn"].has_rsample:
                            raise ValueError("Guide site {} is not reparameterizable".format(name))

                        particle_dim = -self.max_plate_nesting - independent_model_site["fn"].event_dim

                        model_samples = independent_model_site['value']
                        guide_samples = guide_site['value']

                        if self.vectorize_particles:
                            model_samples = model_samples.transpose(-model_samples.dim(), particle_dim)
                            model_samples = model_samples.view(model_samples.shape[0], -1)

                            guide_samples = guide_samples.transpose(-guide_samples.dim(), particle_dim)
                            guide_samples = guide_samples.view(guide_samples.shape[0], -1)
                        else:
                            model_samples = model_samples.view(1, -1)
                            guide_samples = guide_samples.view(1, -1)

                        all_model_samples[name].append(model_samples)
                        all_guide_samples[name].append(guide_samples)
                    else:
                        loglikelihood_particle = loglikelihood_particle + model_site['log_prob_sum']

            loglikelihood = loglikelihood_particle / self.num_particles + loglikelihood

        for name in all_model_samples.keys():
            all_model_samples[name] = torch.cat(all_model_samples[name])
            all_guide_samples[name] = torch.cat(all_guide_samples[name])
            divergence = _compute_mmd(all_model_samples[name], all_guide_samples[name], kernel=self._kernel[name])
            penalty = self._mmd_scale[name] * divergence + penalty

        warn_if_nan(loglikelihood, "loglikelihood")
        warn_if_nan(penalty, "penalty")
        return loglikelihood, penalty