in src/entrypoint/train.py [0:0]
def get_train_kwargs(estimator, dataset) -> Dict[str, Any]:
"""Probe the right validation-data kwarg for the estimator.
Known cases :
- NPTSEstimator (or any other based on DummyEstimator) uses validation_dataset=...
- Other estimators use validation_data=...
"""
candidate_kwarg = [k for k in inspect.signature(estimator.train).parameters if "validation_data" in k]
kwargs = {"training_data": dataset.train}
if len(candidate_kwarg) == 1:
kwargs[candidate_kwarg[0]] = dataset.test
else:
kwargs["validation_data"] = dataset.test
return kwargs