def __init__()

in modules/SwissArmyTransformer/sat/data_utils/webds.py [0:0]


    def __init__(self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None):
        # os.environ['WDS_SHOW_SEED'] = '1'
        if include_dirs is not None: # /webdatasets/A,/webdatasets/C
            other_paths = []
            include_dirs = include_dirs.split(',')
            for include_dir in include_dirs:
                if '*' in include_dir:
                    include_dir, n = include_dir.split('*')
                    n = int(n)
                else:
                    n = 1
                for cur_dir, dirs, files in os.walk(include_dir):
                    for f in files:
                        if f.endswith('tar') and os.path.getsize(os.path.join(cur_dir,f)) > 0:
                            # other_paths.append(os.path.join(cur_dir,f))
                            other_paths.extend([os.path.join(cur_dir,f)]*n)
            # print(f'Adding dataset paths {",".join(other_paths)}')
            from braceexpand import braceexpand
            if len(path) > 0: # not "" 
                path = list(braceexpand(path)) + other_paths
            else:
                path = other_paths
        
        tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
        tarfile_to_samples = pipelinefilter(tarfile_samples)

        # if model parallel, shuffle_buffer should be 1 to disable shuffling
        try:
            from sat.mpu import get_model_parallel_world_size
            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1
        except Exception:
            pass
        
        super().__init__(
            ConfiguredResampledShards(path, seed, nshards=nshards),
            tarfile_to_samples(),
            wds.shuffle(shuffle_buffer),
            process_fn
        )