def run_eft_step_with_2dhand()

in eft/train/eftFitter.py [0:0]


    def run_eft_step_with_2dhand(self, input_batch, iterIdx=0):

        """Setting up EFT mode"""
        self.model.train()
        if self.options.bExemplarMode:
            self.exemplerTrainingMode()


        """ Get data from the batch """
        images = input_batch['img'] # input image
        gt_keypoints_2d = input_batch['keypoints'].clone()# 2D keypoints           #[N,49,3]

        assert 'kp_rightHand_gt' in input_batch.keys()
        assert 'kp_leftHand_gt' in input_batch.keys()
        gt_kp_rhand_2d = input_batch['kp_rightHand_gt'].clone()# 2D keypoints           #[N,21,3]
        gt_kp_lhand_2d = input_batch['kp_leftHand_gt'].clone()# 2D keypoints           #[N,21,3]

        gt_pose = input_batch['pose'] # SMPL pose parameters                #[N,72]
        gt_betas = input_batch['betas'] # SMPL beta parameters              #[N,10]
        
        JOINT_SCALING_3D = 1.0 #3D joint scaling 
        gt_joints = input_batch['pose_3d']*JOINT_SCALING_3D # 3D pose                        #[N,24,4]
        has_pose_3d = input_batch['has_pose_3d'].byte()==1 # flag that indicates whether 3D pose is valid
        indices = input_batch['sample_index'] # index of example inside its dataset
        batch_size = images.shape[0]

        is_flipped = input_batch['is_flipped'] # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch['rot_angle'] # rotation angle used for data augmentation
        dataset_name = input_batch['dataset_name'] # name of the dataset the image comes from

        index_cpu = indices.cpu()
        if self.options.bExemplar_dataLoaderStart>=0:
            index_cpu +=self.options.bExemplar_dataLoaderStart      #Bug fixed.

        #Check existing SPIN fits
        opt_pose, opt_betas, opt_validity = self.fits_dict[(dataset_name, index_cpu, rot_angle.cpu(), is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        

        """ Run model to make a prediction """
        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,[0]], pose2rot=False)
        pred_vertices = pred_output.vertices
        pred_joints_3d = pred_output.joints


        """ Computing Loss """
        # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # This camera translation can be used in a full perspective projection
        pred_cam_t = torch.stack([pred_camera[:,1],
                                  pred_camera[:,2],
                                  2*self.focal_length/(self.options.img_res * pred_camera[:,0] +1e-9)],dim=-1)

        camera_center = torch.zeros(batch_size, 2, device=self.device)
      
        # weakProjection_gpu################
        pred_keypoints_2d = weakProjection_gpu(pred_joints_3d, pred_camera[:,0], pred_camera[:,1:] )           #N, 49, 2

        #Make sure hand mode with SMPLX model
        assert self.options.bUseHand2D and self.options.bUseSMPLX

        pred_right_hand_2d = weakProjection_gpu(pred_output.right_hand_joints, pred_camera[:,0], pred_camera[:,1:] )           #N, 49, 2
        pred_left_hand_2d = weakProjection_gpu(pred_output.left_hand_joints, pred_camera[:,0], pred_camera[:,1:] )           #N, 49, 2

        if True:    #Ignore hips and hip centers, foot
            LENGTH_THRESHOLD = 0.0089 #1/112.0     #at least it should be 5 pixel

            #Disable Hips by default
            if self.options.eft_withHip2D==False:     
                gt_keypoints_2d[:,2+25,2]=0
                gt_keypoints_2d[:,3+25,2]=0
                gt_keypoints_2d[:,14+25,2]=0

            # #Compute angle knee to ankle orientation
            gt_boneOri_leftLeg = gt_keypoints_2d[:,5+25,:2]  -  gt_keypoints_2d[:,4+25,:2]             #Left lower leg orientation     #(N,2)
            gt_boneOri_leftLeg, leftLegLeng = normalize_2dvector(gt_boneOri_leftLeg)

            if leftLegLeng>LENGTH_THRESHOLD:
                leftLegValidity = gt_keypoints_2d[:,5+25, 2]  * gt_keypoints_2d[:,4+25, 2]
                pred_boneOri_leftLeg = pred_keypoints_2d[:,5+25,:2]  -  pred_keypoints_2d[:,4+25,:2]
                pred_boneOri_leftLeg, _ = normalize_2dvector(pred_boneOri_leftLeg)
                loss_legOri_left  = torch.ones(1).to(self.device) - torch.dot(gt_boneOri_leftLeg.view(-1),pred_boneOri_leftLeg.view(-1))
            else:
                loss_legOri_left = torch.zeros(1).to(self.device)
                leftLegValidity  = torch.zeros(1).to(self.device)

            gt_boneOri_rightLeg = gt_keypoints_2d[:,0+25,:2] -  gt_keypoints_2d[:,1+25,:2]            #Right lower leg orientation
            gt_boneOri_rightLeg, rightLegLeng = normalize_2dvector(gt_boneOri_rightLeg)
            if rightLegLeng>LENGTH_THRESHOLD:

                rightLegValidity = gt_keypoints_2d[:,0+25, 2]  * gt_keypoints_2d[:,1+25, 2]
                pred_boneOri_rightLeg = pred_keypoints_2d[:,0+25,:2]  -  pred_keypoints_2d[:,1+25,:2]
                pred_boneOri_rightLeg, _ = normalize_2dvector(pred_boneOri_rightLeg)
                loss_legOri_right  = torch.ones(1).to(self.device) - torch.dot(gt_boneOri_rightLeg.view(-1),pred_boneOri_rightLeg.view(-1))
            else:
                loss_legOri_right = torch.zeros(1).to(self.device)
                rightLegValidity = torch.zeros(1).to(self.device)
            # print("leftLegLeng: {}, rightLegLeng{}".format(leftLegLeng,rightLegLeng ))
            loss_legOri = leftLegValidity* loss_legOri_left + rightLegValidity* loss_legOri_right

            #Disable Foots
            # gt_keypoints_2d[:,5+25,2]=0     #Left foot
            # gt_keypoints_2d[:,0+25,2]=0     #Right foot

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints_2d = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            self.options.openpose_train_weight,
                                            self.options.gt_train_weight)

        loss_keypoints_2d_hand = self.keypoint_loss_keypoint21(pred_right_hand_2d, gt_kp_rhand_2d,1.0) + self.keypoint_loss_keypoint21(pred_left_hand_2d, gt_kp_lhand_2d,1.0)


        # # Compute 3D keypoint loss
        if self.options.bExemplarWith3DSkel:
            # loss_keypoints_3d = self.keypoint_3d_loss(pred_joints_3d, gt_joints, has_pose_3d)
            loss_keypoints_3d = self.keypoint_3d_loss_panopticDB(pred_joints_3d, gt_joints, has_pose_3d)
        else:
            loss_keypoints_3d = torch.tensor(0)
       
        # loss_keypoints_3d = self.keypoint_3d_loss_modelSkel(pred_joints_3d, gt_model_joints[:,25:,:], has_pose_3d)

        loss_regr_betas_noReject = torch.mean(pred_betas**2)

        loss = self.options.keypoint_loss_weight * loss_keypoints_2d  + \
                self.options.beta_loss_weight * loss_regr_betas_noReject  + \
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() 

        #Add hand keypoint 2D loss
        if True:
            loss = loss + self.options.keypoint_loss_weight*loss_keypoints_2d_hand *2.0     #2.0 to make it more important
        
        if self.options.bExemplarWith3DSkel:
            loss = loss + self.options.keypoint_loss_weight * loss_keypoints_3d
            # loss = loss_keypoints_3d        #TODO: DEBUGGIN

        if False:        #Leg orientation loss
            loss = loss + 0.005*loss_legOri

        loss *= 60
        # print("loss2D: {}, loss3D: {}".format( self.options.keypoint_loss_weight * loss_keypoints_2d,self.options.keypoint_loss_weight * loss_keypoints_3d  )  )


        """ Do backprop """
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        """ Handling output """
        # Pack output arguments for tensorboard logging
        output = {'pred_vertices': 0, #pred_vertices.detach(),
                  'opt_vertices': 0,
                  'pred_cam_t': 0,#pred_cam_t.detach(),
                  'opt_cam_t': 0}
        #Save result
        output={}
        output['pred_pose_rotmat'] = pred_rotmat.detach().cpu().numpy()
        output['pred_shape'] = pred_betas.detach().cpu().numpy()
        output['pred_camera'] = pred_camera.detach().cpu().numpy()
        
        #If there exists SPIN fits, save that for comparison later
        output['opt_pose'] = opt_pose.detach().cpu().numpy()
        output['opt_beta'] = opt_betas.detach().cpu().numpy()
        
        output['sampleIdx'] = input_batch['sample_index'].detach().cpu().numpy()     #To use loader directly
        output['imageName'] = input_batch['imgname']
        output['scale'] = input_batch['scale'] .detach().cpu().numpy()
        output['center'] = input_batch['center'].detach().cpu().numpy()

        if 'annotId' in input_batch.keys():
            output['annotId'] = input_batch['annotId'].detach().cpu().numpy()

        if 'subjectId' in input_batch.keys():
            if input_batch['subjectId'][0]!="":
                output['subjectId'] = input_batch['subjectId'][0].item()

        #To save new db file
        output['keypoint2d'] = input_batch['keypoints_original'].detach().cpu().numpy()
        output['keypoint2d_cropped'] = input_batch['keypoints'].detach().cpu().numpy()


        losses = {'loss': loss.detach().item(),
                  'loss_keypoints': loss_keypoints_2d.detach().item(),
                  'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
                #   'loss_regr_pose': loss_regr_pose.detach().item(),
                  'loss_regr_betas': loss_regr_betas_noReject.detach().item()}
                #   'loss_shape': loss_shape.detach().item()}

        """ Visualiztion for Debuggin """
        if self.options.bDebug_visEFT:#g_debugVisualize:    #Debug Visualize input
            
            # For visualization, de-normalize 2D keypoints from [-1,1] to pixel space
            gt_keypoints_2d_orig = gt_keypoints_2d.clone()                     
            gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (gt_keypoints_2d_orig[:, :, :-1] + 1)        #49: (25+24) x 3 


            # print("Image Name: {}".format(output['imageName']))
            for b in range(batch_size):

                # DenormalizeImg
                curImgVis = deNormalizeBatchImg(images[b].cpu())
                viewer2D.ImShow(curImgVis, name='rawIm', scale=4.0)

                # Visualize GT 2D keypoints
                if True:
                    gt_keypoints_2d_orig_vis = gt_keypoints_2d_orig.detach().cpu().numpy()
                    gt_keypoints_2d_orig_vis[b,:25,2] = 0       #Don't show openpose
                    curImgVis = viewer2D.Vis_Skeleton_2D_SPIN49(gt_keypoints_2d_orig_vis[b,:,:2], gt_keypoints_2d_orig_vis[b,:,2], bVis= False, image=curImgVis)
                
                    curImgVis = viewer2D.Vis_Skeleton_2D_Openpose_hand( (gt_kp_rhand_2d[0,:,:2].cpu().numpy()+1)*112, gt_kp_rhand_2d[0,:,2].cpu().numpy(),image =curImgVis)
                    curImgVis = viewer2D.Vis_Skeleton_2D_Openpose_hand( (gt_kp_lhand_2d[0,:,:2].cpu().numpy()+1)*112, gt_kp_lhand_2d[0,:,2].cpu().numpy(),image =curImgVis)

                # curImgVis = viewer2D.Vis_Skeleton_2D_Openpose18(gt_keypoints_2d_orig[b,:,:2].cpu().numpy(), gt_keypoints_2d_orig[b,:,2], bVis= False, image=curImgVis)

                ############### Visualize Mesh #################
                #Visualize SMPL in image space
                pred_smpl_output, pred_smpl_output_bbox  = smpl_utils.visSMPLoutput_bboxSpace(self.smpl, {"pred_rotmat":pred_rotmat, "pred_shape":pred_betas, "pred_camera":pred_camera}
                                                    , image = curImgVis, waittime=-1)

            
                #Visualize GT Mesh
                if False:
                    gtOut = {"pred_pose":gt_pose, "pred_shape":gt_betas, "pred_camera":pred_camera}
                    # _, gt_smpl_output_bbox = smpl_utils.getSMPLoutput_bboxSpace(self.smpl, gtOut)
                    _, gt_smpl_output_bbox = smpl_utils.getSMPLoutput_bboxSpace(self.smpl_male, gtOut)          #Assuming Male model
                    gt_smpl_output_bbox['body_mesh']['color']  = glViewer.g_colorSet['hand']
                    glViewer.addMeshData( [gt_smpl_output_bbox['body_mesh']], bComputeNormal=True)

                ############### Visualize Skeletons ############### 
                glViewer.setSkeleton( [pred_smpl_output_bbox['body_joints_vis'] ])

                if False:
                    glViewer.addSkeleton( [gt_smpl_output_bbox['body_joints_vis'] ], colorRGB= glViewer.g_colorSet['hand'] )

                if True:
                    glViewer.show(1)
                elif False:   #Render to Files in original image space
            
                    #Get Skeletons
                    img_original = cv2.imread(input_batch['imgname'][0])
                    # viewer2D.ImShow(img_original, waitTime=0)
                    bboxCenter =  input_batch['center'].detach().cpu()[0]
                    bboxScale = input_batch['scale'].detach().cpu()[0]
                    imgShape = img_original.shape[:2]
                    smpl_output, smpl_output_bbox, smpl_output_imgspace  = smpl_utils.getSMPLoutput_imgSpace(self.smpl, {"pred_rotmat":pred_rotmat, "pred_shape":pred_betas, "pred_camera":pred_camera},
                                                                    bboxCenter, bboxScale, imgShape)

                    glViewer.setBackgroundTexture(img_original)       #Vis raw video as background
                    glViewer.setWindowSize(img_original.shape[1]*2, img_original.shape[0]*2)       #Vis raw video as background
                    glViewer.setMeshData([smpl_output_imgspace['body_mesh']], bComputeNormal = True )       #Vis raw video as background
                    glViewer.setSkeleton([])
                
                    imgname = os.path.basename(input_batch['imgname'][0])[:-4]
                    fileName = "{0}_{1}_{2:04d}".format(dataset_name[0],  imgname, iterIdx)

                    # rawImg = cv2.putText(rawImg,data['subjectId'],(100,100), cv2.FONT_HERSHEY_PLAIN, 2, (255,255,0),2)
                    glViewer.render_on_image('/home/hjoo/temp/render_eft', fileName, img_original, scaleFactor=2)

                else:
                    #Render
                    if True:
                        imgname = output['imageName'][b]
                        root_imgname = os.path.basename(imgname)[:-4]
                        renderRoot=f'/home/hjoo/temp/render_eft/eft_{root_imgname}'
                        imgname='{:04d}'.format(iterIdx)

                        # smpl_utils.renderSMPLoutput(renderRoot,'overlaid','raw',imgname=imgname)
                        smpl_utils.renderSMPLoutput(renderRoot,'overlaid','mesh',imgname=imgname)
                        smpl_utils.renderSMPLoutput(renderRoot,'overlaid','skeleton',imgname=imgname)
                        smpl_utils.renderSMPLoutput(renderRoot,'side','mesh',imgname=imgname)




                # # Show projection of SMPL sksleton
                # if False:
                #     pred_keypoints_2d_vis = pred_keypoints_2d[b,:,:2].detach().cpu().numpy()
                #     pred_keypoints_2d_vis = 0.5 * self.options.img_res * (pred_keypoints_2d_vis + 1)        #49: (25+24) x 3 
                #     if glViewer.g_bShowSkeleton:
                #         curImgVis = viewer2D.Vis_Skeleton_2D_general(pred_keypoints_2d_vis, bVis= False, image=curImgVis)
                # viewer2D.ImShow(curImgVis, scale=2.0, waitTime=1)


        bCompute3DError = True
        if bCompute3DError and self.options.bUseSMPLX==False:
            pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
            pred_vertices = pred_output.vertices

            gt_output = self.smpl_male(betas=gt_betas, body_pose=gt_pose[:,3:], global_orient=gt_pose[:,:3])
            gt_vertices = gt_output.vertices
            # Reconstuction_error
            J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float()        #17,6890
            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).cuda()
            joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14

            r_error = reconstruction_error_fromMesh(J_regressor_batch, joint_mapper_h36m, pred_vertices, gt_vertices)

            # print("r_error:{}".format(r_error[0]*1000) )

            losses['r_error'] = r_error[0]*1000
        else:
            losses['r_error'] = 0

        return output, losses