in src/gluonts/nursery/spliced_binned_pareto/spliced_binned_pareto.py [0:0]
def log_binned_p(self, xx):
"""
Log probability for one datapoint.
"""
assert xx.shape.numel() == 1, "log_binned_p() expects univariate"
# Transform xx in to a one-hot encoded vector to get bin location
vect_above = xx - self.bin_edges[1:]
vect_below = self.bin_edges[:-1] - xx
one_hot_bin_indicator = (vect_above * vect_below >= 0).float()
if xx > self.bin_edges[-1]:
one_hot_bin_indicator[-1] = 1.0
elif xx < self.bin_edges[0]:
one_hot_bin_indicator[0] = 1.0
if not (one_hot_bin_indicator == 1).sum() == 1:
print(
f"Warning in log_p(self, xx): for xx={xx.item()}, one_hot_bin_indicator value_counts are {one_hot_bin_indicator.unique(return_counts=True)}"
)
if self.smooth_indicator == "kernel":
# The kernel variant is better but slows down training quite a bit
idx_one_hot = torch.argmax(one_hot_bin_indicator)
kernel = [0.006, 0.061, 0.242, 0.383, 0.242, 0.061, 0.006]
len_kernel = len(kernel)
for i in range(len_kernel):
idx = i - len_kernel // 2 + idx_one_hot
if idx in range(len(one_hot_bin_indicator)):
one_hot_bin_indicator[idx] = kernel[i]
elif self.smooth_indicator == "cheap":
# This variant is cheaper in computation time
idx_one_hot = torch.argmax(one_hot_bin_indicator)
if not idx_one_hot + 1 >= len(one_hot_bin_indicator):
one_hot_bin_indicator[idx_one_hot + 1] = 0.5
if not idx_one_hot - 1 < 0:
one_hot_bin_indicator[idx_one_hot - 1] = 0.5
if not idx_one_hot + 2 >= len(one_hot_bin_indicator):
one_hot_bin_indicator[idx_one_hot + 2] = 0.25
if not idx_one_hot - 2 < 0:
one_hot_bin_indicator[idx_one_hot - 2] = 0.25
logp = torch.dot(one_hot_bin_indicator, self.log_bins_prob())
return logp