def train()

in src/entrypoint/train.py [0:0]


def train(args: Namespace, algo_args: Dict[str, Any]) -> None:
    """Train a specified estimator on a specified dataset."""
    dataset = load_dataset(args)
    algo_args = override_hp(algo_args, dataset.metadata)
    estimator = new_estimator(args.algo, kwargs=algo_args)

    # Debug/dev/test milestone
    if args.stop_before == "train":
        logger.info("Early termination: before %s", args.stop_before)
        return

    # Train & save model
    logger.info("Starting model training.")
    if args.y_transform == "log1p":
        dataset = log1p_tds(dataset)
    train_kwargs = get_train_kwargs(estimator, dataset)
    predictor = estimator.train(**train_kwargs)
    predictor.output_transform = INVERSE[args.y_transform]
    save_model(predictor, args)

    # Debug/dev/test milestone
    if args.stop_before == "eval":
        logger.info("Early termination: before %s", args.stop_before)
        return

    # Backtesting
    logger.info("Starting model evaluation.")
    forecast_it, ts_it = backtest.make_evaluation_predictions(
        dataset=dataset.test, predictor=predictor, num_samples=args.num_samples,
    )

    # Compute standard metrics over all samples or quantiles, and plot each timeseries, all in one go!
    # Remember to specify gt_inverse_transform when computing metrics.
    logger.info("MyEvaluator: assume non-negative ground truths, hence no clip_to_zero performed on them.")
    gt_inverse_transform = np.expm1 if args.y_transform == "log1p" else None
    evaluator = MyEvaluator(
        out_dir=Path(args.output_data_dir),
        quantiles=args.quantiles,
        plot_transparent=bool(args.plot_transparent),
        gt_inverse_transform=gt_inverse_transform,
        clip_at_zero=True,
    )
    agg_metrics, item_metrics = evaluator(ts_it, forecast_it, num_series=len(dataset.test))

    # required for metric tracking.
    for name, value in agg_metrics.items():
        logger.info(f"gluonts[metric-{name}]: {value}")

    # save the evaluation results
    metrics_output_dir = Path(args.output_data_dir)
    with open(metrics_output_dir / "agg_metrics.json", "w") as f:
        json.dump(agg_metrics, f)
    with open(metrics_output_dir / "item_metrics.csv", "w") as f:
        item_metrics.to_csv(f, index=False)

    # Specific requirement: output wmape to a separate file.
    with open(metrics_output_dir / f"{freq_name(dataset.metadata.freq)}-wmapes.csv", "w") as f:
        warnings.warn(
            "wmape csv uses daily or weekly according to frequency string, "
            "hence 7D still results in daily rather than weekly."
        )
        wmape_metrics = item_metrics[["item_id", "wMAPE"]].rename(
            {"item_id": "category", "wMAPE": "test_wMAPE"}, axis=1
        )
        wmape_metrics.to_csv(f, index=False)