in train_vclr.py [0:0]
def main(args, writer):
train_loader = get_loader(args)
n_data = len(train_loader.dataset)
logger.info("length of training dataset: {}".format(n_data))
model, model_ema = build_model(args)
logger.info('{}'.format(model))
contrast = MemorySeCo(128, args.nce_k, args.nce_t, args.nce_t_intra).cuda()
contrast_tsn = MemoryVCLR(128, args.nce_k, args.nce_t).cuda()
criterion = NCESoftmaxLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(),
lr=args.batch_size * dist.get_world_size() / 256 * args.base_lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
scheduler = get_scheduler(optimizer, len(train_loader), args)
model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=args.broadcast_buffer)
logger.info('Distributed Enabled')
# optionally resume from a checkpoint
if args.resume:
assert os.path.isfile(args.resume)
load_checkpoint(args, model, model_ema, contrast, contrast_tsn, optimizer, scheduler, logger.info)
# routine
logger.info('Training')
timer = mmcv.Timer()
for epoch in range(args.start_epoch, args.epochs + 1):
train_loader.sampler.set_epoch(epoch)
loss = train_vclr(epoch, train_loader, model, model_ema, contrast, contrast_tsn, criterion, optimizer,
scheduler, writer, args)
logger.info('epoch {}, total time {:.2f}, loss={}'.format(epoch, timer.since_last_check(), loss))
if dist.get_rank() == 0:
save_checkpoint(args, epoch, model, model_ema, contrast, contrast_tsn, optimizer, scheduler, logger.info)
dist.barrier()