in pretrain/pointcontrast/lib/ddp_trainer.py [0:0]
def __init__(
self,
config,
data_loader):
assert config.misc.use_gpu and torch.cuda.is_available(), "DDP mode must support GPU"
num_feats = 3 # always 3 for finetuning.
self.is_master = du.is_master_proc(config.misc.num_gpus) if config.misc.num_gpus > 1 else True
# Model initialization
self.cur_device = torch.cuda.current_device()
Model = load_model(config.net.model)
model = Model(
num_feats,
config.net.model_n_out,
config,
D=3)
model = model.cuda(device=self.cur_device)
if config.misc.num_gpus > 1:
model = torch.nn.parallel.DistributedDataParallel(
module=model,
device_ids=[self.cur_device],
output_device=self.cur_device,
broadcast_buffers=False,
)
self.config = config
self.model = model
self.optimizer = getattr(optim, config.opt.optimizer)(
model.parameters(),
lr=config.opt.lr,
momentum=config.opt.momentum,
weight_decay=config.opt.weight_decay)
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.opt.exp_gamma)
self.curr_iter = 0
self.batch_size = data_loader.batch_size
self.data_loader = data_loader
self.neg_thresh = config.trainer.neg_thresh
self.pos_thresh = config.trainer.pos_thresh
#---------------- optional: resume checkpoint by given path ----------------------
if config.misc.weight:
if self.is_master:
logging.info('===> Loading weights: ' + config.misc.weight)
state = torch.load(config.misc.weight, map_location=lambda s, l: default_restore_location(s, 'cpu'))
load_state(model, state['state_dict'], config.misc.lenient_weight_loading)
if self.is_master:
logging.info('===> Loaded weights: ' + config.misc.weight)
#---------------- default: resume checkpoint in current folder ----------------------
checkpoint_fn = 'weights/weights.pth'
if osp.isfile(checkpoint_fn):
if self.is_master:
logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
state = torch.load(checkpoint_fn, map_location=lambda s, l: default_restore_location(s, 'cpu'))
self.curr_iter = state['curr_iter']
load_state(model, state['state_dict'])
self.optimizer.load_state_dict(state['optimizer'])
self.scheduler.load_state_dict(state['scheduler'])
if self.is_master:
logging.info("=> loaded checkpoint '{}' (curr_iter {})".format(checkpoint_fn, state['curr_iter']))
else:
logging.info("=> no checkpoint found at '{}'".format(checkpoint_fn))
if self.is_master:
self.writer = SummaryWriter(logdir='logs')
if not os.path.exists('weights'):
os.makedirs('weights', mode=0o755)
OmegaConf.save(config, 'config.yaml')