prediction_generation/cpdbench_prophet.py (71 lines of code) (raw):
import argparse
import time
import pandas as pd
from prophet import Prophet
import json
import sys
import copy
from cpdbench_utils import load_dataset, exit_success, make_param_dict, exit_with_error
def parse_args():
parser = argparse.ArgumentParser(description="Wrapper for Prophet")
parser.add_argument("-i", "--input", help="path to the input data file", required=True)
parser.add_argument("-o", "--output", help="path to the output file")
parser.add_argument("-N", "--Nmax", help="maximum number of changepoints", choices=['default', 'max'])
parser.add_argument("-w", "--WeeklySeasonality", type=bool, help="Weekly Seasonality")
parser.add_argument("-d", "--DailySeasonality", type=bool, help="Daily Seasonality")
parser.add_argument("-r", "--ChangepointRange", type=float, help="Changepoint Range")
parser.add_argument("-p", "--ChangepointPriorScale", type=float, help="Changepoint Prior Scale")
parser.add_argument("-t", "--IntervalWidth", type=float, help="Interval Width")
parser.add_argument("-g", "--growth", type=str, help="Growth type: 'linear' or 'logistic'", choices=['linear', 'logistic'])
parser.add_argument("-c", "--cap", type=float, help="Capacity for logistic growth (required for logistic growth)")
return parser.parse_args()
# Function to convert Timestamp objects to a JSON serializable format
def convert_timestamps(obj):
if isinstance(obj, pd.Timestamp):
return obj.isoformat() # Convert to ISO 8601 string format
elif isinstance(obj, list):
return [convert_timestamps(item) for item in obj]
elif isinstance(obj, dict):
return {key: convert_timestamps(value) for key, value in obj.items()}
else:
return obj
def main():
args = parse_args()
raw_args = copy.deepcopy(args)
# Load the dataset (using a Python equivalent of your R helper function)
data, mat = load_dataset(args.input)
start_time = time.time()
if args.Nmax == 'default':
args.Nmax = 25
else:
args.Nmax = data['n_obs'] - 1 # Adjusted from 'original' to the main data structure
# Check if 'series' is in data and extract appropriately
if 'series' not in data or len(data['series']) == 0:
exit_with_error(data, raw_args, vars(args), "No time series data available.")
# Prepare the DataFrame for Prophet
df = pd.DataFrame({
'ds': data['time']['raw'], # Time column
'y': data['series'][0]['raw'] # Series values
})
# Handle logistic growth by adding 'cap'
if args.growth == 'logistic':
if args.cap is None:
exit_with_error(data, raw_args, vars(args), "Capacity ('cap') must be provided for logistic growth.")
# Add the capacity column to the DataFrame
df['cap'] = args.cap
# Fit the Prophet model
try:
model = Prophet(
changepoint_range=args.ChangepointRange,
n_changepoints=args.Nmax,
weekly_seasonality=args.WeeklySeasonality,
daily_seasonality=args.DailySeasonality,
growth=args.growth, # 'linear' or 'logistic'
changepoint_prior_scale=args.ChangepointPriorScale,
interval_width=args.IntervalWidth
)
model.fit(df)
# Retrieve changepoints (timestamps)
changepoints = model.changepoints
locs = changepoints.index.tolist()
except Exception as e:
exit_with_error(data, raw_args, vars(args), str(e), __file__)
stop_time = time.time()
runtime = stop_time - start_time
locs = convert_timestamps(locs)
data = convert_timestamps(data)
exit_success(data, raw_args, vars(args), locs, runtime, __file__)
if __name__ == "__main__":
main()