in modules/SwissArmyTransformer/sat/arguments.py [0:0]
def initialize_distributed(args):
"""Initialize torch.distributed."""
if torch.distributed.is_initialized():
if mpu.model_parallel_is_initialized():
if args.model_parallel_size != mpu.get_model_parallel_world_size():
raise ValueError('model_parallel_size is inconsistent with prior configuration.'
'We currently do not support changing model_parallel_size.')
return False
else:
if args.model_parallel_size > 1:
warnings.warn('model_parallel_size > 1 but torch.distributed is not initialized via SAT.'
'Please carefully make sure the correctness on your own.')
mpu.initialize_model_parallel(args.model_parallel_size)
return True
# the automatic assignment of devices has been moved to arguments.py
if args.device == 'cpu':
pass
else:
torch.cuda.set_device(args.device)
# Call the init process
init_method = 'tcp://'
args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
if args.world_size == 1:
from sat.helpers import get_free_port
default_master_port = str(get_free_port())
else:
default_master_port = '6000'
args.master_port = os.getenv('MASTER_PORT', default_master_port)
init_method += args.master_ip + ':' + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
# Optional DeepSpeed Activation Checkpointing Features
if args.deepspeed:
import deepspeed
deepspeed.init_distributed(
dist_backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank, init_method=init_method)
# It seems that it has no negative influence to configure it even without using checkpointing.
deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
else:
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
except Exception as e:
from sat.helpers import print_rank0
print_rank0(str(e), level="DEBUG")
return True