in src/entrypoint/train.py [0:0]
def load_dataset(args: Namespace) -> TrainDatasets:
"""Load data from channel or fallback to named public dataset."""
if args.s3_dataset is None:
# load built in dataset
logger.info("Downloading dataset %s", args.dataset)
dataset = datasets.get_dataset(args.dataset)
else:
# load custom dataset
logger.info("Loading dataset from %s", args.s3_dataset)
s3_dataset_dir = Path(args.s3_dataset)
dataset = load_datasets(
metadata=s3_dataset_dir / "metadata", train=s3_dataset_dir / "train", test=s3_dataset_dir / "test",
)
return dataset