def __init__()

in pretrain/pointcontrast/lib/ddp_trainer.py [0:0]


  def __init__(
      self,
      config,
      data_loader):
    assert config.misc.use_gpu and torch.cuda.is_available(), "DDP mode must support GPU"
    num_feats = 3  # always 3 for finetuning.

    self.is_master = du.is_master_proc(config.misc.num_gpus) if config.misc.num_gpus > 1 else True

    # Model initialization
    self.cur_device = torch.cuda.current_device()
    Model = load_model(config.net.model)
    model = Model(
        num_feats,
        config.net.model_n_out,
        config,
        D=3)
    model = model.cuda(device=self.cur_device)
    if config.misc.num_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(
                module=model,
                device_ids=[self.cur_device],
                output_device=self.cur_device,
                broadcast_buffers=False,
        )

    self.config = config
    self.model = model

    self.optimizer = getattr(optim, config.opt.optimizer)(
        model.parameters(),
        lr=config.opt.lr,
        momentum=config.opt.momentum,
        weight_decay=config.opt.weight_decay)

    self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.opt.exp_gamma)
    self.curr_iter = 0
    self.batch_size = data_loader.batch_size
    self.data_loader = data_loader

    self.neg_thresh = config.trainer.neg_thresh
    self.pos_thresh = config.trainer.pos_thresh

    #---------------- optional: resume checkpoint by given path ----------------------
    if config.misc.weight:
        if self.is_master:
          logging.info('===> Loading weights: ' + config.misc.weight)
        state = torch.load(config.misc.weight, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        load_state(model, state['state_dict'], config.misc.lenient_weight_loading)
        if self.is_master:
          logging.info('===> Loaded weights: ' + config.misc.weight)

    #---------------- default: resume checkpoint in current folder ----------------------
    checkpoint_fn = 'weights/weights.pth'
    if osp.isfile(checkpoint_fn):
      if self.is_master:
        logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
      state = torch.load(checkpoint_fn, map_location=lambda s, l: default_restore_location(s, 'cpu'))
      self.curr_iter = state['curr_iter']
      load_state(model, state['state_dict'])
      self.optimizer.load_state_dict(state['optimizer'])
      self.scheduler.load_state_dict(state['scheduler'])
      if self.is_master:
        logging.info("=> loaded checkpoint '{}' (curr_iter {})".format(checkpoint_fn, state['curr_iter']))
    else:
      logging.info("=> no checkpoint found at '{}'".format(checkpoint_fn))

    if self.is_master:
        self.writer = SummaryWriter(logdir='logs')
        if not os.path.exists('weights'):
          os.makedirs('weights', mode=0o755)
        OmegaConf.save(config, 'config.yaml')