in src/datasets/packaged_modules/webdataset/webdataset.py [0:0]
def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
# Download the data files
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
data_files = dl_manager.download(self.config.data_files)
splits = []
for split_name, tar_paths in data_files.items():
if isinstance(tar_paths, str):
tar_paths = [tar_paths]
tar_iterators = [dl_manager.iter_archive(tar_path) for tar_path in tar_paths]
splits.append(
datasets.SplitGenerator(
name=split_name, gen_kwargs={"tar_paths": tar_paths, "tar_iterators": tar_iterators}
)
)
if not self.info.features:
# Get one example to get the feature types
pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0])
first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE))
if any(example.keys() != first_examples[0].keys() for example in first_examples):
raise ValueError(
"The TAR archives of the dataset should be in WebDataset format, "
"but the files in the archive don't share the same prefix or the same types."
)
pa_tables = [
pa.Table.from_pylist(cast_to_python_objects([example], only_1d_for_numpy=True))
for example in first_examples
]
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)
# Set Image types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
if extension in self.IMAGE_EXTENSIONS:
features[field_name] = datasets.Image()
# Set Audio types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
if extension in self.AUDIO_EXTENSIONS:
features[field_name] = datasets.Audio()
# Set Video types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
if extension in self.VIDEO_EXTENSIONS:
features[field_name] = datasets.Video()
self.info.features = features
return splits