trend_getter/scenarios.py (292 lines of code) (raw):

import numpy as np import pandas as pd import sqlglot import prophet from dataclasses import dataclass from functools import reduce from google.cloud import bigquery import plotly.express as px import plotly.graph_objects as go from scipy.interpolate import interp1d from trend_getter import holidays from typing import List, Dict @dataclass class Scenario: name: str slug: str reference: str start_date: str end_date: str population: str cdf_x: List[float] cdf_y: List[float] def __post_init__(self) -> None: if "_" in self.slug: raise ValueError("slug cannot contain '_'.") if self.slug in ["total", "other"]: raise ValueError("slug cannot use the name 'total' or 'other'.") self.start_date = pd.to_datetime(self.start_date) self.end_date = pd.to_datetime(self.end_date) self.query_select = f"IFNULL({self.population}, FALSE) AS {self.slug}," self.cdf = interp1d( self.cdf_x, self.cdf_y, kind="linear", bounds_error=False, fill_value=(0, 1) ) self.quantile = interp1d( self.cdf_y, self.cdf_x, kind="linear", bounds_error=False, fill_value=(0, 1) ) def sample(self, samples=1): return self.quantile(np.random.uniform(0, 1, size=samples)) def plot_distribution(self): fig = px.line( x=self.cdf_x + [1], y=self.cdf_y + [1], template="plotly_white", title=f'{self.name}<br><sup>"Y% of scenarios have X% or lower subpopulation DAU decrease"</sup>', ) fig.add_trace( go.Scatter( x=self.cdf_x + [1], y=self.cdf_y + [1], mode="markers", showlegend=False ) ) fig.update_traces( line=dict(color="black", width=3), marker=dict(size=10, color="black") ) fig.update_layout( autosize=False, xaxis=dict( tickformat=".0%", tickangle=0, title="Subpopulation DAU Decrease", tickmode="array", tickvals=self.cdf_x + [1], ), yaxis=dict( tickformat=".1%", title="Cumulative Probability", tickmode="array", tickvals=self.cdf_y + [1], ), ) fig.show() @dataclass class ScenarioForecasts: product_group: List[str] scenarios: List[Scenario] countries: List[str] historical_start_date: str historical_end_date: str forecast_end_date: str project: str = "mozdata" number_of_simulations: int = 1000 metric: str = "dau" @property def column_names_map(self) -> Dict[str, str]: return {"submission_date": "ds", "value": "y"} def __post_init__(self) -> None: self.scenarios = {i.slug: i for i in self.scenarios} self.historical_dfs = {} self.historical_forecasts = {} self.historical_forecast_models = {} self.scaled_historical_forecasts = {} self.scenario_forecasts = {} self.scenario_percent_impacts = {} self.raw_df = None def _query_(self) -> None: if self.metric == "dau": query = f""" SELECT submission_date, IF(country IN ({",".join([f"'{i}'" for i in self.countries])}), country, "ROW") AS country, {" ".join([i.query_select for i in self.scenarios.values()])} SUM(dau) AS dau, FROM `mozdata.telemetry.active_users_aggregates` WHERE app_name IN ({",".join([f"'{i}'" for i in self.product_group])}) AND submission_date BETWEEN "{self.historical_start_date}" AND "{self.historical_end_date}" GROUP BY ALL ORDER BY {", ".join([str(i + 1) for i in range(len(self.scenarios) + 2)])} """ elif self.metric == "mau": query = f""" SELECT submission_date, IF(country IN ({",".join([f"'{i}'" for i in self.countries])}), country, "ROW") AS country, {" ".join([i.query_select for i in self.scenarios.values()])} SUM(mau) AS dau, FROM `mozdata.telemetry.active_users_aggregates` WHERE app_name IN ({",".join([f"'{i}'" for i in self.product_group])}) AND submission_date BETWEEN "{self.historical_start_date}" AND "{self.historical_end_date}" GROUP BY ALL ORDER BY {", ".join([str(i + 1) for i in range(len(self.scenarios) + 2)])} """ elif self.metric == "engagement": query = f""" SELECT submission_date, IF(country IN ({",".join([f"'{i}'" for i in self.countries])}), country, "ROW") AS country, {" ".join([i.query_select for i in self.scenarios.values()])} SUM(dau) / SUM(mau) AS dau, FROM `moz-fx-data-shared-prod.telemetry.desktop_engagement` WHERE app_name IN ({",".join([f"'{i}'" for i in self.product_group])}) AND submission_date BETWEEN "{self.historical_start_date}" AND "{self.historical_end_date}" AND lifecycle_stage = "existing_users" GROUP BY ALL ORDER BY {", ".join([str(i + 1) for i in range(len(self.scenarios) + 2)])} """ return sqlglot.transpile(query, read="bigquery", pretty=True)[0] def fetch_data(self, filtered_end_date: str = None) -> None: if self.raw_df is None: query = self._query_() print(f"Fetching Data:\n\n{query}") self.raw_df = ( bigquery.Client(project=self.project).query(query).to_dataframe() ) df = self.raw_df.copy(deep=True) else: if filtered_end_date is not None: self.historical_end_date = filtered_end_date print(f"Truncating data to before:\n\n{self.historical_end_date}") df = self.raw_df[ pd.to_datetime(self.raw_df.submission_date) <= pd.to_datetime(self.historical_end_date) ].copy(deep=True) self.dates_to_predict = pd.DataFrame( { "submission_date": pd.date_range( self.historical_end_date, self.forecast_end_date ).date[1:] } ) cols = list(set(df.columns) - {"submission_date", "country", "dau"}) df["population"] = ( df[cols] .apply(lambda row: "_".join(col for col in cols if row[col]), axis=1) .replace("", "other") # Replace empty strings with "other" ) # Pivot to wide format df = ( df.pivot_table( index=["submission_date", "country"], columns="population", values="dau", aggfunc="sum", fill_value=0, ) .reset_index() .rename_axis(columns=None) .replace({0: np.nan}) ) df["total"] = df.drop(columns=["submission_date", "country"]).sum(axis=1) self.historical_dates = df["submission_date"] self.population_names = ( ["total"] + sorted(set(df.columns) - {"total", "other", "submission_date", "country"}) + ["other"] ) sub_populations = [] for pop in self.population_names: a = ( df.groupby(["submission_date", "country"], as_index=False)[pop] .sum(min_count=1) .dropna() ) a["dau"] = a[pop] hi = holidays.HolidayImpacts( df=a, forecast_start=self.dates_to_predict["submission_date"].min(), forecast_end=self.dates_to_predict["submission_date"].max(), # detrend_spike_correction=8.0, ) hi.fit() b = hi.all_countries sub_populations.append( b.rename(columns={"dau": f"{pop}_original", "expected": pop})[ ["submission_date", pop, f"{pop}_original"] ] ) if pop == "total": self.holiday_impacts = hi self.future_holiday_impacts = hi.predict() self.populations = ( reduce( lambda left, right: pd.merge( left, right, on="submission_date", how="outer" ), sub_populations, ) .sort_values("submission_date") .reset_index(drop=True) ) def _get_historical_forecasts( self, seed=42, changepoint_range=0.8, seasonality_prior_scale=0.00825, changepoint_prior_scale=0.15983, ) -> None: sub_populations = [] print("\nForecasting Populations: ", end="") for i in self.population_names: print(f"{i}", end=" | ") np.random.seed(seed) observed_df = self.populations[["submission_date", i]].copy(deep=True) observed_df["y"] = observed_df[i] self.historical_dfs[i] = observed_df params = { "daily_seasonality": False, "weekly_seasonality": True, "yearly_seasonality": len(observed_df.dropna()) > (365 * 2), "uncertainty_samples": self.number_of_simulations, "changepoint_range": changepoint_range, "growth": "logistic", } if observed_df["y"].max() >= 10e6: params["seasonality_prior_scale"] = seasonality_prior_scale params["changepoint_prior_scale"] = changepoint_prior_scale m = prophet.Prophet(**params) self.historical_forecast_models[i] = m observed = observed_df.rename(columns=self.column_names_map).copy(deep=True) future = self.dates_to_predict.rename(columns=self.column_names_map).copy( deep=True ) if "growth" in params: if observed_df["y"].max() >= 10e6: cap = observed_df["y"].max() * 2.0 floor = observed_df["y"].min() * 0.8 observed["cap"] = cap observed["floor"] = floor future["cap"] = cap future["floor"] = floor else: cap = observed_df["y"].max() * 2.0 observed["cap"] = cap observed["floor"] = 0.0 future["cap"] = cap future["floor"] = 0.0 m.fit(observed) forecast_df = pd.DataFrame(m.predictive_samples(future)["yhat"]) self.historical_forecasts[i] = forecast_df if i != "total": sub_populations.append(self.historical_forecasts[i]) self.rescaler = self.historical_forecasts["total"] / sum(sub_populations) for i in self.population_names: if i == "total": self.scaled_historical_forecasts[i] = self.historical_forecasts[i] else: self.scaled_historical_forecasts[i] = ( self.historical_forecasts[i] * self.rescaler ) print("done.") def _get_scenario_forecasts(self) -> None: start_date = self.populations["submission_date"].min() end_date = self.dates_to_predict["submission_date"].max() self.all_dates = pd.date_range(start=start_date, end=end_date, freq="D") filler = pd.concat( [ self.populations.total * np.nan for i in range(self.number_of_simulations) ], axis=1, ) filler.columns = range(self.number_of_simulations) print("Running Scenarios: ", end="") for population_name, df in self.scaled_historical_forecasts.items(): self.scenario_forecasts[population_name] = pd.concat( [filler, df.copy(deep=True)] ).reset_index(drop=True) for scenario_name, s in self.scenarios.items(): print(f"{scenario_name}", end=" | ") samples = s.sample(self.number_of_simulations) pct_impacts = pd.DataFrame( np.column_stack( [ np.interp( self.all_dates, pd.to_datetime( [start_date, s.start_date, s.end_date, end_date] ), [0, 0, i, i], ) for i in samples ] ) ) ix = np.argmax(self.all_dates == self.historical_end_date) pct_impacts.iloc[:ix] = 0 pct_impacts.iloc[ix:] = pct_impacts.iloc[ix:] - pct_impacts.iloc[ix].values self.scenario_percent_impacts[scenario_name] = pct_impacts for population_name in self.population_names: if scenario_name in population_name.split("_"): self.scenario_forecasts[population_name] *= ( 1 - self.scenario_percent_impacts[scenario_name] ) self.scenario_forecasts["total"] *= 0 for population_name in self.population_names: if population_name != "total": self.scenario_forecasts["total"] += self.scenario_forecasts[ population_name ] print("done.")