in kats/models/globalmodel/utils.py [0:0]
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward method of AdjustedPinballLoss module.
Args:
input: A `torch.Tensor` object representing the forecasts of shape (num, n_steps * n_quantiles), where n_quantiles is the length of quantile.
target: A `torch.Tensor` object representing true values of shape (num, n_steps)
Returns:
A 1-dimensional `torch.Tensor` object representing the computed pinball loss of length the number of quantiles.
"""
self._check(input, target)
n = len(input)
m = len(self.quantile)
horizon = target.size()[1]
nans = torch.isnan(target).detach()
# avoid nans appear in the loss
target[nans] = 1.0
num_not_nan = (~nans).float().sum(dim=1)
num_not_nan[num_not_nan == 0] += 1
if self.input_log:
target_exp = torch.exp(target)
fcst_exp = torch.exp(input[:, :horizon])
else:
target_exp = target
fcst_exp = input[:, :horizon]
diff = target_exp - fcst_exp
res = (
torch.max(diff * self.quantile[0], diff * (self.quantile[0] - 1.0))
/ (target_exp + fcst_exp)
* 2
)
res[nans] = 0.0
if m > 1:
if self.input_log:
fcst = input[:, horizon:]
else:
fcst = torch.log(input[:, horizon:])
m -= 1
target = target.repeat(1, m)
nans = nans.repeat(1, m)
quants = self.quantile[1:].repeat(horizon, 1).t().flatten()
diff_q = target - fcst
res_q = torch.max(diff_q * quants, diff_q * (quants - 1.0))
res_q[nans] = 0.0
res = torch.cat([res, res_q], dim=1)
weights = self.weight.repeat(horizon, 1).t().flatten()
res = res * weights
res = res.view(n, -1, horizon).sum(dim=2) / num_not_nan[:, None]
if self.reduction == "mean":
return res.mean(dim=0)
if self.reduction == "sum":
return res.sum(dim=0)
return res