in contactopt/optimize_pose.py [0:0]
def optimize_pose(data, hand_contact_target, obj_contact_target, n_iter=250, lr=0.01, w_cont_hand=2, w_cont_obj=1,
save_history=False, ncomps=15, w_cont_asym=2, w_opt_trans=0.3, w_opt_pose=1, w_opt_rot=1,
caps_top=0.0005, caps_bot=-0.001, caps_rad=0.001, caps_on_hand=False,
contact_norm_method=0, w_pen_cost=600, w_obj_rot=0, pen_it=0):
"""Runs differentiable optimization to align the hand with the target contact map.
Minimizes the loss between ground truth contact and contact calculated with DiffContact"""
batch_size = data['hand_pose_aug'].shape[0]
device = data['hand_pose_aug'].device
opt_vector = torch.zeros((batch_size, ncomps + 6 + 3), device=device) # 3 hand rot, 3 hand trans, 3 obj rot
opt_vector.requires_grad = True
mano_model = ManoLayer(mano_root='mano/models', use_pca=True, ncomps=ncomps, side='right', flat_hand_mean=False).to(device)
if data['obj_sampled_idx'].numel() > 1:
obj_normals_sampled = util.batched_index_select(data['obj_normals_aug'], 1, data['obj_sampled_idx'])
else: # If we're optimizing over all verts
obj_normals_sampled = data['obj_normals_aug']
optimizer = torch.optim.Adam([opt_vector], lr=lr, amsgrad=True) # AMSgrad helps
loss_criterion = torch.nn.L1Loss(reduction='none') # Benchmarked, L1 performs best vs MSE/SmoothL1
opt_state = []
is_thin = mesh_is_thin(data['mesh_aug'].num_verts_per_mesh())
# print('is thin', is_thin, data['mesh_aug'].num_verts_per_mesh())
for it in range(n_iter):
optimizer.zero_grad()
mano_pose_out = torch.cat([opt_vector[:, 0:3] * w_opt_rot, opt_vector[:, 3:ncomps+3] * w_opt_pose], dim=1)
mano_pose_out[:, :18] += data['hand_pose_aug']
tform_out = util.translation_to_tform(opt_vector[:, ncomps+3:ncomps+6] * w_opt_trans)
hand_verts, hand_joints = util.forward_mano(mano_model, mano_pose_out, data['hand_beta_aug'], [data['hand_mTc_aug'], tform_out]) # 2.2ms
if contact_norm_method != 0 and not caps_on_hand:
with torch.no_grad(): # We need to calculate hand normals if using more complicated methods
mano_mesh = Meshes(verts=hand_verts, faces=mano_model.th_faces.repeat(batch_size, 1, 1))
hand_normals = mano_mesh.verts_normals_padded()
else:
hand_normals = torch.zeros(hand_verts.shape, device=device)
obj_verts = data['obj_sampled_verts_aug']
obj_normals = obj_normals_sampled
obj_rot_mat = rodrigues_layer.batch_rodrigues(opt_vector[:, ncomps+6:])
obj_rot_mat = obj_rot_mat.view(batch_size, 3, 3)
if w_obj_rot > 0:
obj_verts = util.apply_rot(obj_rot_mat, obj_verts, around_centroid=True)
obj_normals = util.apply_rot(obj_rot_mat, obj_normals)
contact_obj, contact_hand = calculate_contact.calculate_contact_capsule(hand_verts, hand_normals, obj_verts, obj_normals,
caps_top=caps_top, caps_bot=caps_bot, caps_rad=caps_rad, caps_on_hand=caps_on_hand, contact_norm_method=contact_norm_method)
contact_obj_sub = obj_contact_target - contact_obj
contact_obj_weighted = contact_obj_sub + torch.nn.functional.relu(contact_obj_sub) * w_cont_asym # Loss for 'missing' contact higher
loss_contact_obj = loss_criterion(contact_obj_weighted, torch.zeros_like(contact_obj_weighted)).mean(dim=(1, 2))
contact_hand_sub = hand_contact_target - contact_hand
contact_hand_weighted = contact_hand_sub + torch.nn.functional.relu(contact_hand_sub) * w_cont_asym # Loss for 'missing' contact higher
loss_contact_hand = loss_criterion(contact_hand_weighted, torch.zeros_like(contact_hand_weighted)).mean(dim=(1, 2))
loss = loss_contact_obj * w_cont_obj + loss_contact_hand * w_cont_hand
if w_pen_cost > 0 and it >= pen_it:
pen_cost = calculate_contact.calculate_penetration_cost(hand_verts, hand_normals, data['obj_sampled_verts_aug'], obj_normals_sampled, is_thin, contact_norm_method)
loss += pen_cost.mean(dim=1) * w_pen_cost
out_dict = {'loss': loss.detach().cpu()}
if save_history:
out_dict['hand_verts'] = hand_verts.detach().cpu()#.numpy()
out_dict['hand_joints'] = hand_joints.detach().cpu()#.numpy()
out_dict['contact_obj'] = contact_obj.detach().cpu()#.numpy()
out_dict['contact_hand'] = contact_hand.detach().cpu()#.numpy()
out_dict['obj_rot'] = obj_rot_mat.detach().cpu()#.numpy()
opt_state.append(out_dict)
loss.mean().backward()
optimizer.step()
tform_full_out = util.aggregate_tforms([data['hand_mTc_aug'], tform_out])
return mano_pose_out, tform_full_out, obj_rot_mat, opt_state