def load_dataset()

in src/entrypoint/train.py [0:0]


def load_dataset(args: Namespace) -> TrainDatasets:
    """Load data from channel or fallback to named public dataset."""
    if args.s3_dataset is None:
        # load built in dataset
        logger.info("Downloading dataset %s", args.dataset)
        dataset = datasets.get_dataset(args.dataset)
    else:
        # load custom dataset
        logger.info("Loading dataset from %s", args.s3_dataset)
        s3_dataset_dir = Path(args.s3_dataset)
        dataset = load_datasets(
            metadata=s3_dataset_dir / "metadata", train=s3_dataset_dir / "train", test=s3_dataset_dir / "test",
        )
    return dataset