in downstream/semseg/lib/train.py [0:0]
def train(model, data_loader, val_data_loader, config, transform_data_fn=None):
device = config.distributed.device_id
distributed = get_world_size() > 1
# Set up the train flag for batch normalization
model.train()
# Configuration
if not distributed or get_rank() == 0:
writer = SummaryWriter(log_dir='tensorboard')
data_timer, iter_timer = Timer(), Timer()
fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer()
data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter(), AverageMeter()
losses, scores = AverageMeter(), AverageMeter()
optimizer = initialize_optimizer(model.parameters(), config.optimizer)
scheduler = initialize_scheduler(optimizer, config.optimizer)
criterion = nn.CrossEntropyLoss(ignore_index=config.data.ignore_label)
# Train the network
logging.info('===> Start training on {} GPUs, batch-size={}'.format(
get_world_size(), config.data.batch_size * get_world_size()
))
best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True
if config.train.resume:
checkpoint_fn = config.train.resume + '/weights.pth'
if osp.isfile(checkpoint_fn):
logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
state = torch.load(checkpoint_fn, map_location=lambda s, l: default_restore_location(s, 'cpu'))
curr_iter = state['iteration'] + 1
epoch = state['epoch']
load_state(model, state['state_dict'])
if config.train.resume_optimizer:
scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter)
optimizer.load_state_dict(state['optimizer'])
if 'best_val' in state:
best_val_miou = state['best_val']
best_val_iter = state['best_val_iter']
logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch']))
else:
raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn))
data_iter = data_loader.__iter__() # (distributed) infinite sampler
while is_training:
for iteration in range(len(data_loader) // config.optimizer.iter_size):
optimizer.zero_grad()
data_time, batch_loss, batch_score = 0, 0, 0
iter_timer.tic()
# set random seed for every iteration for trackability
_set_seed(config, curr_iter)
for sub_iter in range(config.optimizer.iter_size):
# Get training data
data_timer.tic()
coords, input, target = data_iter.next()
# For some networks, making the network invariant to even, odd coords is important
coords[:, :3] += (torch.rand(3) * 100).type_as(coords)
# Preprocess input
color = input[:, :3].int()
if config.augmentation.normalize_color:
input[:, :3] = input[:, :3] / 255. - 0.5
sinput = SparseTensor(input, coords).to(device)
data_time += data_timer.toc(False)
# Feed forward
fw_timer.tic()
inputs = (sinput,) if config.net.wrapper_type==None else (sinput, coords, color)
# model.initialize_coords(*init_args)
soutput = model(*inputs)
# The output of the network is not sorted
target = target.long().to(device)
loss = criterion(soutput.F, target.long())
# Compute and accumulate gradient
loss /= config.optimizer.iter_size
pred = get_prediction(data_loader.dataset, soutput.F, target)
score = precision_at_one(pred, target)
fw_timer.toc(False)
bw_timer.tic()
# bp the loss
loss.backward()
bw_timer.toc(False)
# gather information
logging_output = {'loss': loss.item(), 'score': score / config.optimizer.iter_size}
ddp_timer.tic()
if distributed:
logging_output = all_gather_list(logging_output)
logging_output = {w: np.mean([
a[w] for a in logging_output]
) for w in logging_output[0]}
batch_loss += logging_output['loss']
batch_score += logging_output['score']
ddp_timer.toc(False)
# Update number of steps
optimizer.step()
scheduler.step()
data_time_avg.update(data_time)
iter_time_avg.update(iter_timer.toc(False))
fw_time_avg.update(fw_timer.diff)
bw_time_avg.update(bw_timer.diff)
ddp_time_avg.update(ddp_timer.diff)
losses.update(batch_loss, target.size(0))
scores.update(batch_score, target.size(0))
if curr_iter >= config.optimizer.max_iter:
is_training = False
break
if curr_iter % config.train.stat_freq == 0 or curr_iter == 1:
lrs = ', '.join(['{:.3e}'.format(x) for x in scheduler.get_last_lr()])
debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
epoch, curr_iter,
len(data_loader) // config.optimizer.iter_size, losses.avg, lrs)
debug_str += "Score {:.3f}\tData time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format(
scores.avg, data_time_avg.avg, fw_time_avg.avg, bw_time_avg.avg, ddp_time_avg.avg, iter_time_avg.avg)
logging.info(debug_str)
# Reset timers
data_time_avg.reset()
iter_time_avg.reset()
# Write logs
if not distributed or get_rank() == 0:
writer.add_scalar('training/loss', losses.avg, curr_iter)
writer.add_scalar('training/precision_at_1', scores.avg, curr_iter)
writer.add_scalar('training/learning_rate', scheduler.get_last_lr()[0], curr_iter)
losses.reset()
scores.reset()
# Save current status, save before val to prevent occational mem overflow
if curr_iter % config.train.save_freq == 0:
checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter)
# Validation
if curr_iter % config.train.val_freq == 0 and (not distributed or get_rank() == 0):
val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn)
if val_miou > best_val_miou:
best_val_miou = val_miou
best_val_iter = curr_iter
checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter,
"best_val")
logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))
# Recover back
model.train()
if curr_iter % config.train.empty_cache_freq == 0:
# Clear cache
torch.cuda.empty_cache()
# End of iteration
curr_iter += 1
epoch += 1
# Explicit memory cleanup
if hasattr(data_iter, 'cleanup'):
data_iter.cleanup()
# Save the final model
checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter)
val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn)
if val_miou > best_val_miou:
best_val_miou = val_miou
best_val_iter = curr_iter
checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val")
logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))