def main()

in downstream/semseg/ddp_main.py [0:0]


def main(config, init_distributed=False):

  if not torch.cuda.is_available():
    raise Exception('No GPUs FOUND.')
  
  # setup initial seed
  torch.cuda.set_device(config.distributed.device_id)  
  torch.manual_seed(config.misc.seed)
  torch.cuda.manual_seed(config.misc.seed)

  device = config.distributed.device_id
  distributed = config.distributed.distributed_world_size > 1

  if init_distributed:
    config.distributed.distributed_rank = distributed_utils.distributed_init(config.distributed)

  setup_logging(config)

  logging.info('===> Configurations')
  logging.info(config.pretty())

  DatasetClass = load_dataset(config.data.dataset)
  if config.test.test_original_pointcloud:
    if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
      raise ValueError('This dataset does not support full pointcloud evaluation.')

  if config.test.evaluate_original_pointcloud:
    if not config.data.return_transformation:
      raise ValueError('Pointcloud evaluation requires config.return_transformation=true.')

  if (config.data.return_transformation ^ config.test.evaluate_original_pointcloud):
    raise ValueError('Rotation evaluation requires config.evaluate_original_pointcloud=true and '
                     'config.return_transformation=true.')

  logging.info('===> Initializing dataloader')
  if config.train.is_train:
    train_data_loader = initialize_data_loader(
        DatasetClass,
        config,
        phase=config.train.train_phase,
        num_workers=config.data.num_workers,
        augment_data=True,
        shuffle=True,
        repeat=True,
        batch_size=config.data.batch_size,
        limit_numpoints=config.data.train_limit_numpoints)

    val_data_loader = initialize_data_loader(
        DatasetClass,
        config,
        num_workers=config.data.num_val_workers,
        phase=config.train.val_phase,
        augment_data=False,
        shuffle=True,
        repeat=False,
        batch_size=config.data.val_batch_size,
        limit_numpoints=False)

    if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
      num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
    else:
      num_in_channel = 3  # RGB color

    num_labels = train_data_loader.dataset.NUM_LABELS
  
  else:
    
    test_data_loader = initialize_data_loader(
        DatasetClass,
        config,
        num_workers=config.data.num_workers,
        phase=config.data.test_phase,
        augment_data=False,
        shuffle=False,
        repeat=False,
        batch_size=config.data.test_batch_size,
        limit_numpoints=False)
    
    if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
      num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
    else:
      num_in_channel = 3  # RGB color

    num_labels = test_data_loader.dataset.NUM_LABELS

  logging.info('===> Building model')
  NetClass = load_model(config.net.model)
  if config.net.wrapper_type == None:
    model = NetClass(num_in_channel, num_labels, config)
    logging.info('===> Number of trainable parameters: {}: {}'.format(NetClass.__name__,
                                                                      count_parameters(model)))
  else:
    wrapper = load_wrapper(config.net.wrapper_type)
    model = wrapper(NetClass, num_in_channel, num_labels, config)
    logging.info('===> Number of trainable parameters: {}: {}'.format(
        wrapper.__name__ + NetClass.__name__, count_parameters(model)))

  logging.info(model)
  
  if config.net.weights == 'modelzoo':  # Load modelzoo weights if possible.
    logging.info('===> Loading modelzoo weights')
    model.preload_modelzoo()

  # Load weights if specified by the parameter.
  elif config.net.weights.lower() != 'none':
    logging.info('===> Loading weights: ' + config.net.weights)
    # state = torch.load(config.weights)
    state = torch.load(config.net.weights, map_location=lambda s, l: default_restore_location(s, 'cpu'))
   
    if 'state_dict' in state.keys():
      state_key_name = 'state_dict'
    elif 'model_state' in state.keys():
      state_key_name = 'model_state'
    else:
      raise NotImplementedError

    if config.net.weights_for_inner_model:
      model.model.load_state_dict(state['state_dict'])
    else:
      if config.train.lenient_weight_loading:
        matched_weights = load_state_with_same_shape(model, state[state_key_name])
        model_dict = model.state_dict()
        model_dict.update(matched_weights)
        model.load_state_dict(model_dict)
      else:
        model.load_state_dict(state['state_dict'])

  model = model.cuda()
  if distributed:
    model = torch.nn.parallel.DistributedDataParallel(
      module=model, device_ids=[device], output_device=device,
      broadcast_buffers=False, bucket_cap_mb=config.distributed.bucket_cap_mb
    ) 

  if config.train.is_train:
    train(model, train_data_loader, val_data_loader, config)
  else:
    test(model, test_data_loader, config)