in downstream/semseg/lib/ddp_trainer.py [0:0]
def train(self):
# Set up the train flag for batch normalization
self.model.train()
# Configuration
data_timer, iter_timer = Timer(), Timer()
fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer()
data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter(), AverageMeter()
scores = AverageMeter()
losses = {
'semantic_loss': AverageMeter(),
'total_loss': AverageMeter()
}
# Train the network
logging.info('===> Start training on {} GPUs, batch-size={}'.format(
get_world_size(), self.config.data.batch_size))
data_iter = self.train_data_loader.__iter__() # (distributed) infinite sampler
while self.is_training:
for _ in range(len(self.train_data_loader) // self.config.optimizer.iter_size):
self.optimizer.zero_grad()
data_time, batch_score = 0, 0
batch_losses = {
'semantic_loss': 0.0,
'offset_dir_loss': 0.0,
'offset_norm_loss': 0.0,
'total_loss': 0.0}
iter_timer.tic()
# set random seed for every iteration for trackability
self.set_seed()
for sub_iter in range(self.config.optimizer.iter_size):
# Get training data
data_timer.tic()
if self.config.data.return_transformation:
coords, input, target, _ = data_iter.next()
else:
coords, input, target = data_iter.next()
# Preprocess input
color = input[:, :3].int()
if self.config.augmentation.normalize_color:
input[:, :3] = input[:, :3] / 255. - 0.5
sinput = SparseTensor(input, coords).to(self.cur_device)
data_time += data_timer.toc(False)
# Feed forward
fw_timer.tic()
inputs = (sinput,)
soutput, _ = self.model(*inputs)
# The output of the network is not sorted
target = target.long().to(self.cur_device)
semantic_loss = self.criterion(soutput.F, target.long())
total_loss = semantic_loss
# Compute and accumulate gradient
total_loss /= self.config.optimizer.iter_size
pred = get_prediction(self.train_data_loader.dataset, soutput.F, target)
score = precision_at_one(pred, target)
# bp the loss
fw_timer.toc(False)
bw_timer.tic()
total_loss.backward()
bw_timer.toc(False)
# gather information
logging_output = {'total_loss': total_loss.item(), 'semantic_loss': semantic_loss.item(), 'score': score / self.config.optimizer.iter_size}
ddp_timer.tic()
if self.config.misc.num_gpus > 1:
logging_output = all_gather(logging_output)
logging_output = {w: np.mean([
a[w] for a in logging_output]
) for w in logging_output[0]}
batch_losses['total_loss'] += logging_output['total_loss']
batch_losses['semantic_loss'] += logging_output['semantic_loss']
batch_score += logging_output['score']
ddp_timer.toc(False)
# Update number of steps
self.optimizer.step()
self.scheduler.step()
data_time_avg.update(data_time)
iter_time_avg.update(iter_timer.toc(False))
fw_time_avg.update(fw_timer.diff)
bw_time_avg.update(bw_timer.diff)
ddp_time_avg.update(ddp_timer.diff)
losses['total_loss'].update(batch_losses['total_loss'], target.size(0))
losses['semantic_loss'].update(batch_losses['semantic_loss'], target.size(0))
scores.update(batch_score, target.size(0))
if self.curr_iter >= self.config.optimizer.max_iter:
self.is_training = False
break
if self.curr_iter % self.config.train.stat_freq == 0 or self.curr_iter == 1:
lrs = ', '.join(['{:.3e}'.format(x) for x in self.scheduler.get_last_lr()])
debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}, Sem {:.4f} \tLR: {}\t".format(
self.epoch, self.curr_iter, len(self.train_data_loader) // self.config.optimizer.iter_size,
losses['total_loss'].avg, losses['semantic_loss'].avg, lrs)
debug_str += "Score {:.3f}\tData time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format(
scores.avg, data_time_avg.avg, fw_time_avg.avg, bw_time_avg.avg, ddp_time_avg.avg, iter_time_avg.avg)
logging.info(debug_str)
# Reset timers
data_time_avg.reset()
iter_time_avg.reset()
# Write logs
if self.is_master:
self.writer.add_scalar('train/loss', losses['total_loss'].avg, self.curr_iter)
self.writer.add_scalar('train/semantic_loss', losses['semantic_loss'].avg, self.curr_iter)
self.writer.add_scalar('train/precision_at_1', scores.avg, self.curr_iter)
self.writer.add_scalar('train/learning_rate', self.scheduler.get_last_lr()[0], self.curr_iter)
# clear loss
losses['total_loss'].reset()
losses['semantic_loss'].reset()
scores.reset()
# Validation
if self.curr_iter % self.config.train.val_freq == 0 and self.is_master:
self.validate()
self.model.train()
if self.curr_iter % self.config.train.empty_cache_freq == 0:
# Clear cache
torch.cuda.empty_cache()
# End of iteration
self.curr_iter += 1
self.epoch += 1
# Explicit memory cleanup
if hasattr(data_iter, 'cleanup'):
data_iter.cleanup()
# Save the final model
if self.is_master:
self.validate()