in pyro/infer/mcmc/api.py [0:0]
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
num_chains=1, hook_fn=None, mp_context=None, disable_progbar=False,
disable_validation=True, transforms=None):
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.kernel = kernel
self.transforms = transforms
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")
parallel = False
if num_chains > 1:
# check that initial_params is different for each chain
if initial_params:
for v in initial_params.values():
if v.shape[0] != num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")
# FIXME: probably we want to use "spawn" method by default to avoid the error
# CUDA initialization error https://github.com/pytorch/pytorch/issues/2517
# even that we run MCMC in CPU.
if mp_context is None:
# change multiprocessing context to 'spawn' for CUDA tensors.
if list(initial_params.values())[0].is_cuda:
mp_context = "spawn"
# verify num_chains is compatible with available CPU.
available_cpu = max(mp.cpu_count() - 1, 1) # reserving 1 for the main process.
if num_chains <= available_cpu:
parallel = True
else:
warnings.warn("num_chains={} is more than available_cpu={}. "
"Chains will be drawn sequentially."
.format(num_chains, available_cpu))
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}
self.num_chains = num_chains
self._diagnostics = [None] * num_chains
if parallel:
self.sampler = _MultiSampler(kernel, num_samples, self.warmup_steps, num_chains, mp_context,
disable_progbar, initial_params=initial_params, hook=hook_fn)
else:
self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar,
initial_params=initial_params, hook=hook_fn)