def forward()

in kats/models/globalmodel/utils.py [0:0]


    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        """Forward method of AdjustedPinballLoss module.

        Args:
            input: A `torch.Tensor` object representing the forecasts of shape (num, n_steps * n_quantiles), where n_quantiles is the length of quantile.
            target: A `torch.Tensor` object representing true values of shape (num, n_steps)

        Returns:
            A 1-dimensional `torch.Tensor` object representing the computed pinball loss of length the number of quantiles.
        """

        self._check(input, target)
        n = len(input)
        m = len(self.quantile)
        horizon = target.size()[1]
        nans = torch.isnan(target).detach()
        # avoid nans appear in the loss
        target[nans] = 1.0
        num_not_nan = (~nans).float().sum(dim=1)
        num_not_nan[num_not_nan == 0] += 1

        if self.input_log:
            target_exp = torch.exp(target)
            fcst_exp = torch.exp(input[:, :horizon])
        else:
            target_exp = target
            fcst_exp = input[:, :horizon]
        diff = target_exp - fcst_exp
        res = (
            torch.max(diff * self.quantile[0], diff * (self.quantile[0] - 1.0))
            / (target_exp + fcst_exp)
            * 2
        )
        res[nans] = 0.0

        if m > 1:
            if self.input_log:
                fcst = input[:, horizon:]
            else:
                fcst = torch.log(input[:, horizon:])
            m -= 1

            target = target.repeat(1, m)
            nans = nans.repeat(1, m)
            quants = self.quantile[1:].repeat(horizon, 1).t().flatten()

            diff_q = target - fcst
            res_q = torch.max(diff_q * quants, diff_q * (quants - 1.0))
            res_q[nans] = 0.0

            res = torch.cat([res, res_q], dim=1)

        weights = self.weight.repeat(horizon, 1).t().flatten()
        res = res * weights
        res = res.view(n, -1, horizon).sum(dim=2) / num_not_nan[:, None]

        if self.reduction == "mean":
            return res.mean(dim=0)
        if self.reduction == "sum":
            return res.sum(dim=0)

        return res