in contactopt/train_deepcontact.py [0:0]
def train_epoch(epoch):
model.train()
scheduler.step()
loss_meter = util.AverageMeter('Loss', ':.2f')
for idx, data in enumerate(tqdm(train_loader)):
data = util.dict_to_device(data, device)
batch_size = data['hand_pose_gt'].shape[0]
optimizer.zero_grad()
out = model(data['hand_verts_aug'], data['hand_feats_aug'], data['obj_sampled_verts_aug'], data['obj_feats_aug'])
losses = calc_losses(out, data['obj_contact_gt'], data['hand_contact_gt'], data['obj_sampled_idx'])
loss = losses['contact_obj'] * args.loss_c_obj + losses['contact_hand'] * args.loss_c_hand
loss_meter.update(loss.item(), batch_size) # TODO better loss monitoring
loss.backward()
optimizer.step()
if idx % 10 == 0:
print('{} / {}'.format(idx, len(train_loader)), loss_meter)
global_iter = epoch * len(train_loader) + idx
writer.add_scalar('training/loss_contact_obj', losses['contact_obj'], global_iter)
writer.add_scalar('training/loss_contact_hand', losses['contact_hand'], global_iter)
writer.add_scalar('training/lr', scheduler.get_lr(), global_iter)
print('Train epoch: {}. Avg loss {:.4f} --------------------'.format(epoch, loss_meter.avg))