def _single_inference_step()

in src/beanmachine/ppl/experimental/abc/abc_smc_infer.py [0:0]


    def _single_inference_step(self, stage: int) -> int:
        """
        Single inference step of the ABC-SMC algorithm which attempts to obtain a sample. In the first
        stage, samples are generated from the prior of the node to be observed, and their summary
        statistic is compared with summary statistic of provided observations. If distance is within
        provided tolerence values, the sample is accepted. In concequent stages, the samples are drawn
        from the pool of accepted samples of the last stage, perturbed using a perturbation kernel and
        then the summray statistic is computed and compared for accept/reject. Each stage has a different
        tolerance value.
        :param stage: the stage of ABC-SMC inference used to choose from the tolerance schedule
        :returns: 1 if sample is accepted and 0 if sample is rejected (used to update the tqdm iterator)
        """
        self.world_ = World()
        self.world_.set_initialize_from_prior(True)
        self.world_.set_maintain_graph(False)
        self.world_.set_cache_functionals(True)

        if not stage == 0:
            weighted_sample = self.weighted_sample_draw()
            perturbed_sample_draw = self.perturb_kernel(weighted_sample, stage)
            # although this method is used to set observations, we use it here to set values RVs
            # to the generated perturbations
            self.world_.set_observations(perturbed_sample_draw)
        weights = []
        for summary_statistic, observed_summary in self.observations_.items():
            # makes the call for the summary statistic node, which will run sample(node())
            # that results in adding its corresponding Variable and its dependent
            # Variable to the world, as well as computing it's value
            computed_summary = self.world_.call(summary_statistic)
            # check if passed observation is a tensor, if not, cast it
            if not torch.is_tensor(observed_summary):
                observed_summary = torch.tensor(observed_summary)
            # check if the shapes of computed and provided summary matches
            if computed_summary.shape != observed_summary.shape:
                raise ValueError(
                    f"Shape mismatch in random variable {summary_statistic}"
                    + "\nshape does not match with observation\n"
                    + f"Expected observation shape: {computed_summary.shape};"
                    + f"Provided observation shape{observed_summary.shape}"
                )

            # if user passed a dict for distance functions, load from it, else load default
            if isinstance(self.distance_function, dict):
                distance_function = self.distance_function[summary_statistic]
            else:
                distance_function = self.distance_function
            # we allow users to pass either a dict or a single value for tolerance
            if isinstance(self.tolerance_schedule, dict):
                tolerance = self.tolerance_schedule[summary_statistic][stage]
            else:
                tolerance = self.tolerance_schedule[stage]

            # perform rejection
            distance = distance_function(
                computed_summary.float(), observed_summary.float()
            )
            reject = torch.gt(distance, tolerance)
            if reject:
                self._reject_sample(node_key=summary_statistic)
                return 0
            weights.append((tolerance - distance) / tolerance)
        self.queries_sample_weights.append(torch.mean(torch.stack(weights)))
        self._accept_sample()
        return 1