trend_getter/scenarios.py (292 lines of code) (raw):
import numpy as np
import pandas as pd
import sqlglot
import prophet
from dataclasses import dataclass
from functools import reduce
from google.cloud import bigquery
import plotly.express as px
import plotly.graph_objects as go
from scipy.interpolate import interp1d
from trend_getter import holidays
from typing import List, Dict
@dataclass
class Scenario:
name: str
slug: str
reference: str
start_date: str
end_date: str
population: str
cdf_x: List[float]
cdf_y: List[float]
def __post_init__(self) -> None:
if "_" in self.slug:
raise ValueError("slug cannot contain '_'.")
if self.slug in ["total", "other"]:
raise ValueError("slug cannot use the name 'total' or 'other'.")
self.start_date = pd.to_datetime(self.start_date)
self.end_date = pd.to_datetime(self.end_date)
self.query_select = f"IFNULL({self.population}, FALSE) AS {self.slug},"
self.cdf = interp1d(
self.cdf_x, self.cdf_y, kind="linear", bounds_error=False, fill_value=(0, 1)
)
self.quantile = interp1d(
self.cdf_y, self.cdf_x, kind="linear", bounds_error=False, fill_value=(0, 1)
)
def sample(self, samples=1):
return self.quantile(np.random.uniform(0, 1, size=samples))
def plot_distribution(self):
fig = px.line(
x=self.cdf_x + [1],
y=self.cdf_y + [1],
template="plotly_white",
title=f'{self.name}<br><sup>"Y% of scenarios have X% or lower subpopulation DAU decrease"</sup>',
)
fig.add_trace(
go.Scatter(
x=self.cdf_x + [1], y=self.cdf_y + [1], mode="markers", showlegend=False
)
)
fig.update_traces(
line=dict(color="black", width=3), marker=dict(size=10, color="black")
)
fig.update_layout(
autosize=False,
xaxis=dict(
tickformat=".0%",
tickangle=0,
title="Subpopulation DAU Decrease",
tickmode="array",
tickvals=self.cdf_x + [1],
),
yaxis=dict(
tickformat=".1%",
title="Cumulative Probability",
tickmode="array",
tickvals=self.cdf_y + [1],
),
)
fig.show()
@dataclass
class ScenarioForecasts:
product_group: List[str]
scenarios: List[Scenario]
countries: List[str]
historical_start_date: str
historical_end_date: str
forecast_end_date: str
project: str = "mozdata"
number_of_simulations: int = 1000
metric: str = "dau"
@property
def column_names_map(self) -> Dict[str, str]:
return {"submission_date": "ds", "value": "y"}
def __post_init__(self) -> None:
self.scenarios = {i.slug: i for i in self.scenarios}
self.historical_dfs = {}
self.historical_forecasts = {}
self.historical_forecast_models = {}
self.scaled_historical_forecasts = {}
self.scenario_forecasts = {}
self.scenario_percent_impacts = {}
self.raw_df = None
def _query_(self) -> None:
if self.metric == "dau":
query = f"""
SELECT submission_date,
IF(country IN ({",".join([f"'{i}'" for i in self.countries])}), country, "ROW") AS country,
{" ".join([i.query_select for i in self.scenarios.values()])}
SUM(dau) AS dau,
FROM `mozdata.telemetry.active_users_aggregates`
WHERE app_name IN ({",".join([f"'{i}'" for i in self.product_group])})
AND submission_date BETWEEN "{self.historical_start_date}" AND "{self.historical_end_date}"
GROUP BY ALL
ORDER BY {", ".join([str(i + 1) for i in range(len(self.scenarios) + 2)])}
"""
elif self.metric == "mau":
query = f"""
SELECT submission_date,
IF(country IN ({",".join([f"'{i}'" for i in self.countries])}), country, "ROW") AS country,
{" ".join([i.query_select for i in self.scenarios.values()])}
SUM(mau) AS dau,
FROM `mozdata.telemetry.active_users_aggregates`
WHERE app_name IN ({",".join([f"'{i}'" for i in self.product_group])})
AND submission_date BETWEEN "{self.historical_start_date}" AND "{self.historical_end_date}"
GROUP BY ALL
ORDER BY {", ".join([str(i + 1) for i in range(len(self.scenarios) + 2)])}
"""
elif self.metric == "engagement":
query = f"""
SELECT submission_date,
IF(country IN ({",".join([f"'{i}'" for i in self.countries])}), country, "ROW") AS country,
{" ".join([i.query_select for i in self.scenarios.values()])}
SUM(dau) / SUM(mau) AS dau,
FROM `moz-fx-data-shared-prod.telemetry.desktop_engagement`
WHERE app_name IN ({",".join([f"'{i}'" for i in self.product_group])})
AND submission_date BETWEEN "{self.historical_start_date}" AND "{self.historical_end_date}"
AND lifecycle_stage = "existing_users"
GROUP BY ALL
ORDER BY {", ".join([str(i + 1) for i in range(len(self.scenarios) + 2)])}
"""
return sqlglot.transpile(query, read="bigquery", pretty=True)[0]
def fetch_data(self, filtered_end_date: str = None) -> None:
if self.raw_df is None:
query = self._query_()
print(f"Fetching Data:\n\n{query}")
self.raw_df = (
bigquery.Client(project=self.project).query(query).to_dataframe()
)
df = self.raw_df.copy(deep=True)
else:
if filtered_end_date is not None:
self.historical_end_date = filtered_end_date
print(f"Truncating data to before:\n\n{self.historical_end_date}")
df = self.raw_df[
pd.to_datetime(self.raw_df.submission_date)
<= pd.to_datetime(self.historical_end_date)
].copy(deep=True)
self.dates_to_predict = pd.DataFrame(
{
"submission_date": pd.date_range(
self.historical_end_date, self.forecast_end_date
).date[1:]
}
)
cols = list(set(df.columns) - {"submission_date", "country", "dau"})
df["population"] = (
df[cols]
.apply(lambda row: "_".join(col for col in cols if row[col]), axis=1)
.replace("", "other") # Replace empty strings with "other"
)
# Pivot to wide format
df = (
df.pivot_table(
index=["submission_date", "country"],
columns="population",
values="dau",
aggfunc="sum",
fill_value=0,
)
.reset_index()
.rename_axis(columns=None)
.replace({0: np.nan})
)
df["total"] = df.drop(columns=["submission_date", "country"]).sum(axis=1)
self.historical_dates = df["submission_date"]
self.population_names = (
["total"]
+ sorted(set(df.columns) - {"total", "other", "submission_date", "country"})
+ ["other"]
)
sub_populations = []
for pop in self.population_names:
a = (
df.groupby(["submission_date", "country"], as_index=False)[pop]
.sum(min_count=1)
.dropna()
)
a["dau"] = a[pop]
hi = holidays.HolidayImpacts(
df=a,
forecast_start=self.dates_to_predict["submission_date"].min(),
forecast_end=self.dates_to_predict["submission_date"].max(),
# detrend_spike_correction=8.0,
)
hi.fit()
b = hi.all_countries
sub_populations.append(
b.rename(columns={"dau": f"{pop}_original", "expected": pop})[
["submission_date", pop, f"{pop}_original"]
]
)
if pop == "total":
self.holiday_impacts = hi
self.future_holiday_impacts = hi.predict()
self.populations = (
reduce(
lambda left, right: pd.merge(
left, right, on="submission_date", how="outer"
),
sub_populations,
)
.sort_values("submission_date")
.reset_index(drop=True)
)
def _get_historical_forecasts(
self,
seed=42,
changepoint_range=0.8,
seasonality_prior_scale=0.00825,
changepoint_prior_scale=0.15983,
) -> None:
sub_populations = []
print("\nForecasting Populations: ", end="")
for i in self.population_names:
print(f"{i}", end=" | ")
np.random.seed(seed)
observed_df = self.populations[["submission_date", i]].copy(deep=True)
observed_df["y"] = observed_df[i]
self.historical_dfs[i] = observed_df
params = {
"daily_seasonality": False,
"weekly_seasonality": True,
"yearly_seasonality": len(observed_df.dropna()) > (365 * 2),
"uncertainty_samples": self.number_of_simulations,
"changepoint_range": changepoint_range,
"growth": "logistic",
}
if observed_df["y"].max() >= 10e6:
params["seasonality_prior_scale"] = seasonality_prior_scale
params["changepoint_prior_scale"] = changepoint_prior_scale
m = prophet.Prophet(**params)
self.historical_forecast_models[i] = m
observed = observed_df.rename(columns=self.column_names_map).copy(deep=True)
future = self.dates_to_predict.rename(columns=self.column_names_map).copy(
deep=True
)
if "growth" in params:
if observed_df["y"].max() >= 10e6:
cap = observed_df["y"].max() * 2.0
floor = observed_df["y"].min() * 0.8
observed["cap"] = cap
observed["floor"] = floor
future["cap"] = cap
future["floor"] = floor
else:
cap = observed_df["y"].max() * 2.0
observed["cap"] = cap
observed["floor"] = 0.0
future["cap"] = cap
future["floor"] = 0.0
m.fit(observed)
forecast_df = pd.DataFrame(m.predictive_samples(future)["yhat"])
self.historical_forecasts[i] = forecast_df
if i != "total":
sub_populations.append(self.historical_forecasts[i])
self.rescaler = self.historical_forecasts["total"] / sum(sub_populations)
for i in self.population_names:
if i == "total":
self.scaled_historical_forecasts[i] = self.historical_forecasts[i]
else:
self.scaled_historical_forecasts[i] = (
self.historical_forecasts[i] * self.rescaler
)
print("done.")
def _get_scenario_forecasts(self) -> None:
start_date = self.populations["submission_date"].min()
end_date = self.dates_to_predict["submission_date"].max()
self.all_dates = pd.date_range(start=start_date, end=end_date, freq="D")
filler = pd.concat(
[
self.populations.total * np.nan
for i in range(self.number_of_simulations)
],
axis=1,
)
filler.columns = range(self.number_of_simulations)
print("Running Scenarios: ", end="")
for population_name, df in self.scaled_historical_forecasts.items():
self.scenario_forecasts[population_name] = pd.concat(
[filler, df.copy(deep=True)]
).reset_index(drop=True)
for scenario_name, s in self.scenarios.items():
print(f"{scenario_name}", end=" | ")
samples = s.sample(self.number_of_simulations)
pct_impacts = pd.DataFrame(
np.column_stack(
[
np.interp(
self.all_dates,
pd.to_datetime(
[start_date, s.start_date, s.end_date, end_date]
),
[0, 0, i, i],
)
for i in samples
]
)
)
ix = np.argmax(self.all_dates == self.historical_end_date)
pct_impacts.iloc[:ix] = 0
pct_impacts.iloc[ix:] = pct_impacts.iloc[ix:] - pct_impacts.iloc[ix].values
self.scenario_percent_impacts[scenario_name] = pct_impacts
for population_name in self.population_names:
if scenario_name in population_name.split("_"):
self.scenario_forecasts[population_name] *= (
1 - self.scenario_percent_impacts[scenario_name]
)
self.scenario_forecasts["total"] *= 0
for population_name in self.population_names:
if population_name != "total":
self.scenario_forecasts["total"] += self.scenario_forecasts[
population_name
]
print("done.")