in src/beanmachine/ppl/experimental/abc/abc_infer.py [0:0]
def _single_inference_step(self) -> int:
"""
Single inference step of the vanilla ABC algorithm which attempts to obtain a sample.
Samples are generated from the prior of the node to be observed, and their summary statistic is
compared with summary statistic of provided observations. If distance is within provided
tolerence values, the sample is accepted.
:returns: 1 if sample is accepted and 0 if sample is rejected (used to update the tqdm iterator)
"""
self.world_ = World()
self.world_.set_initialize_from_prior(True)
self.world_.set_maintain_graph(False)
self.world_.set_cache_functionals(True)
if self.simulate:
# in simulate mode, user passes obtained samples as observations and shall query nodes to be
# simulated. This required observations to set instead of being sampled from prior
self.world_.set_observations(self.observations_)
for summary_statistic, observed_summary in self.observations_.items():
# makes the call for the summary statistic node, which will run sample(node())
# that results in adding its corresponding Variable and its dependent
# Variable to the world, as well as computing it's value
computed_summary = self.world_.call(summary_statistic)
if self.simulate:
# if we are simulating, simply accept sample
self._accept_sample()
return 1
# check if passed observation is a tensor, if not, cast it
if not torch.is_tensor(observed_summary):
observed_summary = torch.tensor(observed_summary)
# check if the shapes of computed and provided summary matches
if computed_summary.shape != observed_summary.shape:
raise ValueError(
f"Shape mismatch in random variable {summary_statistic}"
+ "\nshape does not match with observation\n"
+ f"Expected observation shape: {computed_summary.shape};"
+ f"Provided observation shape{observed_summary.shape}"
)
# if user passed a dict for distance functions, load from it, else load default
if isinstance(self.distance_function, dict):
distance_function = self.distance_function[summary_statistic]
else:
distance_function = self.distance_function
# we allow users to pass either a dict or a single value for tolerance
if isinstance(self.tolerance, dict):
tolerance = self.tolerance[summary_statistic]
else:
tolerance = self.tolerance
# perform rejection
reject = torch.gt(
distance_function(computed_summary.float(), observed_summary.float()),
tolerance,
)
if reject:
self._reject_sample(node_key=summary_statistic)
return 0
self._accept_sample()
return 1