in src/beanmachine/ppl/experimental/abc/abc_smc_infer.py [0:0]
def _single_inference_step(self, stage: int) -> int:
"""
Single inference step of the ABC-SMC algorithm which attempts to obtain a sample. In the first
stage, samples are generated from the prior of the node to be observed, and their summary
statistic is compared with summary statistic of provided observations. If distance is within
provided tolerence values, the sample is accepted. In concequent stages, the samples are drawn
from the pool of accepted samples of the last stage, perturbed using a perturbation kernel and
then the summray statistic is computed and compared for accept/reject. Each stage has a different
tolerance value.
:param stage: the stage of ABC-SMC inference used to choose from the tolerance schedule
:returns: 1 if sample is accepted and 0 if sample is rejected (used to update the tqdm iterator)
"""
self.world_ = World()
self.world_.set_initialize_from_prior(True)
self.world_.set_maintain_graph(False)
self.world_.set_cache_functionals(True)
if not stage == 0:
weighted_sample = self.weighted_sample_draw()
perturbed_sample_draw = self.perturb_kernel(weighted_sample, stage)
# although this method is used to set observations, we use it here to set values RVs
# to the generated perturbations
self.world_.set_observations(perturbed_sample_draw)
weights = []
for summary_statistic, observed_summary in self.observations_.items():
# makes the call for the summary statistic node, which will run sample(node())
# that results in adding its corresponding Variable and its dependent
# Variable to the world, as well as computing it's value
computed_summary = self.world_.call(summary_statistic)
# check if passed observation is a tensor, if not, cast it
if not torch.is_tensor(observed_summary):
observed_summary = torch.tensor(observed_summary)
# check if the shapes of computed and provided summary matches
if computed_summary.shape != observed_summary.shape:
raise ValueError(
f"Shape mismatch in random variable {summary_statistic}"
+ "\nshape does not match with observation\n"
+ f"Expected observation shape: {computed_summary.shape};"
+ f"Provided observation shape{observed_summary.shape}"
)
# if user passed a dict for distance functions, load from it, else load default
if isinstance(self.distance_function, dict):
distance_function = self.distance_function[summary_statistic]
else:
distance_function = self.distance_function
# we allow users to pass either a dict or a single value for tolerance
if isinstance(self.tolerance_schedule, dict):
tolerance = self.tolerance_schedule[summary_statistic][stage]
else:
tolerance = self.tolerance_schedule[stage]
# perform rejection
distance = distance_function(
computed_summary.float(), observed_summary.float()
)
reject = torch.gt(distance, tolerance)
if reject:
self._reject_sample(node_key=summary_statistic)
return 0
weights.append((tolerance - distance) / tolerance)
self.queries_sample_weights.append(torch.mean(torch.stack(weights)))
self._accept_sample()
return 1