def train_epoch()

in contactopt/train_deepcontact.py [0:0]


def train_epoch(epoch):
    model.train()
    scheduler.step()
    loss_meter = util.AverageMeter('Loss', ':.2f')

    for idx, data in enumerate(tqdm(train_loader)):
        data = util.dict_to_device(data, device)
        batch_size = data['hand_pose_gt'].shape[0]

        optimizer.zero_grad()
        out = model(data['hand_verts_aug'], data['hand_feats_aug'], data['obj_sampled_verts_aug'], data['obj_feats_aug'])
        losses = calc_losses(out, data['obj_contact_gt'], data['hand_contact_gt'], data['obj_sampled_idx'])
        loss = losses['contact_obj'] * args.loss_c_obj + losses['contact_hand'] * args.loss_c_hand

        loss_meter.update(loss.item(), batch_size)   # TODO better loss monitoring
        loss.backward()
        optimizer.step()

        if idx % 10 == 0:
            print('{} / {}'.format(idx, len(train_loader)), loss_meter)

            global_iter = epoch * len(train_loader) + idx
            writer.add_scalar('training/loss_contact_obj', losses['contact_obj'], global_iter)
            writer.add_scalar('training/loss_contact_hand', losses['contact_hand'], global_iter)
            writer.add_scalar('training/lr', scheduler.get_lr(), global_iter)

    print('Train epoch: {}. Avg loss {:.4f} --------------------'.format(epoch, loss_meter.avg))