def _single_inference_step()

in src/beanmachine/ppl/experimental/abc/adaptive_abc_smc_infer.py [0:0]


    def _single_inference_step(self, stage: int) -> int:
        """
        Single inference step of the adaptive ABC-SMC algorithm which attempts to obtain a sample.
        In the first stage, samples are generated from the prior of the node to be observed, and their
        summary statistic is compared with summary statistic of provided observations. If distance is
        within provided tolerence values, the sample is accepted. In concequent stages, the samples are
        drawn from accepted samples of the last stage, perturbed using a perturbation kernel and
        then the summray statistic is computed and compared for accept/reject. Each stage has a different
        tolerance value which is computed from the max distance of accepted sample from last stage.
        :param stage: the stage of Adaptive ABC-SMC inference used to choose from the tolerance schedule
        :returns: 1 if sample is accepted and 0 if sample is rejected (used to update the tqdm iterator)
        """
        self.world_ = World()
        self.world_.set_initialize_from_prior(True)
        self.world_.set_maintain_graph(False)
        self.world_.set_cache_functionals(True)

        if not stage == 0:
            # do a weighted sample draw and perturb step
            weighted_sample = self.weighted_sample_draw()
            self.perturb_kernel(weighted_sample, stage)

        weights = []
        for summary_statistic, observed_summary in self.observations_.items():
            # makes the call for the summary statistic node, which will run sample(node())
            # that results in adding its corresponding Variable and its dependent
            # Variable to the world, as well as computing it's value
            computed_summary = self.world_.call(summary_statistic)
            # check if passed observation is a tensor, if not, cast it
            if not torch.is_tensor(observed_summary):
                observed_summary = torch.tensor(observed_summary)
            # check if the shapes of computed and provided summary matches
            if computed_summary.shape != observed_summary.shape:
                raise ValueError(
                    f"Shape mismatch in random variable {summary_statistic}"
                    + "\nshape does not match with observation\n"
                    + f"Expected observation shape: {computed_summary.shape};"
                    + f"Provided observation shape{observed_summary.shape}"
                )

            # if user passed a dict for distance functions, load from it, else load default
            if isinstance(self.distance_function, dict):
                distance_function = self.distance_function[summary_statistic]
            else:
                distance_function = self.distance_function

            # perform rejection
            distance = distance_function(
                computed_summary.float(), observed_summary.float()
            )
            reject = torch.gt(distance, self.tolerance)
            if reject:
                self._reject_sample(node_key=summary_statistic)
                return 0
            weights.append(distance)
        self.queries_sample_weights.append(torch.mean(torch.stack(weights)))
        self._accept_sample()
        return 1