def train_step()

in eft/train/trainer.py [0:0]


    def train_step(self, input_batch):
        self.model.train()
        
        # Get data from the batch
        images = input_batch['img'] # input image
        gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints           #[N,49,3]
        gt_pose = input_batch['pose'] # SMPL pose parameters                #[N,72]
        gt_betas = input_batch['betas'] # SMPL beta parameters              #[N,10]
        gt_joints = input_batch['pose_3d'] # 3D pose                        #[N,24,4]
        has_smpl = input_batch['has_smpl'].byte() ==1 # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte()==1 # flag that indicates whether 3D pose is valid
        is_flipped = input_batch['is_flipped'] # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch['rot_angle'] # rotation angle used for data augmentation
        dataset_name = input_batch['dataset_name'] # name of the dataset the image comes from
        indices = input_batch['sample_index'] # index of example inside its dataset
        batch_size = images.shape[0]


        #Debug temporary scaling for h36m
        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:,3:], global_orient=gt_pose[:,:3])
        
        gt_model_joints = gt_out.joints.detach()             #[N, 49, 3]
        gt_vertices = gt_out.vertices

        # else:
        #     gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:,3:-6], global_orient=gt_pose[:,:3])

        #     gt_model_joints = gt_out.joints.detach()             #[N, 49, 3]
        #     gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary

        opt_pose, opt_betas, opt_validity = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        # if g_smplx == False:
        opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:,3:], global_orient=opt_pose[:,:3])

        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints.detach()

        # else:
        #     opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:,3:-6], global_orient=opt_pose[:,:3])

        #     opt_vertices = opt_output.vertices
        #     opt_joints = opt_output.joints.detach()
        

        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity==0])>0:
            assert min(has_smpl[opt_validity==0])  #All should be True


        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity==0])>0:
            assert min(has_smpl[opt_validity==0])  #All should be True

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()              
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res)


        opt_joint_loss = self.smplify.get_fitting_loss(opt_pose, opt_betas, opt_cam_t,          #opt_pose (N,72)  (N,10)  opt_cam_t: (N,3)
                                                       0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device),      #(N,2)   (112, 112)
                                                       gt_keypoints_2d_orig).mean(dim=-1)

        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        # if g_smplx == False: #Original
        pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
        # else:
        #     pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:-2], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints


        # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # This camera translation can be used in a full perspective projection
        pred_cam_t = torch.stack([pred_camera[:,1],
                                  pred_camera[:,2],
                                  2*self.focal_length/(self.options.img_res * pred_camera[:,0] +1e-9)],dim=-1)


        camera_center = torch.zeros(batch_size, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(pred_joints,
                                                   rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(batch_size, -1, -1),
                                                   translation=pred_cam_t,
                                                   focal_length=self.focal_length,
                                                   camera_center=camera_center)
        # Normalize keypoints to [-1,1]
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

        #Weak Projection
        if self.options.bUseWeakProj:
            pred_keypoints_2d = weakProjection_gpu(pred_joints, pred_camera[:,0], pred_camera[:,1:] )           #N, 49, 2

        bFootOriLoss = False
        if bFootOriLoss:    #Ignore hips and hip centers, foot
            # LENGTH_THRESHOLD = 0.0089 #1/112.0     #at least it should be 5 pixel
            #Disable parts
            gt_keypoints_2d[:,2+25,2]=0
            gt_keypoints_2d[:,3+25,2]=0
            gt_keypoints_2d[:,14+25,2]=0

            #Disable Foots
            gt_keypoints_2d[:,5+25,2]=0     #Left foot
            gt_keypoints_2d[:,0+25,2]=0     #Right foot


        if self.options.run_smplify:

            # Convert predicted rotation matrices to axis-angle
            pred_rotmat_hom = torch.cat([pred_rotmat.detach().view(-1, 3, 3).detach(), torch.tensor([0,0,1], dtype=torch.float32,
                device=self.device).view(1, 3, 1).expand(batch_size * 24, -1, -1)], dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(pred_rotmat_hom).contiguous().view(batch_size, -1)
            # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
            pred_pose[torch.isnan(pred_pose)] = 0.0

            # Run SMPLify optimization starting from the network prediction
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)

            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)
            # print("new_opt_joint_loss{} vs opt_joint_loss{}".format(new_opt_joint_loss))

            if True:   #Visualize opt
                for b in range(batch_size):

                    curImgVis = images[b]     #3,224,224
                    curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                    curImgVis = np.transpose( curImgVis , (1,2,0) )*255.0
                    curImgVis =curImgVis[:,:,[2,1,0]] 

                    #Denormalize image
                    curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                    viewer2D.ImShow(curImgVis, name='rawIm')
                    originalImg = curImgVis.copy()

                    pred_camera_vis = pred_camera.detach().cpu().numpy()


                    opt_vert_vis = opt_vertices[b].detach().cpu().numpy() 
                    opt_vert_vis *=pred_camera_vis[b,0]
                    opt_vert_vis[:,0] += pred_camera_vis[b,1]        #no need +1 (or  112). Rendernig has this offset already
                    opt_vert_vis[:,1] += pred_camera_vis[b,2]        #no need +1 (or  112). Rendernig has this offset already
                    opt_vert_vis*=112
                    opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}


                    gt_vert_vis = gt_vertices[b].detach().cpu().numpy() 
                    gt_vert_vis *=pred_camera_vis[b,0]
                    gt_vert_vis[:,0] += pred_camera_vis[b,1]        #no need +1 (or  112). Rendernig has this offset already
                    gt_vert_vis[:,1] += pred_camera_vis[b,2]        #no need +1 (or  112). Rendernig has this offset already
                    gt_vert_vis*=112
                    gt_meshes = {'ver': gt_vert_vis, 'f': self.smpl.faces}

                    new_opt_output = self.smpl(betas=new_opt_betas, body_pose=new_opt_pose[:,3:], global_orient=new_opt_pose[:,:3])
                    new_opt_vertices = new_opt_output.vertices
                    new_opt_joints = new_opt_output.joints
                    new_opt_vert_vis = new_opt_vertices[b].detach().cpu().numpy() 
                    new_opt_vert_vis *=pred_camera_vis[b,0]
                    new_opt_vert_vis[:,0] += pred_camera_vis[b,1]        #no need +1 (or  112). Rendernig has this offset already
                    new_opt_vert_vis[:,1] += pred_camera_vis[b,2]        #no need +1 (or  112). Rendernig has this offset already
                    new_opt_vert_vis*=112
                    new_opt_meshes = {'ver': new_opt_vert_vis, 'f': self.smpl.faces}
                    
                    glViewer.setMeshData([new_opt_meshes, gt_meshes, new_opt_meshes], bComputeNormal= True)

                    glViewer.setBackgroundTexture(originalImg)
                    glViewer.setWindowSize(curImgVis.shape[1], curImgVis.shape[0])
                    glViewer.SetOrthoCamera(True)

                    print(has_smpl[b])
                    glViewer.show()
                

            opt_joint_loss[update] = new_opt_joint_loss[update]
            opt_vertices[update, :] = new_opt_vertices[update, :]
            opt_joints[update, :] = new_opt_joints[update, :]
            opt_pose[update, :] = new_opt_pose[update, :]
            opt_betas[update, :] = new_opt_betas[update, :]
            opt_cam_t[update, :] = new_opt_cam_t[update, :]

            self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu(), update.cpu())] = (opt_pose.cpu(), opt_betas.cpu())

        else:
            update = torch.zeros(batch_size, device=self.device).byte()

        # Replace the optimized parameters with the ground truth parameters, if available
        opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
        opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
        opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit = (opt_joint_loss < self.options.smplify_threshold).to(self.device)


        if self.options.ablation_no_pseudoGT:  
            valid_fit[:] =False       #Disable all pseudoGT


        # Add the examples with GT parameters to the list of valid fits
        valid_fit = valid_fit | has_smpl    

        # if len(valid_fit) > sum(valid_fit):
        #     print(">> Rejected fit: {}/{}".format(len(valid_fit) - sum(valid_fit), len(valid_fit) ))

        opt_keypoints_2d = perspective_projection(opt_joints,
                                                  rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(batch_size, -1, -1),
                                                  translation=opt_cam_t,
                                                  focal_length=self.focal_length,
                                                  camera_center=camera_center)


        opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.)


        # Compute loss on SMPL parameters
        loss_regr_pose, loss_regr_betas = self.smpl_losses(pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            self.options.openpose_train_weight,
                                            self.options.gt_train_weight)

        # Compute 3D keypoint loss
        loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, has_pose_3d)

        # Per-vertex loss for the shape
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)


        #Regularization term for shape
        loss_regr_betas_noReject = torch.mean(pred_betas**2)

        
        # Compute total loss
        # The last component is a loss that forces the network to predict positive depth values
        if self.options.ablation_loss_2dkeyonly:        #2D keypoint only
            loss = self.options.keypoint_loss_weight * loss_keypoints +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() +\
                    self.options.beta_loss_weight * loss_regr_betas_noReject        #Beta regularization

        elif self.options.ablation_loss_noSMPLloss:     #2D no Pose parameter
            loss = self.options.keypoint_loss_weight * loss_keypoints +\
                self.options.keypoint_loss_weight * loss_keypoints_3d +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() +\
                self.options.beta_loss_weight * loss_regr_betas_noReject        #Beta regularization

        else:
            loss = self.options.shape_loss_weight * loss_shape +\
                self.options.keypoint_loss_weight * loss_keypoints +\
                self.options.keypoint_loss_weight * loss_keypoints_3d +\
                loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()


        # loss = self.options.keypoint_loss_weight * loss_keypoints #Debug: 2d error only
        # print("DEBUG: 2donly loss")
        loss *= 60


        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output = {'pred_vertices': pred_vertices.detach(),
                  'opt_vertices': opt_vertices,
                  'pred_cam_t': pred_cam_t.detach(),
                  'opt_cam_t': opt_cam_t}
        losses = {'loss': loss.detach().item(),
                  'loss_keypoints': loss_keypoints.detach().item(),
                  'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
                  'loss_regr_pose': loss_regr_pose.detach().item(),
                  'loss_regr_betas': loss_regr_betas.detach().item(),
                  'loss_shape': loss_shape.detach().item()}



        if self.options.bDebug_visEFT:#g_debugVisualize:    #Debug Visualize input
            for b in range(batch_size):
                #denormalizeImg
                curImgVis = images[b]     #3,224,224
                curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                curImgVis = np.transpose( curImgVis , (1,2,0) )*255.0
                curImgVis =curImgVis[:,:,[2,1,0]] 

                #Denormalize image
                curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                viewer2D.ImShow(curImgVis, name='rawIm')
                originalImg = curImgVis.copy()

                # curImgVis = viewer2D.Vis_Skeleton_2D_general(gt_keypoints_2d_orig[b,:,:2].cpu().numpy(), gt_keypoints_2d_orig[b,:,2], bVis= False, image=curImgVis)


                pred_keypoints_2d_vis = pred_keypoints_2d[b,:,:2].detach().cpu().numpy()
                pred_keypoints_2d_vis = 0.5 * self.options.img_res * (pred_keypoints_2d_vis + 1)        #49: (25+24) x 3 

                curImgVis = viewer2D.Vis_Skeleton_2D_general(pred_keypoints_2d_vis, bVis= False, image=curImgVis)
                viewer2D.ImShow(curImgVis, scale=2.0, waitTime=1)

                #Get camera pred_params
                pred_camera_vis = pred_camera.detach().cpu().numpy()

                ############### Visualize Mesh ############### 
                pred_vert_vis = pred_vertices[b].detach().cpu().numpy() 
                # meshVertVis = gt_vertices[b].detach().cpu().numpy() 
                # meshVertVis = meshVertVis-pelvis        #centering
                pred_vert_vis *=pred_camera_vis[b,0]
                pred_vert_vis[:,0] += pred_camera_vis[b,1]        #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis[:,1] += pred_camera_vis[b,2]        #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis*=112
                pred_meshes = {'ver': pred_vert_vis, 'f': self.smpl.faces}


                opt_vert_vis = opt_vertices[b].detach().cpu().numpy() 
                opt_vert_vis *=pred_camera_vis[b,0]
                opt_vert_vis[:,0] += pred_camera_vis[b,1]        #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis[:,1] += pred_camera_vis[b,2]        #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis*=112
                opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}


                # glViewer.setMeshData([pred_meshes, opt_meshes], bComputeNormal= True)
                glViewer.setMeshData([pred_meshes, opt_meshes], bComputeNormal= True)
                # glViewer.setMeshData([opt_meshes], bComputeNormal= True)


                ############### Visualize Skeletons ############### 
                #Vis pred-SMPL joint
                pred_joints_vis = pred_joints[b,:,:3].detach().cpu().numpy()  #[N,49,3]
                pred_joints_vis = pred_joints_vis.ravel()[:,np.newaxis]
                #Weak-perspective projection
                pred_joints_vis*=pred_camera_vis[b,0]
                pred_joints_vis[::3] += pred_camera_vis[b,1]
                pred_joints_vis[1::3] += pred_camera_vis[b,2]
                pred_joints_vis *=112           #112 == 0.5*224
                glViewer.setSkeleton( [pred_joints_vis])

                # #GT joint
                gt_jointsVis = gt_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                gt_jointsVis = gt_jointsVis.ravel()[:,np.newaxis]
                gt_jointsVis*=pred_camera_vis[b,0]
                gt_jointsVis[::3] += pred_camera_vis[b,1]
                gt_jointsVis[1::3] += pred_camera_vis[b,2]
                gt_jointsVis*=112 
                glViewer.addSkeleton( [gt_jointsVis],jointType='spin')



                # #Vis SMPL's Skeleton
                # gt_smpljointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                # # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                # gt_smpljointsVis = gt_smpljointsVis.ravel()[:,np.newaxis]
                # gt_smpljointsVis*=pred_camera_vis[b,0]
                # gt_smpljointsVis[::3] += pred_camera_vis[b,1]
                # gt_smpljointsVis[1::3] += pred_camera_vis[b,2]
                # gt_smpljointsVis*=112
                # glViewer.addSkeleton( [gt_smpljointsVis])


                # #Vis GT  joint  (not model (SMPL) joint!!)
                # if has_pose_3d[b]:
                #     gt_jointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_jointsVis = gt_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_pelvis = (gt_jointsVis[ 25+2,:] + gt_jointsVis[ 25+3,:]) / 2
                #     # gt_jointsVis = gt_jointsVis- gt_pelvis

                #     gt_jointsVis = gt_jointsVis.ravel()[:,np.newaxis]
                #     gt_jointsVis*=pred_camera_vis[b,0]
                #     gt_jointsVis[::3] += pred_camera_vis[b,1]
                #     gt_jointsVis[1::3] += pred_camera_vis[b,2]
                #     gt_jointsVis*=112

                #     glViewer.addSkeleton( [gt_jointsVis])
                # # glViewer.show()


                glViewer.setBackgroundTexture(originalImg)
                glViewer.setWindowSize(curImgVis.shape[1], curImgVis.shape[0])
                glViewer.SetOrthoCamera(True)
                glViewer.show(0)

                # continue


        return output, losses