in contactopt/run_contactopt.py [0:0]
def run_contactopt(args):
"""
Actually run ContactOpt approach. Estimates target contact with DeepContact,
then optimizes it. Performs random restarts if selected.
Saves results to a pkl file.
:param args: input settings
"""
print('Running split', args.split)
dataset = ContactDBDataset(args.test_dataset, min_num_cont=args.min_cont)
shuffle = args.vis or args.partial > 0
print('Shuffle:', shuffle)
test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=shuffle, num_workers=6, collate_fn=ContactDBDataset.collate_fn)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_newest_checkpoint()
model.to(device)
model.eval()
all_data = list()
for idx, data in enumerate(tqdm(test_loader)):
data_gpu = util.dict_to_device(data, device)
batch_size = data['obj_sampled_idx'].shape[0]
if args.split != 'fine':
with torch.no_grad():
network_out = model(data_gpu['hand_verts_aug'], data_gpu['hand_feats_aug'], data_gpu['obj_sampled_verts_aug'], data_gpu['obj_feats_aug'])
hand_contact_target = util.class_to_val(network_out['contact_hand']).unsqueeze(2)
obj_contact_target = util.class_to_val(network_out['contact_obj']).unsqueeze(2)
else:
hand_contact_target = data_gpu['hand_contact_gt']
obj_contact_target = util.batched_index_select(data_gpu['obj_contact_gt'], 1, data_gpu['obj_sampled_idx'])
if args.sharpen_thresh > 0: # If flag, sharpen contact
print('Sharpening')
obj_contact_target = util.sharpen_contact(obj_contact_target, slope=2, thresh=args.sharpen_thresh)
hand_contact_target = util.sharpen_contact(hand_contact_target, slope=2, thresh=args.sharpen_thresh)
if args.rand_re > 1: # If we desire random restarts
mtc_orig = data_gpu['hand_mTc_aug'].detach().clone()
print('Doing random optimization restarts')
best_loss = torch.ones(batch_size) * 100000
for re_it in range(args.rand_re):
# Add noise to hand translation and rotation
data_gpu['hand_mTc_aug'] = mtc_orig.detach().clone()
random_rot_mat = pytorch3d.transforms.euler_angles_to_matrix(torch.randn((batch_size, 3), device=device) * args.rand_re_rot / 180 * np.pi, 'ZYX')
data_gpu['hand_mTc_aug'][:, :3, :3] = torch.bmm(random_rot_mat, data_gpu['hand_mTc_aug'][:, :3, :3])
data_gpu['hand_mTc_aug'][:, :3, 3] += torch.randn((batch_size, 3), device=device) * args.rand_re_trans
cur_result = optimize_pose(data_gpu, hand_contact_target, obj_contact_target, n_iter=args.n_iter, lr=args.lr,
w_cont_hand=args.w_cont_hand, w_cont_obj=1, save_history=args.vis, ncomps=args.ncomps,
w_cont_asym=args.w_cont_asym, w_opt_trans=args.w_opt_trans, w_opt_pose=args.w_opt_pose,
w_opt_rot=args.w_opt_rot,
caps_top=args.caps_top, caps_bot=args.caps_bot, caps_rad=args.caps_rad,
caps_on_hand=args.caps_hand,
contact_norm_method=args.cont_method, w_pen_cost=args.w_pen_cost,
w_obj_rot=args.w_obj_rot, pen_it=args.pen_it)
if re_it == 0:
out_pose = torch.zeros_like(cur_result[0])
out_mTc = torch.zeros_like(cur_result[1])
obj_rot = torch.zeros_like(cur_result[2])
opt_state = cur_result[3]
loss_val = cur_result[3][-1]['loss']
for b in range(batch_size):
if loss_val[b] < best_loss[b]:
best_loss[b] = loss_val[b]
out_pose[b, :] = cur_result[0][b, :]
out_mTc[b, :, :] = cur_result[1][b, :, :]
obj_rot[b, :, :] = cur_result[2][b, :, :]
# print('Loss, re', re_it, loss_val)
# print('Best loss', best_loss)
else:
result = optimize_pose(data_gpu, hand_contact_target, obj_contact_target, n_iter=args.n_iter, lr=args.lr,
w_cont_hand=args.w_cont_hand, w_cont_obj=1, save_history=args.vis, ncomps=args.ncomps,
w_cont_asym=args.w_cont_asym, w_opt_trans=args.w_opt_trans, w_opt_pose=args.w_opt_pose,
w_opt_rot=args.w_opt_rot,
caps_top=args.caps_top, caps_bot=args.caps_bot, caps_rad=args.caps_rad,
caps_on_hand=args.caps_hand,
contact_norm_method=args.cont_method, w_pen_cost=args.w_pen_cost,
w_obj_rot=args.w_obj_rot, pen_it=args.pen_it)
out_pose, out_mTc, obj_rot, opt_state = result
obj_contact_upscale = util.upscale_contact(data_gpu['mesh_aug'], data_gpu['obj_sampled_idx'], obj_contact_target)
for b in range(obj_contact_upscale.shape[0]): # Loop over batch
gt_ho = HandObject()
in_ho = HandObject()
out_ho = HandObject()
gt_ho.load_from_batch(data['hand_beta_gt'], data['hand_pose_gt'], data['hand_mTc_gt'], data['hand_contact_gt'], data['obj_contact_gt'], data['mesh_gt'], b)
in_ho.load_from_batch(data['hand_beta_aug'], data['hand_pose_aug'], data['hand_mTc_aug'], hand_contact_target, obj_contact_upscale, data['mesh_aug'], b)
out_ho.load_from_batch(data['hand_beta_aug'], out_pose, out_mTc, data['hand_contact_gt'], data['obj_contact_gt'], data['mesh_aug'], b, obj_rot=obj_rot)
# out_ho.calc_dist_contact(hand=True, obj=True)
all_data.append({'gt_ho': gt_ho, 'in_ho': in_ho, 'out_ho': out_ho})
if args.vis:
show_optimization(data, opt_state, hand_contact_target.detach().cpu().numpy(), obj_contact_upscale.detach().cpu().numpy(),
is_video=args.video, vis_method=args.vis_method)
if idx >= args.partial > 0: # Speed up for eval
break
out_file = 'data/optimized_{}.pkl'.format(args.split)
print('Saving to {}. Len {}'.format(out_file, len(all_data)))
pickle.dump(all_data, open(out_file, 'wb'))