def initialize_distributed()

in modules/SwissArmyTransformer/sat/arguments.py [0:0]


def initialize_distributed(args):
    """Initialize torch.distributed."""
    if torch.distributed.is_initialized():
        if mpu.model_parallel_is_initialized():
            if args.model_parallel_size != mpu.get_model_parallel_world_size():
                raise ValueError('model_parallel_size is inconsistent with prior configuration.'
                                 'We currently do not support changing model_parallel_size.')
            return False
        else:
            if args.model_parallel_size > 1:
                warnings.warn('model_parallel_size > 1 but torch.distributed is not initialized via SAT.'
                            'Please carefully make sure the correctness on your own.')
            mpu.initialize_model_parallel(args.model_parallel_size)
        return True
    # the automatic assignment of devices has been moved to arguments.py
    if args.device == 'cpu':
        pass
    else:
        torch.cuda.set_device(args.device)
    # Call the init process
    init_method = 'tcp://'
    args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
    
    if args.world_size == 1:
        from sat.helpers import get_free_port
        default_master_port = str(get_free_port())
    else:
        default_master_port = '6000'
    args.master_port = os.getenv('MASTER_PORT', default_master_port)
    init_method += args.master_ip + ':' + args.master_port
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
        init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    mpu.initialize_model_parallel(args.model_parallel_size)
    # Optional DeepSpeed Activation Checkpointing Features
    if args.deepspeed: 
        import deepspeed
        deepspeed.init_distributed(
            dist_backend=args.distributed_backend,
            world_size=args.world_size, rank=args.rank, init_method=init_method)
        # It seems that it has no negative influence to configure it even without using checkpointing.  
        deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
    else:
        # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout. 
        try:
            import deepspeed
            from deepspeed.runtime.activation_checkpointing.checkpointing import _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME
            _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
        except Exception as e:
            from sat.helpers import print_rank0
            print_rank0(str(e), level="DEBUG")


    return True