def _initialize_adapter()

in pyro/infer/mcmc/hmc.py [0:0]


    def _initialize_adapter(self):
        if self._adapter.dense_mass is False:
            dense_sites_list = []
        elif self._adapter.dense_mass is True:
            dense_sites_list = [tuple(sorted(self.initial_params))]
        else:
            msg = "full_mass should be a list of tuples of site names."
            dense_sites_list = self._adapter.dense_mass
            assert isinstance(dense_sites_list, list), msg
            for dense_sites in dense_sites_list:
                assert dense_sites and isinstance(dense_sites, tuple), msg
                for name in dense_sites:
                    assert isinstance(name, str) and name in self.initial_params, msg
        dense_sites_set = set().union(*dense_sites_list)
        diag_sites = tuple(sorted([name for name in self.initial_params
                                   if name not in dense_sites_set]))
        assert len(diag_sites) + sum([len(sites) for sites in dense_sites_list]) == len(self.initial_params), \
            "Site names specified in full_mass are duplicated."

        mass_matrix_shape = OrderedDict()
        for dense_sites in dense_sites_list:
            size = sum([self.initial_params[site].numel() for site in dense_sites])
            mass_matrix_shape[dense_sites] = (size, size)

        if diag_sites:
            size = sum([self.initial_params[site].numel() for site in diag_sites])
            mass_matrix_shape[diag_sites] = (size,)

        options = {"dtype": self._potential_energy_last.dtype,
                   "device": self._potential_energy_last.device}
        self._adapter.configure(self._warmup_steps,
                                mass_matrix_shape=mass_matrix_shape,
                                find_reasonable_step_size_fn=self._find_reasonable_step_size,
                                options=options)

        if self._adapter.adapt_step_size:
            self._adapter.reset_step_size_adaptation(self._initial_params)