in pyro/infer/mcmc/hmc.py [0:0]
def _initialize_adapter(self):
if self._adapter.dense_mass is False:
dense_sites_list = []
elif self._adapter.dense_mass is True:
dense_sites_list = [tuple(sorted(self.initial_params))]
else:
msg = "full_mass should be a list of tuples of site names."
dense_sites_list = self._adapter.dense_mass
assert isinstance(dense_sites_list, list), msg
for dense_sites in dense_sites_list:
assert dense_sites and isinstance(dense_sites, tuple), msg
for name in dense_sites:
assert isinstance(name, str) and name in self.initial_params, msg
dense_sites_set = set().union(*dense_sites_list)
diag_sites = tuple(sorted([name for name in self.initial_params
if name not in dense_sites_set]))
assert len(diag_sites) + sum([len(sites) for sites in dense_sites_list]) == len(self.initial_params), \
"Site names specified in full_mass are duplicated."
mass_matrix_shape = OrderedDict()
for dense_sites in dense_sites_list:
size = sum([self.initial_params[site].numel() for site in dense_sites])
mass_matrix_shape[dense_sites] = (size, size)
if diag_sites:
size = sum([self.initial_params[site].numel() for site in diag_sites])
mass_matrix_shape[diag_sites] = (size,)
options = {"dtype": self._potential_energy_last.dtype,
"device": self._potential_energy_last.device}
self._adapter.configure(self._warmup_steps,
mass_matrix_shape=mass_matrix_shape,
find_reasonable_step_size_fn=self._find_reasonable_step_size,
options=options)
if self._adapter.adapt_step_size:
self._adapter.reset_step_size_adaptation(self._initial_params)