in src/datasets/data_files.py [0:0]
def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> dict[str, list[str]]:
"""
Get the default pattern from a directory or repository by testing all the supported patterns.
The first patterns to return a non-empty list of data files is returned.
In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
"""
# first check the split patterns like data/{split}-00000-of-00001.parquet
for split_pattern in ALL_SPLIT_PATTERNS:
pattern = split_pattern.replace("{split}", "*")
try:
data_files = pattern_resolver(pattern)
except FileNotFoundError:
continue
if len(data_files) > 0:
splits: set[str] = set()
for p in data_files:
p_parts = string_to_dict(xbasename(p), xbasename(split_pattern))
assert p_parts is not None
splits.add(p_parts["split"])
if any(not re.match(_split_re, split) for split in splits):
raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.")
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
splits - {str(split) for split in DEFAULT_SPLITS}
)
return {split: [split_pattern.format(split=split)] for split in sorted_splits}
# then check the default patterns based on train/valid/test splits
for patterns_dict in ALL_DEFAULT_PATTERNS:
non_empty_splits = []
for split, patterns in patterns_dict.items():
for pattern in patterns:
try:
data_files = pattern_resolver(pattern)
except FileNotFoundError:
continue
if len(data_files) > 0:
non_empty_splits.append(split)
break
if non_empty_splits:
return {split: patterns_dict[split] for split in non_empty_splits}
raise FileNotFoundError(f"Couldn't resolve pattern {pattern} with resolver {pattern_resolver}")