def maybe_setup_distributed()

in uimnet/workers/base.py [0:0]


  def maybe_setup_distributed(self, cfg):

    # Turning off the read only flags on the config file
    OmegaConf.set_readonly(cfg, False)

    if cfg.experiment.distributed:
      utils.message('Distributed job detected')

      cfg.experiment.world_size = os.environ.get('WORLD_SIZE', cfg.experiment.world_size)
      cfg.experiment.rank = os.environ.get('RANK', cfg.experiment.world_size)

      # Scenario 1: Initializing from environment
      if cfg.experiment.dist_protocol == 'env':
        cfg.experiment.dist_url = f'{cfg.experiment.dist_protocol}://'

      # Scenario 2: Check if using SLURM on the FAIR cluster
      elif 'SLURM_JOB_ID' in os.environ and cfg.experiment.platform == 'slurm':
        # TODO Improve SLURM Support:
        # - n different tasks on the same node will require n different IPs
        # - Specifying local rank for jobs multi-nodes jobs with unequal
        # numbers of workers per node.
        utils.message('Slurm job detected')
        _world_size = int(os.environ['SLURM_NNODES']
                         ) * int(os.environ["SLURM_TASKS_PER_NODE"][0])
        utils.message(f'Slurm world size={_world_size}')
        cfg.experiment.world_size = int(os.environ['SLURM_NTASKS'])
        assert _world_size == cfg.experiment.world_size

        procid = int(os.environ['SLURM_PROCID'])
        cfg.experiment.rank = procid
        # Using TCP/IP by default on SLURM
        cfg.experiment.dist_protocol = 'tcp'
        # Using first node as master node
        master_addr = os.getenv("SLURM_JOB_NODELIST").split(',')[0].replace(
            '[', ''
        )
        master_port = f'4000'
        cfg.experiment.dist_url = f'{cfg.experiment.dist_protocol}://{master_addr}:{master_port}'
        cfg.experiment.num_workers = max(2, (torch.multiprocessing.cpu_count() // torch.cuda.device_count()) - 2)

      # Scenario 3: TCP/IP
      elif cfg.experiment.dist_protocol == 'tcp':
        # Each script should be called with
        master_addr = str(os.getenv('MASTER_ADDR', socket.gethostname()))
        master_port = str(os.getenv('MASTER_PORT', 4000))
        cfg.experiment.dist_url = f'{cfg.experiment.dist_protocol}://{master_addr}:{master_port}'

      # Scenario 4: File initialization
      elif cfg.experiment.dist_protocol == 'file':
        connection_file = os.environ.get('CONNECTION_FILE')
        cfg.experiment.dist_url = f'{cfg.experiment.dist_protocol}://{connection_file}'

      # Scenario 5: explicitely provided connection address
      else:
        if cfg.experiment.dist_protocol is None and cfg.experiment.dist_url is None:
          err_msg = f'Specify dist_url or valid dist_protocol'
          raise ValueError(err_msg)

      assert cfg.experiment.dist_url is not None
      assert cfg.experiment.rank is not None
      assert cfg.experiment.world_size is not None

      cfg.experiment.rank = int(cfg.experiment.rank)
      cfg.experiment.world_size = int(cfg.experiment.world_size)
      dist.init_process_group(
          backend='nccl',
          init_method=cfg.experiment.dist_url,
          rank=cfg.experiment.rank,
          world_size=cfg.experiment.world_size
      )

      local_rank = cfg.experiment.rank % torch.cuda.device_count()
      cfg.experiment.local_rank = cfg.experiment.local_rank
      cfg.experiment.device = f"cuda:{local_rank}"

    if cfg.experiment.device is None:
      err_msg = 'Please specify device'
      raise ValueError(err_msg)

    OmegaConf.set_readonly(cfg, True)
    return