def predict()

in kats/models/lstm.py [0:0]


    def predict(self, steps: int, *args: Any, **kwargs: Any) -> pd.DataFrame:
        """Prediction function for a multi-step forecast

        Args:
            steps: number of steps for the forecast

        Returns:
            A pd.DataFrame that includes the forecast and confidence interval
        """
        model = self.model
        if model is None:
            raise ValueError("Call fit() before predict()")
        train_data_normalized = self.train_data_normalized
        scaler = self.scaler
        assert train_data_normalized is not None and scaler is not None
        time_window = self.params.time_window
        hidden_size = self.params.hidden_size

        logging.debug(
            "Call predict() with parameters. " f"steps:{steps}, kwargs:{kwargs}"
        )
        self.freq = kwargs.get("freq", pd.infer_freq(self.data.time))

        model.eval()

        # get last train input sequence
        test_inputs = train_data_normalized[-time_window:].tolist()

        for _ in range(steps):
            seq = torch.FloatTensor(test_inputs[-time_window:])
            with torch.no_grad():
                model.hidden_cell = (
                    torch.zeros(1, 1, hidden_size),
                    torch.zeros(1, 1, hidden_size),
                )
                test_inputs.append(model(seq).item())

        # inverse transform
        fcst_denormalized = scaler.inverse_transform(
            np.array(test_inputs[time_window:]).reshape(-1, 1)
        ).flatten()
        logging.info("Generated forecast data from LSTM model.")
        logging.debug(f"Forecast data: {fcst_denormalized}")

        last_date = self.data.time.max()
        dates = pd.date_range(start=last_date, periods=steps + 1, freq=self.freq)
        self.dates = dates[dates != last_date]  # Return correct number of periods
        self.y_fcst = fcst_denormalized
        self.y_fcst_lower = fcst_denormalized * 0.95
        self.y_fcst_upper = fcst_denormalized * 1.05

        self.fcst_df = fcst_df = pd.DataFrame(
            {
                "time": self.dates,
                "fcst": self.y_fcst,
                "fcst_lower": self.y_fcst_lower,
                "fcst_upper": self.y_fcst_upper,
            }
        )

        logging.debug(f"Return forecast data: {self.fcst_df}")

        return fcst_df