kats/models/globalmodel/utils.py [659:691]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def _check(self, input: Tensor, target: Tensor) -> None:
        """
        Check input tensor and target tensor size.
        """
        if target.size()[0] != input.size()[0]:
            msg = "Input batch size is not equal to target batch size."
            logging.error(msg)
            raise ValueError(msg)
        num_feature = target.size()[1] * len(self.quantile)
        if input.size()[1] != num_feature:
            msg = f"Input should contain {num_feature} features but receive {input.size()[1]}."
            logging.error(msg)
            raise ValueError(msg)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        """
        Args:
            input: A `torch.Tensor` object representing forecasted values of shape (num, n_steps * n_quantiles), where n_quantiles is the length of quantile.
            target: A `torch.Tensor` object contianing true values of shape (num, n_steps).

        Returns:
            A 1-dimensional `torch.Tensor` object representing the computed pinball loss of length the number of quantiles.
        """

        self._check(input, target)
        n = len(input)
        m = len(self.quantile)
        horizon = target.size()[1]
        nans = torch.isnan(target).detach()
        # clean up NaNs to avoid NaNs in gradient
        target[nans] = 1.0
        num_not_nan = (~nans).float().sum(dim=1)
        num_not_nan[num_not_nan == 0] += 1
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



kats/models/globalmodel/utils.py [766:799]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def _check(self, input: Tensor, target: Tensor) -> None:
        """
        Check input tensor and target tensor size.
        """
        if target.size()[0] != input.size()[0]:
            msg = "Input batch size is not equal to target batch size."
            logging.error(msg)
            raise ValueError(msg)
        num_feature = target.size()[1] * len(self.quantile)
        if input.size()[1] != num_feature:
            msg = f"Input should contain {num_feature} features but receive {input.size()[1]}."
            logging.error(msg)
            raise ValueError(msg)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        """Forward method of AdjustedPinballLoss module.

        Args:
            input: A `torch.Tensor` object representing the forecasts of shape (num, n_steps * n_quantiles), where n_quantiles is the length of quantile.
            target: A `torch.Tensor` object representing true values of shape (num, n_steps)

        Returns:
            A 1-dimensional `torch.Tensor` object representing the computed pinball loss of length the number of quantiles.
        """

        self._check(input, target)
        n = len(input)
        m = len(self.quantile)
        horizon = target.size()[1]
        nans = torch.isnan(target).detach()
        # avoid nans appear in the loss
        target[nans] = 1.0
        num_not_nan = (~nans).float().sum(dim=1)
        num_not_nan[num_not_nan == 0] += 1
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



