def get_train_kwargs()

in src/entrypoint/train.py [0:0]


def get_train_kwargs(estimator, dataset) -> Dict[str, Any]:
    """Probe the right validation-data kwarg for the estimator.

    Known cases :

    - NPTSEstimator (or any other based on DummyEstimator) uses validation_dataset=...
    - Other estimators use validation_data=...
    """
    candidate_kwarg = [k for k in inspect.signature(estimator.train).parameters if "validation_data" in k]
    kwargs = {"training_data": dataset.train}
    if len(candidate_kwarg) == 1:
        kwargs[candidate_kwarg[0]] = dataset.test
    else:
        kwargs["validation_data"] = dataset.test
    return kwargs