def log_binned_p()

in src/gluonts/nursery/spliced_binned_pareto/spliced_binned_pareto.py [0:0]


    def log_binned_p(self, xx):
        """
        Log probability for one datapoint.
        """
        assert xx.shape.numel() == 1, "log_binned_p() expects univariate"

        # Transform xx in to a one-hot encoded vector to get bin location
        vect_above = xx - self.bin_edges[1:]
        vect_below = self.bin_edges[:-1] - xx
        one_hot_bin_indicator = (vect_above * vect_below >= 0).float()

        if xx > self.bin_edges[-1]:
            one_hot_bin_indicator[-1] = 1.0
        elif xx < self.bin_edges[0]:
            one_hot_bin_indicator[0] = 1.0
        if not (one_hot_bin_indicator == 1).sum() == 1:
            print(
                f"Warning in log_p(self, xx): for xx={xx.item()}, one_hot_bin_indicator value_counts are {one_hot_bin_indicator.unique(return_counts=True)}"
            )

        if self.smooth_indicator == "kernel":
            # The kernel variant is better but slows down training quite a bit
            idx_one_hot = torch.argmax(one_hot_bin_indicator)
            kernel = [0.006, 0.061, 0.242, 0.383, 0.242, 0.061, 0.006]
            len_kernel = len(kernel)
            for i in range(len_kernel):
                idx = i - len_kernel // 2 + idx_one_hot
                if idx in range(len(one_hot_bin_indicator)):
                    one_hot_bin_indicator[idx] = kernel[i]

        elif self.smooth_indicator == "cheap":
            # This variant is cheaper in computation time
            idx_one_hot = torch.argmax(one_hot_bin_indicator)
            if not idx_one_hot + 1 >= len(one_hot_bin_indicator):
                one_hot_bin_indicator[idx_one_hot + 1] = 0.5
            if not idx_one_hot - 1 < 0:
                one_hot_bin_indicator[idx_one_hot - 1] = 0.5
            if not idx_one_hot + 2 >= len(one_hot_bin_indicator):
                one_hot_bin_indicator[idx_one_hot + 2] = 0.25
            if not idx_one_hot - 2 < 0:
                one_hot_bin_indicator[idx_one_hot - 2] = 0.25

        logp = torch.dot(one_hot_bin_indicator, self.log_bins_prob())
        return logp