def eftAllInDB_3dpwtest()

in bodymocap/train/eftFitter.py [0:0]


    def eftAllInDB_3dpwtest(self, test_dataset_3dpw = None, test_dataset_h36m= None, bExportPKL = True):

        if config.bIsDevfair:
            now = datetime.datetime.now()
            # newName = '{:02d}-{:02d}-{}'.format(now.month, now.day, now.hour*3600 + now.minute*60 + now.second)
            newName = '{:02d}-{:02d}'.format(now.month, now.day)
            outputDir = newName + '_' + self.options.db_set + '_' + self.options.name
        else:
            now = datetime.datetime.now()
            # outputDir = self.options.db_set
            newName = 'test_{:02d}-{:02d}-{}'.format(now.month, now.day, now.hour*3600 + now.minute*60 + now.second)
            outputDir = newName + '_' + self.options.db_set + '_' + self.options.name

        exemplarOutputPath = os.path.join(config.EXEMPLAR_OUTPUT_ROOT , outputDir)
        if not os.path.exists(exemplarOutputPath):
            os.mkdir(exemplarOutputPath)

        """Training process."""
        # Run training for num_epochs epochs
        # Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch
        train_data_loader = CheckpointDataLoader(self.train_ds,checkpoint=self.checkpoint,
                                                    batch_size=1,       #Always o1
                                                    num_workers=self.options.num_workers,
                                                    pin_memory=self.options.pin_memory,
                                                    shuffle=False)      #No Shuffle      
        
        maxExemplarIter = self.options.maxExemplarIter
       
        # Iterate over all batches in an epoch
        outputList ={}
        reconError =[]
        for step, batch in enumerate(tqdm(train_data_loader)):#, desc='Epoch '+str(epoch),
                                        #     total=len(self.train_ds) // self.options.batch_size,
                                        #     initial=train_data_loader.checkpoint_batch_idx),
                                        # train_data_loader.checkpoint_batch_idx):
            # if step==100:
            #     break
            #3DPW test
            # if 'downtown_bus_00' not in batch['imgname']:
            #     continue

            #Only performed for 1/100 data (roughly hundred level)
            if self.options.bExemplar_analysis_testloss:
                sampleIdx = batch['sample_index'][0].item()
                if sampleIdx%100 !=0:
                    continue
            
            if self.options.bExemplar_badsample_finder:
                sampleIdx = batch['sample_index'][0].item()
                # if sampleIdx%100 !=0:
                #     continue

            bSkipExisting  =  self.options.bNotSkipExemplar==False     #bNotSkipExemplar ===True --> bSkipExisting==False
            if bSkipExisting:
                if self.options.db_set =='panoptic':
                    # fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
                    # fileNameOnly = (batch['pkl_save_name'][0])[:-4].replace("/","-")

                    sampleIdx = batch['sample_index'][0].item()
                    if self.options.bExemplar_dataLoaderStart>=0:
                        sampleIdx +=self.options.bExemplar_dataLoaderStart

                    sampleIdxSaveFrame = 100* (int(sampleIdx/100.0) + 1)
                    
                    fileName = '{:08d}.pkl'.format(sampleIdxSaveFrame)

                    outputPath = os.path.join(exemplarOutputPath,fileName)
                    # print(">> checking: {}".format(outputPath))
                    if os.path.exists(outputPath):
                        print("Skipped: {}".format(outputPath))
                        continue
                elif '3dpw' in self.options.db_set:
                    fileNameOnly = os.path.basename(batch['imgname'][0])[:-4]
                    seqName = os.path.basename(os.path.dirname(batch['imgname'][0]))
                    fileNameOnly = f"{seqName}_{fileNameOnly}"

                    sampleIdx = batch['sample_index'][0].item()
                    if self.options.bExemplar_dataLoaderStart>=0:
                        sampleIdx +=self.options.bExemplar_dataLoaderStart
                    fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
                    outputPath = os.path.join(exemplarOutputPath,fileName)
                    if os.path.exists(outputPath):
                        print("Skipped: {}".format(outputPath))
                        continue
                else:
                    fileNameOnly = os.path.basename(batch['imgname'][0])[:-4]
                    sampleIdx = batch['sample_index'][0].item()
                    if self.options.bExemplar_dataLoaderStart>=0:
                        sampleIdx +=self.options.bExemplar_dataLoaderStart
                    fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
                    outputPath = os.path.join(exemplarOutputPath,fileName)
                    if os.path.exists(outputPath):
                        print("Skipped: {}".format(outputPath))
                        continue
                    
            # g_timer.tic()
            self.reloadModel()  #For each sample
            # g_timer.toc(average =False, bPrint=True,title="reload")
            # self.exemplerTrainingMode()

            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in batch.items()}

            output_backup={}
            reconErrorInfo ={}
            for it in range(maxExemplarIter):
                
                g_timer.tic()
                if self.options.bUseHand3D:
                    output, losses = self.run_eft_step_wHand(batch)
                else:
                    output, losses = self.run_eft_step(batch)
                
                #Check r_error

                
                reconErrorInfo[it] = (losses['r_error'], losses['loss_keypoints'])


                # g_timer.toc(average =False, bPrint=True,title="eachStep"
                output['loss_keypoints_2d'] = losses['loss_keypoints']
                output['loss'] = losses['loss']

                if it==0:
                    output_backup['pred_shape'] = output['pred_shape'].copy()
                    output_backup['pred_pose_rotmat'] = output['pred_pose_rotmat'].copy()
                    output_backup['pred_camera'] = output['pred_camera'].copy()

                    output_backup['loss_keypoints_2d'] = output['loss_keypoints_2d']
                    output_backup['loss'] = output['loss']

                    # #Save the first output here for coparison (why??)
                    # batch['pose'] =  torch.tensor(output['pred_pose_rotmat'].copy())# SMPL pose parameters                #[N,72]
                    # batch['betas'] =  torch.tensor(output['pred_shape'].copy()) # SMPL beta parameters              #[N,10]
                    # pred_rotmat_hom = torch.cat([batch['pose'].view(-1, 3, 3), torch.tensor([0,0,1], dtype=torch.float32,).view(1, 3, 1).expand(batch['pose'].shape[0] * 24, -1, -1)], dim=-1)
                    # batch['pose'] = rotation_matrix_to_angle_axis(pred_rotmat_hom).contiguous().view(batch['pose'].shape[0], -1)

            reconError.append(reconErrorInfo)
            output['reconErrorInfo'] = reconErrorInfo

            output['pred_shape_init'] = output_backup['pred_shape'] 
            output['pred_pose_rotmat_init']  = output_backup['pred_pose_rotmat']
            output['pred_camera_init'] = output_backup['pred_camera']

            output['loss_init'] = output_backup['loss'] 
            output['loss_keypoints_2d_init']  = output_backup['loss_keypoints_2d']
            output['numOfIteration'] = it

            if self.options.bUseSMPLX:
                output['smpltype'] = 'smplx'
            else:
                output['smpltype'] = 'smpl'

            #Exemplar Tuning Analysis
            if self.options.bExemplar_analysis_testloss and test_dataset_3dpw is not None:
                print(">> Testing : test set size:{}".format(len(test_dataset_3dpw)))
                error_3dpw = self.test(test_dataset_3dpw, '3dpw')
                output['test_error_3dpw'] = error_3dpw
                error_h36m = self.test(test_dataset_h36m, 'h36m-p1')
                output['test_error_h36m'] = error_h36m
            
            if self.options.bExemplar_badsample_finder and test_dataset_3dpw is not None:
                print(">> Testing : test set size:{}".format(len(test_dataset_3dpw)))
                error_3dpw = self.test(test_dataset_3dpw, '3dpw')
                output['test_error_3dpw'] = error_3dpw

            if bExportPKL:    #Export Output to PKL files
                if self.options.db_set =='panoptic' or "haggling" in self.options.db_set:
                    # fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
                    fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")

                    sampleIdx = output['sampleIdx'][0]
                    if self.options.bExemplar_dataLoaderStart>=0:
                        sampleIdx +=self.options.bExemplar_dataLoaderStart

                    if sampleIdx%100==0:
                        outputList[sampleIdx] = output

                        # fileName = '{:80d}.pkl'.format(fileNameOnly,sampleIdx)
                        fileName = '{:08d}.pkl'.format(sampleIdx)
                        outputPath = os.path.join(exemplarOutputPath,fileName)
                        print("Saved:{}".format(outputPath))
                        with open(outputPath,'wb') as f:
                            pickle.dump(outputList,f)       #Bug fixed
                            f.close()
                        
                        outputList ={}      #reset
                    else:
                        outputList[sampleIdx] = output

                elif "3dpw" in self.options.db_set:
                    fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
                    seqName = os.path.basename(os.path.dirname(output['imageName'][0]))
                    fileNameOnly = f"{seqName}_{fileNameOnly}"
                    # fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")

                    sampleIdx = output['sampleIdx'][0].item()
                    if self.options.bExemplar_dataLoaderStart>=0:
                        sampleIdx +=self.options.bExemplar_dataLoaderStart

                    fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
                    outputPath = os.path.join(exemplarOutputPath,fileName)

                    print("Saved:{}".format(outputPath))
                    with open(outputPath,'wb') as f:
                        pickle.dump(output,f)       
                        f.close()

                else:
                    fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
                    # fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")

                    sampleIdx = output['sampleIdx'][0].item()
                    if self.options.bExemplar_dataLoaderStart>=0:
                        sampleIdx +=self.options.bExemplar_dataLoaderStart

                    fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
                    outputPath = os.path.join(exemplarOutputPath,fileName)

                    print("Saved:{}".format(outputPath))
                    with open(outputPath,'wb') as f:
                        pickle.dump(output,f)       
                        f.close()
                        
            # # # Tensorboard logging every summary_steps steps
            # if self.step_count % self.options.summary_steps == 0:
            #     self.train_summaries(batch, *out)

        if False:       #Display the best iteration
            reconErrorPerIter=[]
            for it in range(maxExemplarIter):
                print(it)
                reconErrorPerIter.append([d[it][0] for d in reconError])
            # viewer2D.Plot(reconError)
            reconErrorPerIter = np.array(reconErrorPerIter)
            for it in range(maxExemplarIter):
                print("{}: reconError:{}".format(it, np.mean(reconErrorPerIter[it,:])))
            #Find the best
            print("Best: reconError:{}".format(np.mean(np.min(reconErrorPerIter,axis=0))))