in src/entrypoint/train.py [0:0]
def train(args: Namespace, algo_args: Dict[str, Any]) -> None:
"""Train a specified estimator on a specified dataset."""
dataset = load_dataset(args)
algo_args = override_hp(algo_args, dataset.metadata)
estimator = new_estimator(args.algo, kwargs=algo_args)
# Debug/dev/test milestone
if args.stop_before == "train":
logger.info("Early termination: before %s", args.stop_before)
return
# Train & save model
logger.info("Starting model training.")
if args.y_transform == "log1p":
dataset = log1p_tds(dataset)
train_kwargs = get_train_kwargs(estimator, dataset)
predictor = estimator.train(**train_kwargs)
predictor.output_transform = INVERSE[args.y_transform]
save_model(predictor, args)
# Debug/dev/test milestone
if args.stop_before == "eval":
logger.info("Early termination: before %s", args.stop_before)
return
# Backtesting
logger.info("Starting model evaluation.")
forecast_it, ts_it = backtest.make_evaluation_predictions(
dataset=dataset.test, predictor=predictor, num_samples=args.num_samples,
)
# Compute standard metrics over all samples or quantiles, and plot each timeseries, all in one go!
# Remember to specify gt_inverse_transform when computing metrics.
logger.info("MyEvaluator: assume non-negative ground truths, hence no clip_to_zero performed on them.")
gt_inverse_transform = np.expm1 if args.y_transform == "log1p" else None
evaluator = MyEvaluator(
out_dir=Path(args.output_data_dir),
quantiles=args.quantiles,
plot_transparent=bool(args.plot_transparent),
gt_inverse_transform=gt_inverse_transform,
clip_at_zero=True,
)
agg_metrics, item_metrics = evaluator(ts_it, forecast_it, num_series=len(dataset.test))
# required for metric tracking.
for name, value in agg_metrics.items():
logger.info(f"gluonts[metric-{name}]: {value}")
# save the evaluation results
metrics_output_dir = Path(args.output_data_dir)
with open(metrics_output_dir / "agg_metrics.json", "w") as f:
json.dump(agg_metrics, f)
with open(metrics_output_dir / "item_metrics.csv", "w") as f:
item_metrics.to_csv(f, index=False)
# Specific requirement: output wmape to a separate file.
with open(metrics_output_dir / f"{freq_name(dataset.metadata.freq)}-wmapes.csv", "w") as f:
warnings.warn(
"wmape csv uses daily or weekly according to frequency string, "
"hence 7D still results in daily rather than weekly."
)
wmape_metrics = item_metrics[["item_id", "wMAPE"]].rename(
{"item_id": "category", "wMAPE": "test_wMAPE"}, axis=1
)
wmape_metrics.to_csv(f, index=False)