trend_getter/holidays.py (340 lines of code) (raw):
import holidays
import numpy as np
import pandas as pd
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from dateutil.easter import easter
from typing import Optional
from trend_getter.plotting import Line, plot
from trend_getter.metric_calculations import moving_average, year_over_year
NO_PASCHAL_CYCLE = ["IN", "JP", "IR", "CN"]
class PaschalCycleHolidays(holidays.HolidayBase):
"""
Custom holiday calendar for key Christian holidays, mostly related to the
Paschal cycle. Includes Mardi Gras, Palm Sunday, Good Friday, Easter,
Ascension Day, Corpus Christi. Also includes All Saints' Day, which is not
part of the Paschal Cycle.
"""
def _populate(self, year):
# Get Easter Sunday for the given year
easter_sunday = easter(year)
# Add pre- and post-Easter holidays
self[easter_sunday - pd.Timedelta(days=47)] = "Mardi Gras"
self[easter_sunday - pd.Timedelta(days=7)] = "Palm Sunday"
self[easter_sunday - pd.Timedelta(days=2)] = "Good Friday"
self[easter_sunday] = "Easter Sunday"
self[easter_sunday + pd.Timedelta(days=39)] = "Ascension Day"
self[easter_sunday + pd.Timedelta(days=60)] = "Corpus Christi"
# Fixed-date holiday
self[pd.Timestamp(year=year, month=11, day=1)] = "All Saints' Day"
class MozillaHolidays(holidays.HolidayBase):
"""
Custom holiday calendar for Mozilla-specific historical incidents.
Currently includes Data Loss events in May and July 2019.
"""
def _populate(self, year):
# Only populate holidays if the year includes 2019
if year == 2019:
# Data Loss in May 6–13, 2019
for day in range(6, 14):
self[datetime(2019, 5, day).date()] = "Data Loss"
# Additional Data Loss in July 15–17, 2019
for day in range(15, 18):
self[datetime(2019, 7, day).date()] = "Data Loss"
def get_calendar(
country: str,
holiday_years: list,
exclude_paschal_cycle: list = NO_PASCHAL_CYCLE,
split_concurrent_holidays: bool = False,
) -> pd.DataFrame:
"""
Generate a cleaned and formatted DataFrame of holidays for a specific country.
Args:
country (str): Country code (e.g., "US", "FR", "ROW").
holiday_years (list): List of years to include holidays from.
include_paschal_cycle (list): A list of countries that aren't impacted by the Paschal Cycle.
split_concurrent_holidays (bool): Whether to split semicolon-delimited holidays into multiple rows.
Returns:
pd.DataFrame: A DataFrame with holidays, labeled by date, country, and cleaned holiday name.
"""
# Use US holidays as a default for ROW (Rest of World)
if country == "ROW":
country_holidays = holidays.US(years=holiday_years)
else:
country_holidays = getattr(holidays, country)(years=holiday_years)
# Optionally add Paschal cycle holidays
if country not in exclude_paschal_cycle:
country_holidays += PaschalCycleHolidays(years=holiday_years)
# Include Mozilla-specific holidays for 2019
if 2019 in holiday_years:
country_holidays += MozillaHolidays(years=holiday_years)
# Convert holiday dictionary into DataFrame
df = pd.DataFrame(
{
"submission_date": list(country_holidays.keys()),
"holiday": list(country_holidays.values()),
"country": country,
}
)
# Clean holiday name text
df["holiday"] = df["holiday"].str.replace(
r"Day off \(substituted.*", "day off (substituted)", regex=True
)
df["holiday"] = df["holiday"].str.replace(" (observed)", "", regex=False)
# Split concurrent holidays into separate rows if requested
if split_concurrent_holidays:
df = df.assign(holiday=df["holiday"].str.split("; ")).explode(
"holiday", ignore_index=True
)
else:
# Otherwise, append the country name to each concurrent holiday
df["holiday"] = df["holiday"].str.replace(";", f"; {country}", regex=False)
# Prefix holiday names with country for clarity
df["holiday"] = df["country"] + " " + df["holiday"]
df["submission_date"] = pd.to_datetime(df["submission_date"])
return df.sort_values(by="submission_date").reset_index(drop=True)
def detrend(
df: pd.DataFrame,
holiday_df: pd.DataFrame,
threshold: float = -0.05,
max_radius: int = 7,
min_radius: int = 3,
spike_correction: Optional[float] = None,
) -> pd.DataFrame:
"""
Applies a physics-inspired detrending algorithm to smooth out holiday-driven dips in DAU.
Parameters:
df (pd.DataFrame): DataFrame with submission_date and dau columns.
holiday_df (pd.DataFrame): DataFrame with submission_date and holiday columns.
threshold (float): Minimum relative difference for adjustment to apply.
max_radius (int): Maximum distance (in days) from a holiday for adjustment.
min_radius (int): Minimum difference (in days) that is meaningful for smoothing.
Must be 1 <= min_radius <= max_radius. A value of 1 means that the holiday +/-1
day all have equal weight.
spike_correction (Optional float): Correction multiple to clamp x, v, or a values.
Returns:
pd.DataFrame: Input dataframe with additional columns:
'x' (position), 'v' (velocity), 'a' (acceleration), and 'expected' (detrended DAU).
"""
df = df.copy()
# Calculate a spike-correction factor:
def scaled_sigmoid(x, L=1.5, U=8, x0=5e5, k=1e-5):
return L + (U - L) / (1 + np.exp(-k * (x - x0)))
if spike_correction is None:
spike_correction = scaled_sigmoid(df.dau.max())
# Merge in holiday labels
df["submission_date"] = pd.to_datetime(df["submission_date"])
df = df.merge(holiday_df, how="left", on="submission_date")
# Create a holiday date lookup
holiday_dates = df.loc[df["holiday"].notna(), "submission_date"]
# Compute days from the nearest holiday
df["days_from_holiday"] = df["submission_date"].apply(
lambda date: max((holiday_dates - date).abs().min().days, min_radius)
)
# Initialize series for kinematic components and expected values
_x, _v, _a, _e = [], [], [], []
for i in df.itertuples():
idx = i.Index
if idx >= 21:
# Get lagged expected values for position/velocity/acceleration
lag07, lag14, lag21 = _e[idx - 7], _e[idx - 14], _e[idx - 21]
x = lag07
v = lag07 - lag14
a = lag07 - 2 * lag14 + lag21
# Compute rolling averages for recent values
x_bar = np.mean(_x[-7:])
v_bar = np.mean(_v[-7:])
a_bar = np.mean(_a[-7:])
# Clamp spikes using relative thresholds
if abs(x) > spike_correction * abs(x_bar):
x = x_bar
if abs(v) > spike_correction * abs(v_bar):
v = v_bar
if abs(a) > spike_correction * abs(a_bar):
a = a_bar
# Compute expected DAU using position + velocity + 0.5 * acceleration
e = (x + v + 0.5 * a) or i.dau
_x.append(x)
_v.append(v)
_a.append(a)
# If within holiday radius and relative error is below threshold, apply smoothing
if (i.days_from_holiday <= max_radius) and (i.dau / abs(e) - 1) < threshold:
weight = (min_radius + 1) / (i.days_from_holiday + 1)
blended = e * weight + i.dau * (1 - weight)
_e.append(blended)
else:
_e.append(i.dau)
else:
# For early points, fall back to observed DAU
_x.append(np.nan)
_v.append(np.nan)
_a.append(np.nan)
_e.append(i.dau)
# Attach kinematic components and expected values to DataFrame
df["x"] = _x
df["v"] = _v
df["a"] = _a
df["expected"] = _e
return df
def estimate_impacts(
dau_dfs: dict,
holiday_dfs: dict,
last_observed_date=None,
dau_column: str = "dau",
expected_dau_column: str = "expected",
) -> dict:
"""
Estimate holiday impacts by comparing observed and expected DAU near holidays.
Parameters:
dau_dfs (dict): Dictionary of DataFrames with actual and expected DAU per country.
holiday_dfs (dict): Dictionary of DataFrames with holidays per country.
last_observed_date (str or pd.Timestamp, optional): Filter out dates >= this value.
dau_column (str): Column name for actual DAU.
expected_dau_column (str): Column name for expected DAU.
Returns:
dict: Nested dictionary of estimated holiday impacts:
{holiday: {date_diff: {"impact": [...], "years": set(), "average_impact": float}}}
"""
holiday_impacts = defaultdict(
lambda: defaultdict(lambda: {"impact": [], "years": set()})
)
print("Calculating holiday impacts for: ", end="")
for country in dau_dfs:
print(country, end=", ")
dau_df = dau_dfs[country].copy()
holiday_df = holiday_dfs[country].copy()
# Optional filter to exclude future dates
if last_observed_date is not None:
dau_df = dau_df[
dau_df["submission_date"] < pd.to_datetime(last_observed_date)
].copy()
# Cross-join DAU and holiday dates
merged_df = dau_df.merge(holiday_df, how="cross", suffixes=("_dau", "_holiday"))
# Calculate date difference between submission_dates and holiday dates
merged_df["date_diff"] = (
merged_df["submission_date_dau"] - merged_df["submission_date_holiday"]
).dt.days
# Keep only rows where a holiday is within ±7 days of the date
merged_df = merged_df[merged_df["date_diff"].between(-7, 7)].copy()
# Exclude rows with "Data Loss" holidays
merged_df = merged_df[
~merged_df["holiday_holiday"].str.contains("Data Loss", na=False)
].copy()
# Calculate the DAU impact: (observed - expected)
merged_df["impact"] = merged_df.groupby("submission_date_dau")[
dau_column
].transform("first") - merged_df.groupby("submission_date_dau")[
expected_dau_column
].transform(
"first"
)
# Apply inverse-distance weighting by date offset
merged_df["weight"] = 1 / (1 + merged_df["date_diff"].abs())
merged_df["scale"] = merged_df["weight"] / merged_df.groupby(
"submission_date_dau"
)["weight"].transform("sum")
# Scale the impact by the weight
merged_df["impact"] *= merged_df["scale"]
# Expand semicolon-delimited holidays into separate rows
merged_df = (
merged_df.assign(holiday=merged_df["holiday_dau"].str.split("; "))
.explode("holiday")
.assign(holiday=merged_df["holiday_holiday"].str.split("; "))
.explode("holiday")
).copy()
# Aggregate impacts into the nested dictionary
for row in merged_df.itertuples():
holiday_impacts[row.holiday][row.date_diff]["impact"].append(row.impact)
# For substituted holidays, store full date; otherwise, store year only
if "day off (substituted)" in row.holiday:
holiday_impacts[row.holiday][row.date_diff]["years"].add(
row.submission_date_holiday
)
else:
holiday_impacts[row.holiday][row.date_diff]["years"].add(
row.submission_date_holiday.year
)
print()
# Compute average impact for each (holiday, date_diff) pair
for diffs in holiday_impacts.values():
for data in diffs.values():
if len(data["years"]) > 0:
data["average_impact"] = sum(data["impact"]) / len(data["years"])
else:
data["average_impact"] = 0.0 # Prevent division by zero
return holiday_impacts
def predict_impacts(countries, holiday_impacts, start_date, end_date):
future_dates = pd.date_range(start_date, end_date)
holiday_dates = (
pd.concat(
get_calendar(
country=country,
holiday_years=np.unique(future_dates.year),
split_concurrent_holidays=True,
)
for country in countries
)
.sort_values(by="submission_date")
.reset_index(drop=True)
)
impacts = [] # List to store predicted impact values
new_holidays = set() # Track unknown holidays for diagnostic output
for target_date in future_dates:
# Compute date difference between target_date and all holiday dates
date_diffs = pd.to_datetime(target_date) - holiday_dates.submission_date
# Filter holidays within ±7 days
nearby = holiday_dates[abs(date_diffs) <= timedelta(days=7)].copy()
impact = 0
if len(nearby) and not nearby["holiday"].str.contains("Data Loss").any():
# Compute integer date_diff for indexing
nearby["date_diff"] = date_diffs[nearby.index].map(lambda x: x.days)
# Accumulate known holiday impacts
for row in nearby.itertuples():
if row.holiday in holiday_impacts:
impact += holiday_impacts[row.holiday][row.date_diff][
"average_impact"
]
else:
new_holidays.add(row.holiday)
impacts.append(impact)
print("Unaccounted Holidays:\n - " + "\n - ".join(new_holidays))
return pd.DataFrame({"submission_date": future_dates, "impact": impacts})
@dataclass
class HolidayImpacts:
df: pd.DataFrame
forecast_start: str
forecast_end: str
detrend_threshold: float = -0.05
detrend_max_radius: int = 5
detrend_min_radius: int = 3
detrend_spike_correction: Optional[float] = None
calendar_exclude_paschal_cycle: list = field(
default_factory=lambda: NO_PASCHAL_CYCLE
)
def __post_init__(self):
self.countries = np.unique(self.df.country)
self.observed_years = pd.to_datetime(self.df.submission_date).dt.year.unique()
self.dau_dfs = {}
self.holiday_dfs = {}
def fit(self):
for country in self.countries:
self.holiday_dfs[country] = get_calendar(
country=country,
holiday_years=self.observed_years,
exclude_paschal_cycle=self.calendar_exclude_paschal_cycle,
split_concurrent_holidays=False,
)
self.dau_dfs[country] = detrend(
df=self.df[self.df.country == country],
holiday_df=self.holiday_dfs[country],
threshold=self.detrend_threshold,
max_radius=self.detrend_max_radius,
min_radius=self.detrend_min_radius,
spike_correction=self.detrend_spike_correction,
)
self.dau_dfs[country]["dau_28ma"] = moving_average(
self.dau_dfs[country]["dau"]
)
self.dau_dfs[country]["edau_28ma"] = moving_average(
self.dau_dfs[country]["expected"]
)
self.dau_dfs[country]["dau_yoy"] = year_over_year(
self.dau_dfs[country], "dau_28ma"
)
self.dau_dfs[country]["edau_yoy"] = year_over_year(
self.dau_dfs[country], "edau_28ma"
)
self.all_countries = (
pd.concat(
[
i[["submission_date", "dau", "expected", "dau_28ma", "edau_28ma"]]
for i in self.dau_dfs.values()
]
)
.groupby("submission_date", as_index=False)
.sum(min_count=1)
)
self.all_countries["dau_yoy"] = year_over_year(self.all_countries, "dau_28ma")
self.all_countries["edau_yoy"] = year_over_year(self.all_countries, "edau_28ma")
self.holiday_impacts = estimate_impacts(
dau_dfs=self.dau_dfs,
holiday_dfs=self.holiday_dfs,
last_observed_date=self.forecast_start,
)
self.future_impacts = None
def predict(self):
if self.future_impacts is None:
self.future_impacts = predict_impacts(
self.countries,
self.holiday_impacts,
self.forecast_start,
self.forecast_end,
)
return self.future_impacts
def plot_countries(self):
for country, df in self.dau_dfs.items():
plot(
df,
[
Line("dau", "#ff9900", "DAU"),
Line("expected", "black", "Detrended DAU", opacity=0.5),
],
f"Holiday Detrended DAU ({country})",
"DAU",
)
plot(
df,
[
Line("dau_28ma", "#ff9900", "DAU 28MA"),
Line("edau_28ma", "black", "Detrended DAU 28MA", opacity=0.5),
],
f"Holiday Detrended DAU 28MA ({country})",
"DAU 28MA",
)
plot(
df,
[
Line("dau_yoy", "#ff9900", "DAU YoY"),
Line("edau_yoy", "black", "Detrended DAU YoY", opacity=0.5),
],
f"Holiday Detrended DAU YoY ({country})",
"DAU YoY",
)
def plot_overall(self):
plot(
self.all_countries,
[
Line("dau_28ma", "#ff9900", "DAU 28MA"),
Line("edau_28ma", "black", "Detrended DAU 28MA", opacity=0.5),
],
"Holiday Detrended DAU 28MA",
"DAU 28MA",
)
plot(
self.all_countries,
[
Line("dau_yoy", "#ff9900", "DAU 28MA YoY"),
Line("edau_yoy", "black", "Holidays Removed YoY", opacity=0.5),
],
"YoY Dashboard Comparisons",
"DAU YoY",
)
def plot_future_impacts(self):
plot(
self.predict(),
[Line("impact", "black", "DAU Impact")],
"Estimated Holiday Impacts",
"Estimated DAU",
)