in segmentation/tool/train.py [0:0]
def train(train_loader, model, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
main_loss_meter = AverageMeter()
aux_loss_meter = AverageMeter()
loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
model.train()
end = time.time()
max_iter = args.epochs * len(train_loader)
for i, (input, target) in enumerate(train_loader):
data_time.update(time.time() - end)
if args.zoom_factor != 8:
h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
# 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
target = F.interpolate(target.unsqueeze(1).float(), size=(h, w), mode='bilinear', align_corners=True).squeeze(1).long()
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
output, main_loss, aux_loss = model(input, target)
if not args.multiprocessing_distributed:
main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
loss = main_loss + args.aux_weight * aux_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
n = input.size(0)
if args.multiprocessing_distributed:
main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n # not considering ignore pixels
count = target.new_tensor([n], dtype=torch.long)
dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
n = count.item()
main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n
intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
if args.multiprocessing_distributed:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
main_loss_meter.update(main_loss.item(), n)
aux_loss_meter.update(aux_loss.item(), n)
loss_meter.update(loss.item(), n)
batch_time.update(time.time() - end)
end = time.time()
current_iter = epoch * len(train_loader) + i + 1
current_lr = poly_learning_rate(args.base_lr, current_iter, max_iter, power=args.power)
for index in range(0, args.index_split):
optimizer.param_groups[index]['lr'] = current_lr
for index in range(args.index_split, len(optimizer.param_groups)):
optimizer.param_groups[index]['lr'] = current_lr * 10
remain_iter = max_iter - current_iter
remain_time = remain_iter * batch_time.avg
t_m, t_s = divmod(remain_time, 60)
t_h, t_m = divmod(t_m, 60)
remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
if (i + 1) % args.print_freq == 0 and main_process():
logger.info('Epoch: [{}/{}][{}/{}] '
'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Remain {remain_time} '
'MainLoss {main_loss_meter.val:.4f} '
'AuxLoss {aux_loss_meter.val:.4f} '
'Loss {loss_meter.val:.4f} '
'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
batch_time=batch_time,
data_time=data_time,
remain_time=remain_time,
main_loss_meter=main_loss_meter,
aux_loss_meter=aux_loss_meter,
loss_meter=loss_meter,
accuracy=accuracy))
if main_process():
writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
writer.add_scalar('allAcc_train_batch', accuracy, current_iter)
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
mIoU = np.mean(iou_class)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
if main_process():
logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch+1, args.epochs, mIoU, mAcc, allAcc))
return main_loss_meter.avg, mIoU, mAcc, allAcc