in main.py [0:0]
def main_worker(gpu, ngpus, args, cfg):
args.gpu = gpu
ngpus_per_node = ngpus
# Setup environment
args = main_utils.initialize_distributed_backend(args, ngpus_per_node) ### Use other method instead
logger, tb_writter, model_dir = main_utils.prep_environment(args, cfg)
# Define model
model = main_utils.build_model(cfg['model'], logger)
model, args = main_utils.distribute_model_to_cuda(model, args)
# Define dataloaders
train_loader = main_utils.build_dataloaders(cfg['dataset'], cfg['num_workers'], args.multiprocessing_distributed, logger)
# Define criterion
train_criterion = main_utils.build_criterion(cfg['loss'], logger=logger)
train_criterion = train_criterion.cuda()
# Define optimizer
optimizer, scheduler = main_utils.build_optimizer(
params=list(model.parameters())+list(train_criterion.parameters()),
cfg=cfg['optimizer'],
logger=logger)
ckp_manager = main_utils.CheckpointManager(model_dir, rank=args.rank, dist=args.multiprocessing_distributed)
# Optionally resume from a checkpoint
start_epoch, end_epoch = 0, cfg['optimizer']['num_epochs']
if cfg['resume']:
if ckp_manager.checkpoint_exists(last=True):
start_epoch = ckp_manager.restore(restore_last=True, model=model, optimizer=optimizer, train_criterion=train_criterion)
scheduler.step(start_epoch)
logger.add_line("Checkpoint loaded: '{}' (epoch {})".format(ckp_manager.last_checkpoint_fn(), start_epoch))
else:
logger.add_line("No checkpoint found at '{}'".format(ckp_manager.last_checkpoint_fn()))
cudnn.benchmark = True
############################ TRAIN #########################################
test_freq = cfg['test_freq'] if 'test_freq' in cfg else 1
for epoch in range(start_epoch, end_epoch):
if (epoch % 10) == 0:
ckp_manager.save(epoch, model=model, train_criterion=train_criterion, optimizer=optimizer, filename='checkpoint-ep{}.pth.tar'.format(epoch))
if args.multiprocessing_distributed:
train_loader.sampler.set_epoch(epoch)
# Train for one epoch
logger.add_line('='*30 + ' Epoch {} '.format(epoch) + '='*30)
logger.add_line('LR: {}'.format(scheduler.get_lr()))
run_phase('train', train_loader, model, optimizer, train_criterion, epoch, args, cfg, logger, tb_writter)
scheduler.step(epoch)
if ((epoch % test_freq) == 0) or (epoch == end_epoch - 1):
ckp_manager.save(epoch+1, model=model, optimizer=optimizer, train_criterion=train_criterion)