def _set_levs_and_seas()

in orbit/template/ktr.py [0:0]


    def _set_levs_and_seas(self, df, training_meta):
        response_col = training_meta["response_col"]
        date_col = training_meta[TrainingMetaKeys.DATE_COL.value]
        num_of_observations = training_meta[TrainingMetaKeys.NUM_OF_OBS.value]
        date_array = training_meta[TrainingMetaKeys.DATE_ARRAY.value]

        # use ktrlite to derive levs and seas
        ktrlite = KTRLite(
            response_col=response_col,
            date_col=date_col,
            level_knot_scale=self.level_knot_scale,
            level_segments=self.level_segments,
            level_knot_dates=self.level_knot_dates,
            level_knot_distance=self.level_knot_distance,
            seasonality=self.seasonality,
            seasonality_fs_order=self.seasonality_fs_order,
            seasonal_initial_knot_scale=self.seasonal_initial_knot_scale,
            seasonal_knot_scale=self.seasonal_knot_scale,
            seasonality_segments=self.seasonality_segments,
            degree_of_freedom=self.degree_of_freedom,
            date_freq=self.date_freq,
            estimator="stan-map",
            **self.ktrlite_optim_args,
        )
        ktrlite.fit(df=df)
        # self._ktrlite_model = ktrlite
        ktrlite_pt_posteriors = ktrlite.get_point_posteriors()
        ktrlite_obs_scale = ktrlite_pt_posteriors["map"]["obs_scale"]

        # load _seasonality and _seasonality_fs_order
        self._seasonality = ktrlite._model._seasonality
        self._seasonality_fs_order = ktrlite._model._seasonality_fs_order
        self._seasonality_labels = list()
        for seas in self._seasonality:
            self._seasonality_labels.append("seasonality_{}".format(seas))

        # if input None for upper bound of residuals scale, use data-driven input
        if self.residuals_scale_upper is None:
            # make it 5 times to have some buffer in case we over-fit in KTRLite
            self._residuals_scale_upper = min(
                ktrlite_obs_scale * 5, training_meta["response_sd"]
            )

        # this part is to extract level and seasonality result from KTRLite
        self._level_knots = np.squeeze(ktrlite_pt_posteriors["map"]["lev_knot"])
        self._level_knot_dates = ktrlite._model._level_knot_dates
        tp = np.arange(1, num_of_observations + 1) / num_of_observations
        # # trim level knots dates when they are beyond training dates
        # lev_knot_dates = list()
        # lev_knots = list()
        # for i, x in enumerate(self.level_knot_dates):
        #     if (x <= df[date_col].max()) and (x >= df[date_col].min()):
        #         lev_knot_dates.append(x)
        #         lev_knots.append(self._level_knots[i])
        # self._level_knot_dates = pd.to_datetime(lev_knot_dates)
        # self._level_knots = np.array(lev_knots)

        self._level_knots_idx = get_knot_idx(
            date_array=date_array,
            num_of_obs=None,
            knot_dates=self._level_knot_dates,
            knot_distance=None,
            num_of_segments=None,
        )
        self.knots_tp_level = (1 + self._level_knots_idx) / num_of_observations
        self._kernel_level = sandwich_kernel(tp, self.knots_tp_level)
        self._num_knots_level = len(self._level_knot_dates)

        if self._seasonality:
            self._seasonality_coef_knot_dates = ktrlite._model._coef_knot_dates
            coef_knots_flatten = ktrlite_pt_posteriors["map"]["coef_knot"]
            coef_knots = dict()
            pos = 0
            for idx, label in enumerate(self._seasonality_labels):
                order = self._seasonality_fs_order[idx]
                coef_knots[label] = coef_knots_flatten[..., pos : (pos + 2 * order), :]
                pos += 2 * order
            self._seasonality_coef_knots = coef_knots

            # we just need total here and because of
            self._seas_term, _ = self._generate_seas(
                df,
                training_meta,
                self._seasonality_coef_knot_dates,
                self._seasonality_coef_knots,
                self._seasonality,
                self._seasonality_fs_order,
                self._seasonality_labels,
            )
            # remove batch size as an input for models
            self._seas_term = np.squeeze(self._seas_term, 0)