in modules/SwissArmyTransformer/sat/data_utils/webds.py [0:0]
def __init__(self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None):
# os.environ['WDS_SHOW_SEED'] = '1'
if include_dirs is not None: # /webdatasets/A,/webdatasets/C
other_paths = []
include_dirs = include_dirs.split(',')
for include_dir in include_dirs:
if '*' in include_dir:
include_dir, n = include_dir.split('*')
n = int(n)
else:
n = 1
for cur_dir, dirs, files in os.walk(include_dir):
for f in files:
if f.endswith('tar') and os.path.getsize(os.path.join(cur_dir,f)) > 0:
# other_paths.append(os.path.join(cur_dir,f))
other_paths.extend([os.path.join(cur_dir,f)]*n)
# print(f'Adding dataset paths {",".join(other_paths)}')
from braceexpand import braceexpand
if len(path) > 0: # not ""
path = list(braceexpand(path)) + other_paths
else:
path = other_paths
tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
tarfile_to_samples = pipelinefilter(tarfile_samples)
# if model parallel, shuffle_buffer should be 1 to disable shuffling
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(path, seed, nshards=nshards),
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn
)