modules/SwissArmyTransformer/sat/data_utils/webds.py [20:86]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def pytorch_worker_info(group=None):  # sourcery skip: use-contextlib-suppress
    """Return node and worker info for PyTorch and some distributed environments."""
    rank = 0
    world_size = 1
    worker = 0
    num_workers = 1
    try:
        import torch.distributed

        if torch.distributed.is_available() and torch.distributed.is_initialized():
            group = group or torch.distributed.group.WORLD
            rank = torch.distributed.get_rank(group=group)
            world_size = torch.distributed.get_world_size(group=group)
    except ModuleNotFoundError:
        pass
    try:
        import torch.utils.data

        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker = worker_info.id
            num_workers = worker_info.num_workers
    except ModuleNotFoundError:
        pass

    return rank, world_size, worker, num_workers


def pytorch_worker_seed(group=None):
    """Compute a distinct, deterministic RNG seed for each worker and node."""
    rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
    return rank * 1000 + worker

def worker_seed_sat(group=None, seed=0):
    return pytorch_worker_seed(group=group) + seed * 23

class ConfiguredResampledShards(ResampledShards):
    def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
        from sat.helpers import print_rank0
        try:
            from megatron.core.parallel_state import get_data_parallel_group
            group = get_data_parallel_group()
            print_rank0("Using megatron data parallel group.")
        except:
            from sat.mpu import get_data_parallel_group
            try:
                group = get_data_parallel_group()
                print_rank0("Using sat data parallel group.")
            except AssertionError:
                group = None
                print_rank0("No data parallel group is specified!")
        worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
        super().__init__(urls, nshards, worker_seed_sat_this, deterministic)

class SimpleDistributedWebDataset(DataPipeline):
    def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
        # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
        try:
            from sat.mpu import get_model_parallel_world_size
            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1
        except Exception:
            pass
        super().__init__(
            ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly
            tarfile_to_samples(),
            wds.shuffle(shuffle_buffer),
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sat/sgm/webds.py [17:90]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def pytorch_worker_info(group=None):  # sourcery skip: use-contextlib-suppress
    """Return node and worker info for PyTorch and some distributed environments."""
    rank = 0
    world_size = 1
    worker = 0
    num_workers = 1
    try:
        import torch.distributed

        if torch.distributed.is_available() and torch.distributed.is_initialized():
            group = group or torch.distributed.group.WORLD
            rank = torch.distributed.get_rank(group=group)
            world_size = torch.distributed.get_world_size(group=group)
    except ModuleNotFoundError:
        pass
    try:
        import torch.utils.data

        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker = worker_info.id
            num_workers = worker_info.num_workers
    except ModuleNotFoundError:
        pass

    return rank, world_size, worker, num_workers


def pytorch_worker_seed(group=None):
    """Compute a distinct, deterministic RNG seed for each worker and node."""
    rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
    return rank * 1000 + worker


def worker_seed_sat(group=None, seed=0):
    return pytorch_worker_seed(group=group) + seed * 23


class ConfiguredResampledShards(ResampledShards):
    def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
        from sat.helpers import print_rank0

        try:
            from megatron.core.parallel_state import get_data_parallel_group

            group = get_data_parallel_group()
            print_rank0("Using megatron data parallel group.")
        except:
            from sat.mpu import get_data_parallel_group

            try:
                group = get_data_parallel_group()
                print_rank0("Using sat data parallel group.")
            except AssertionError:
                group = None
                print_rank0("No data parallel group is specified!")
        worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
        super().__init__(urls, nshards, worker_seed_sat_this, deterministic)


class SimpleDistributedWebDataset(DataPipeline):
    def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
        # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
        try:
            from sat.mpu import get_model_parallel_world_size

            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1
        except Exception:
            pass
        super().__init__(
            ConfiguredResampledShards(path, seed),  # Lots of shards are recommended, or not evenly
            tarfile_to_samples(),
            wds.shuffle(shuffle_buffer),
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



