jobs/kpi-forecasting/kpi_forecasting/models/base_forecast.py (116 lines of code) (raw):
import json
import numpy as np
import pandas as pd
import abc
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from kpi_forecasting.metric_hub import MetricHub
from typing import Dict, List
@dataclass
class BaseForecast(abc.ABC):
"""
A base class for fitting, forecasting, and summarizing forecasts. This class
should not be invoked directly; it should be inherited by a child class. The
child class needs to implement `_fit` and `_forecast` methods in order to work.
Args:
model_type (str): The name of the forecasting model that's being used.
parameters (Dict): Parameters that should be passed to the forecasting model.
use_all_us_holidays (bool): Whether or not the forecasting model should use holidays.
The base model does not apply holiday logic; that logic needs to be built
in the child class.
start_date (str): A 'YYYY-MM-DD' formatted-string that specifies the first
date that should be forecsted.
end_date (str): A 'YYYY-MM-DD' formatted-string that specifies the last
date the metric should be queried.
metric_hub (MetricHub): A MetricHub object that provides details about the
metric to be forecasted.
predict_historical_dates (bool): If True, forecast starts at the first
date in the observed data. If False, it uses the value of start_date it set
and the first day after the observed data ends otherwise
"""
model_type: str
parameters: Dict
use_all_us_holidays: bool
start_date: str
end_date: str
metric_hub: MetricHub
predict_historical_dates: bool = False
def _get_observed_data(self):
if self.metric_hub:
# the columns in this dataframe
# are "value" for the metric, submission_date
# and any segments where the column name
# is the name of the segment
self.observed_df = self.metric_hub.fetch()
def __post_init__(self) -> None:
# fetch observed observed data
self.collected_at = datetime.now(timezone.utc).replace(tzinfo=None)
self._get_observed_data()
# raise an error is predict_historical_dates is True and start_date is set
if self.start_date and self.predict_historical_dates:
raise ValueError(
"forecast start_date set while predict_historical_dates is True"
)
# use default start/end dates if the user doesn't specify them
self.start_date = pd.to_datetime(self.start_date or self._default_start_date)
self.end_date = pd.to_datetime(self.end_date or self._default_end_date)
self.dates_to_predict = pd.DataFrame(
{"submission_date": pd.date_range(self.start_date, self.end_date).date}
)
# initialize unset attributes
self.model = None
self.forecast_df = None
self.summary_df = None
# metadata
self.metadata_params = json.dumps(
{
"model_type": self.model_type.lower(),
"model_params": self.parameters,
"use_all_us_holidays": self.use_all_us_holidays,
}
)
@abc.abstractmethod
def _fit(self, observed_df: pd.DataFrame) -> None:
"""Fit a forecasting model using `observed_df.` This will typically
be the data that was generated using
Metric Hub in `__post_init__`.
This method should update (and potentially set) `self.model`.
Args:
observed_df (pd.DataFrame): observed data used to fit the model
"""
raise NotImplementedError
@abc.abstractmethod
def _predict(self, dates_to_predict: pd.DataFrame) -> pd.DataFrame:
"""Forecast using `self.model` on dates in `dates_to_predict`.
This method should return a dataframe that will
be validated by `_validate_forecast_df`.
Args:
dates_to_predict (pd.DataFrame): dataframe of dates to forecast for
Returns:
pd.DataFrame: dataframe of predictions
"""
raise NotImplementedError
@abc.abstractmethod
def _validate_forecast_df(self, forecast_df: pd.DataFrame) -> None:
"""Method to validate reults produced by _predict
Args:
forecast_df (pd.DataFrame): dataframe produced by `_predict`"""
raise NotImplementedError
@abc.abstractmethod
def _summarize(
self,
forecast_df: pd.DataFrame,
observed_df: pd.DataFrame,
period: str,
numpy_aggregations: List[str],
percentiles: List[int],
) -> pd.DataFrame:
"""Calculate summary metrics for `forecast_df` over a given period, and
add metadata.
Args:
forecast_df (pd.DataFrame): forecast dataframe created by `predict`
observed_df (pd.DataFrame): observed data used to generate prediction
period (str): aggregation period up to which metrics are aggregated
numpy_aggregations (List[str]): List of numpy aggregation names
percentiles (List[int]): List of percentiles to aggregate up to
Returns:
pd.DataFrame: dataframe containing metrics listed in numpy_aggregations
and percentiles
"""
raise NotImplementedError
@property
def _default_start_date(self) -> str:
"""The first day after the last date in the observed dataset."""
if self.predict_historical_dates:
return self.observed_df["submission_date"].min()
else:
return self.observed_df["submission_date"].max() + timedelta(days=1)
@property
def _default_end_date(self) -> str:
"""78 weeks (18 months) ahead of the current UTC date."""
return (
datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(weeks=78)
).date()
def _set_seed(self) -> None:
"""Set random seed to ensure that fits and predictions are reproducible."""
np.random.seed(42)
def fit(self) -> None:
"""Fit a model using historic metric data provided by `metric_hub`."""
print(f"Fitting {self.model_type} model.", flush=True)
self._set_seed()
self.trained_at = datetime.now(timezone.utc).replace(tzinfo=None)
self._fit(self.observed_df)
def predict(self) -> None:
"""Generate a forecast from `start_date` to `end_date`.
Result is set to `self.forecast_df`"""
print(f"Forecasting from {self.start_date} to {self.end_date}.", flush=True)
self._set_seed()
self.predicted_at = datetime.now(timezone.utc).replace(tzinfo=None)
self.forecast_df = self._predict(self.dates_to_predict)
self._validate_forecast_df(self.forecast_df)
def summarize(
self,
periods: List[str] = ["day", "month"],
numpy_aggregations: List[str] = ["mean"],
percentiles: List[int] = [10, 50, 90],
) -> pd.DataFrame:
"""
Calculate summary metrics for `forecast_df` and add metadata.
The dataframe returned here will be reported in Big Query when
`write_results` is called.
Args:
periods (List[str]): A list of the time periods that the data should be aggregated and
summarized by. For example ["day", "month"]
numpy_aggregations (List[str]): A list of numpy methods (represented as strings) that can
be applied to summarize numeric values in a numpy dataframe. For example, ["mean"].
percentiles (List[int]): A list of integers representing the percentiles that should be reported
in the summary. For example [50] would calculate the 50th percentile (i.e. the median).
Returns:
pd.DataFrame: metric dataframe for all metrics and aggregations
"""
summary_df = pd.concat(
[
self._summarize(
self.forecast_df,
self.observed_df,
i,
numpy_aggregations,
percentiles,
)
for i in periods
]
)
# add Metric Hub metadata columns
summary_df["metric_alias"] = self.metric_hub.alias.lower()
summary_df["metric_hub_app_name"] = self.metric_hub.app_name.lower()
summary_df["metric_hub_slug"] = self.metric_hub.slug.lower()
summary_df["metric_start_date"] = pd.to_datetime(self.metric_hub.min_date)
summary_df["metric_end_date"] = pd.to_datetime(self.metric_hub.max_date)
summary_df["metric_collected_at"] = self.collected_at
# add forecast model metadata columns
summary_df["forecast_start_date"] = self.start_date
summary_df["forecast_end_date"] = self.end_date
summary_df["forecast_trained_at"] = self.trained_at
summary_df["forecast_predicted_at"] = self.predicted_at
summary_df["forecast_parameters"] = self.metadata_params
self.summary_df = summary_df
return self.summary_df