in eft/train/eftFitter.py [0:0]
def run_eft_step_with_2dhand(self, input_batch, iterIdx=0):
"""Setting up EFT mode"""
self.model.train()
if self.options.bExemplarMode:
self.exemplerTrainingMode()
""" Get data from the batch """
images = input_batch['img'] # input image
gt_keypoints_2d = input_batch['keypoints'].clone()# 2D keypoints #[N,49,3]
assert 'kp_rightHand_gt' in input_batch.keys()
assert 'kp_leftHand_gt' in input_batch.keys()
gt_kp_rhand_2d = input_batch['kp_rightHand_gt'].clone()# 2D keypoints #[N,21,3]
gt_kp_lhand_2d = input_batch['kp_leftHand_gt'].clone()# 2D keypoints #[N,21,3]
gt_pose = input_batch['pose'] # SMPL pose parameters #[N,72]
gt_betas = input_batch['betas'] # SMPL beta parameters #[N,10]
JOINT_SCALING_3D = 1.0 #3D joint scaling
gt_joints = input_batch['pose_3d']*JOINT_SCALING_3D # 3D pose #[N,24,4]
has_pose_3d = input_batch['has_pose_3d'].byte()==1 # flag that indicates whether 3D pose is valid
indices = input_batch['sample_index'] # index of example inside its dataset
batch_size = images.shape[0]
is_flipped = input_batch['is_flipped'] # flag that indicates whether image was flipped during data augmentation
rot_angle = input_batch['rot_angle'] # rotation angle used for data augmentation
dataset_name = input_batch['dataset_name'] # name of the dataset the image comes from
index_cpu = indices.cpu()
if self.options.bExemplar_dataLoaderStart>=0:
index_cpu +=self.options.bExemplar_dataLoaderStart #Bug fixed.
#Check existing SPIN fits
opt_pose, opt_betas, opt_validity = self.fits_dict[(dataset_name, index_cpu, rot_angle.cpu(), is_flipped.cpu())]
opt_pose = opt_pose.to(self.device)
opt_betas = opt_betas.to(self.device)
""" Run model to make a prediction """
# Feed images in the network to predict camera and SMPL parameters
pred_rotmat, pred_betas, pred_camera = self.model(images)
pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,[0]], pose2rot=False)
pred_vertices = pred_output.vertices
pred_joints_3d = pred_output.joints
""" Computing Loss """
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
# This camera translation can be used in a full perspective projection
pred_cam_t = torch.stack([pred_camera[:,1],
pred_camera[:,2],
2*self.focal_length/(self.options.img_res * pred_camera[:,0] +1e-9)],dim=-1)
camera_center = torch.zeros(batch_size, 2, device=self.device)
# weakProjection_gpu################
pred_keypoints_2d = weakProjection_gpu(pred_joints_3d, pred_camera[:,0], pred_camera[:,1:] ) #N, 49, 2
#Make sure hand mode with SMPLX model
assert self.options.bUseHand2D and self.options.bUseSMPLX
pred_right_hand_2d = weakProjection_gpu(pred_output.right_hand_joints, pred_camera[:,0], pred_camera[:,1:] ) #N, 49, 2
pred_left_hand_2d = weakProjection_gpu(pred_output.left_hand_joints, pred_camera[:,0], pred_camera[:,1:] ) #N, 49, 2
if True: #Ignore hips and hip centers, foot
LENGTH_THRESHOLD = 0.0089 #1/112.0 #at least it should be 5 pixel
#Disable Hips by default
if self.options.eft_withHip2D==False:
gt_keypoints_2d[:,2+25,2]=0
gt_keypoints_2d[:,3+25,2]=0
gt_keypoints_2d[:,14+25,2]=0
# #Compute angle knee to ankle orientation
gt_boneOri_leftLeg = gt_keypoints_2d[:,5+25,:2] - gt_keypoints_2d[:,4+25,:2] #Left lower leg orientation #(N,2)
gt_boneOri_leftLeg, leftLegLeng = normalize_2dvector(gt_boneOri_leftLeg)
if leftLegLeng>LENGTH_THRESHOLD:
leftLegValidity = gt_keypoints_2d[:,5+25, 2] * gt_keypoints_2d[:,4+25, 2]
pred_boneOri_leftLeg = pred_keypoints_2d[:,5+25,:2] - pred_keypoints_2d[:,4+25,:2]
pred_boneOri_leftLeg, _ = normalize_2dvector(pred_boneOri_leftLeg)
loss_legOri_left = torch.ones(1).to(self.device) - torch.dot(gt_boneOri_leftLeg.view(-1),pred_boneOri_leftLeg.view(-1))
else:
loss_legOri_left = torch.zeros(1).to(self.device)
leftLegValidity = torch.zeros(1).to(self.device)
gt_boneOri_rightLeg = gt_keypoints_2d[:,0+25,:2] - gt_keypoints_2d[:,1+25,:2] #Right lower leg orientation
gt_boneOri_rightLeg, rightLegLeng = normalize_2dvector(gt_boneOri_rightLeg)
if rightLegLeng>LENGTH_THRESHOLD:
rightLegValidity = gt_keypoints_2d[:,0+25, 2] * gt_keypoints_2d[:,1+25, 2]
pred_boneOri_rightLeg = pred_keypoints_2d[:,0+25,:2] - pred_keypoints_2d[:,1+25,:2]
pred_boneOri_rightLeg, _ = normalize_2dvector(pred_boneOri_rightLeg)
loss_legOri_right = torch.ones(1).to(self.device) - torch.dot(gt_boneOri_rightLeg.view(-1),pred_boneOri_rightLeg.view(-1))
else:
loss_legOri_right = torch.zeros(1).to(self.device)
rightLegValidity = torch.zeros(1).to(self.device)
# print("leftLegLeng: {}, rightLegLeng{}".format(leftLegLeng,rightLegLeng ))
loss_legOri = leftLegValidity* loss_legOri_left + rightLegValidity* loss_legOri_right
#Disable Foots
# gt_keypoints_2d[:,5+25,2]=0 #Left foot
# gt_keypoints_2d[:,0+25,2]=0 #Right foot
# Compute 2D reprojection loss for the keypoints
loss_keypoints_2d = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
self.options.openpose_train_weight,
self.options.gt_train_weight)
loss_keypoints_2d_hand = self.keypoint_loss_keypoint21(pred_right_hand_2d, gt_kp_rhand_2d,1.0) + self.keypoint_loss_keypoint21(pred_left_hand_2d, gt_kp_lhand_2d,1.0)
# # Compute 3D keypoint loss
if self.options.bExemplarWith3DSkel:
# loss_keypoints_3d = self.keypoint_3d_loss(pred_joints_3d, gt_joints, has_pose_3d)
loss_keypoints_3d = self.keypoint_3d_loss_panopticDB(pred_joints_3d, gt_joints, has_pose_3d)
else:
loss_keypoints_3d = torch.tensor(0)
# loss_keypoints_3d = self.keypoint_3d_loss_modelSkel(pred_joints_3d, gt_model_joints[:,25:,:], has_pose_3d)
loss_regr_betas_noReject = torch.mean(pred_betas**2)
loss = self.options.keypoint_loss_weight * loss_keypoints_2d + \
self.options.beta_loss_weight * loss_regr_betas_noReject + \
((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()
#Add hand keypoint 2D loss
if True:
loss = loss + self.options.keypoint_loss_weight*loss_keypoints_2d_hand *2.0 #2.0 to make it more important
if self.options.bExemplarWith3DSkel:
loss = loss + self.options.keypoint_loss_weight * loss_keypoints_3d
# loss = loss_keypoints_3d #TODO: DEBUGGIN
if False: #Leg orientation loss
loss = loss + 0.005*loss_legOri
loss *= 60
# print("loss2D: {}, loss3D: {}".format( self.options.keypoint_loss_weight * loss_keypoints_2d,self.options.keypoint_loss_weight * loss_keypoints_3d ) )
""" Do backprop """
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
""" Handling output """
# Pack output arguments for tensorboard logging
output = {'pred_vertices': 0, #pred_vertices.detach(),
'opt_vertices': 0,
'pred_cam_t': 0,#pred_cam_t.detach(),
'opt_cam_t': 0}
#Save result
output={}
output['pred_pose_rotmat'] = pred_rotmat.detach().cpu().numpy()
output['pred_shape'] = pred_betas.detach().cpu().numpy()
output['pred_camera'] = pred_camera.detach().cpu().numpy()
#If there exists SPIN fits, save that for comparison later
output['opt_pose'] = opt_pose.detach().cpu().numpy()
output['opt_beta'] = opt_betas.detach().cpu().numpy()
output['sampleIdx'] = input_batch['sample_index'].detach().cpu().numpy() #To use loader directly
output['imageName'] = input_batch['imgname']
output['scale'] = input_batch['scale'] .detach().cpu().numpy()
output['center'] = input_batch['center'].detach().cpu().numpy()
if 'annotId' in input_batch.keys():
output['annotId'] = input_batch['annotId'].detach().cpu().numpy()
if 'subjectId' in input_batch.keys():
if input_batch['subjectId'][0]!="":
output['subjectId'] = input_batch['subjectId'][0].item()
#To save new db file
output['keypoint2d'] = input_batch['keypoints_original'].detach().cpu().numpy()
output['keypoint2d_cropped'] = input_batch['keypoints'].detach().cpu().numpy()
losses = {'loss': loss.detach().item(),
'loss_keypoints': loss_keypoints_2d.detach().item(),
'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
# 'loss_regr_pose': loss_regr_pose.detach().item(),
'loss_regr_betas': loss_regr_betas_noReject.detach().item()}
# 'loss_shape': loss_shape.detach().item()}
""" Visualiztion for Debuggin """
if self.options.bDebug_visEFT:#g_debugVisualize: #Debug Visualize input
# For visualization, de-normalize 2D keypoints from [-1,1] to pixel space
gt_keypoints_2d_orig = gt_keypoints_2d.clone()
gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (gt_keypoints_2d_orig[:, :, :-1] + 1) #49: (25+24) x 3
# print("Image Name: {}".format(output['imageName']))
for b in range(batch_size):
# DenormalizeImg
curImgVis = deNormalizeBatchImg(images[b].cpu())
viewer2D.ImShow(curImgVis, name='rawIm', scale=4.0)
# Visualize GT 2D keypoints
if True:
gt_keypoints_2d_orig_vis = gt_keypoints_2d_orig.detach().cpu().numpy()
gt_keypoints_2d_orig_vis[b,:25,2] = 0 #Don't show openpose
curImgVis = viewer2D.Vis_Skeleton_2D_SPIN49(gt_keypoints_2d_orig_vis[b,:,:2], gt_keypoints_2d_orig_vis[b,:,2], bVis= False, image=curImgVis)
curImgVis = viewer2D.Vis_Skeleton_2D_Openpose_hand( (gt_kp_rhand_2d[0,:,:2].cpu().numpy()+1)*112, gt_kp_rhand_2d[0,:,2].cpu().numpy(),image =curImgVis)
curImgVis = viewer2D.Vis_Skeleton_2D_Openpose_hand( (gt_kp_lhand_2d[0,:,:2].cpu().numpy()+1)*112, gt_kp_lhand_2d[0,:,2].cpu().numpy(),image =curImgVis)
# curImgVis = viewer2D.Vis_Skeleton_2D_Openpose18(gt_keypoints_2d_orig[b,:,:2].cpu().numpy(), gt_keypoints_2d_orig[b,:,2], bVis= False, image=curImgVis)
############### Visualize Mesh #################
#Visualize SMPL in image space
pred_smpl_output, pred_smpl_output_bbox = smpl_utils.visSMPLoutput_bboxSpace(self.smpl, {"pred_rotmat":pred_rotmat, "pred_shape":pred_betas, "pred_camera":pred_camera}
, image = curImgVis, waittime=-1)
#Visualize GT Mesh
if False:
gtOut = {"pred_pose":gt_pose, "pred_shape":gt_betas, "pred_camera":pred_camera}
# _, gt_smpl_output_bbox = smpl_utils.getSMPLoutput_bboxSpace(self.smpl, gtOut)
_, gt_smpl_output_bbox = smpl_utils.getSMPLoutput_bboxSpace(self.smpl_male, gtOut) #Assuming Male model
gt_smpl_output_bbox['body_mesh']['color'] = glViewer.g_colorSet['hand']
glViewer.addMeshData( [gt_smpl_output_bbox['body_mesh']], bComputeNormal=True)
############### Visualize Skeletons ###############
glViewer.setSkeleton( [pred_smpl_output_bbox['body_joints_vis'] ])
if False:
glViewer.addSkeleton( [gt_smpl_output_bbox['body_joints_vis'] ], colorRGB= glViewer.g_colorSet['hand'] )
if True:
glViewer.show(1)
elif False: #Render to Files in original image space
#Get Skeletons
img_original = cv2.imread(input_batch['imgname'][0])
# viewer2D.ImShow(img_original, waitTime=0)
bboxCenter = input_batch['center'].detach().cpu()[0]
bboxScale = input_batch['scale'].detach().cpu()[0]
imgShape = img_original.shape[:2]
smpl_output, smpl_output_bbox, smpl_output_imgspace = smpl_utils.getSMPLoutput_imgSpace(self.smpl, {"pred_rotmat":pred_rotmat, "pred_shape":pred_betas, "pred_camera":pred_camera},
bboxCenter, bboxScale, imgShape)
glViewer.setBackgroundTexture(img_original) #Vis raw video as background
glViewer.setWindowSize(img_original.shape[1]*2, img_original.shape[0]*2) #Vis raw video as background
glViewer.setMeshData([smpl_output_imgspace['body_mesh']], bComputeNormal = True ) #Vis raw video as background
glViewer.setSkeleton([])
imgname = os.path.basename(input_batch['imgname'][0])[:-4]
fileName = "{0}_{1}_{2:04d}".format(dataset_name[0], imgname, iterIdx)
# rawImg = cv2.putText(rawImg,data['subjectId'],(100,100), cv2.FONT_HERSHEY_PLAIN, 2, (255,255,0),2)
glViewer.render_on_image('/home/hjoo/temp/render_eft', fileName, img_original, scaleFactor=2)
else:
#Render
if True:
imgname = output['imageName'][b]
root_imgname = os.path.basename(imgname)[:-4]
renderRoot=f'/home/hjoo/temp/render_eft/eft_{root_imgname}'
imgname='{:04d}'.format(iterIdx)
# smpl_utils.renderSMPLoutput(renderRoot,'overlaid','raw',imgname=imgname)
smpl_utils.renderSMPLoutput(renderRoot,'overlaid','mesh',imgname=imgname)
smpl_utils.renderSMPLoutput(renderRoot,'overlaid','skeleton',imgname=imgname)
smpl_utils.renderSMPLoutput(renderRoot,'side','mesh',imgname=imgname)
# # Show projection of SMPL sksleton
# if False:
# pred_keypoints_2d_vis = pred_keypoints_2d[b,:,:2].detach().cpu().numpy()
# pred_keypoints_2d_vis = 0.5 * self.options.img_res * (pred_keypoints_2d_vis + 1) #49: (25+24) x 3
# if glViewer.g_bShowSkeleton:
# curImgVis = viewer2D.Vis_Skeleton_2D_general(pred_keypoints_2d_vis, bVis= False, image=curImgVis)
# viewer2D.ImShow(curImgVis, scale=2.0, waitTime=1)
bCompute3DError = True
if bCompute3DError and self.options.bUseSMPLX==False:
pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
pred_vertices = pred_output.vertices
gt_output = self.smpl_male(betas=gt_betas, body_pose=gt_pose[:,3:], global_orient=gt_pose[:,:3])
gt_vertices = gt_output.vertices
# Reconstuction_error
J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float() #17,6890
J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).cuda()
joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
r_error = reconstruction_error_fromMesh(J_regressor_batch, joint_mapper_h36m, pred_vertices, gt_vertices)
# print("r_error:{}".format(r_error[0]*1000) )
losses['r_error'] = r_error[0]*1000
else:
losses['r_error'] = 0
return output, losses