in main.py [0:0]
def run_experiment_worker(subprocess_index, world_size, model_mico, hparams, log_queue):
"""This is the function in each sub-process for multi-process training with DistributedDataParallel.
Parameters
----------
subprocess_index : int
The index for the current sub-process
world_size : int
Total number of all the sub-processes
model_mico : `MutualInfoCotrain` object
The same initialized model for all the sub-processes
hparams : dictionary
The hyper-parameters for MICO
log_queue : `torch.multiprocessing.Queue`
For the logging with multi-process
"""
setup_worker_logging(subprocess_index, log_queue)
logging.info("Running experiment on GPU %d" % subprocess_index)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=subprocess_index, world_size=world_size)
# For reproducible random runs
random.seed(hparams.seed * world_size + subprocess_index)
np.random.seed(hparams.seed * world_size + subprocess_index)
torch.manual_seed(hparams.seed * world_size + subprocess_index)
train_mico(subprocess_index, model_mico)