in bodymocap/train/eftFitter.py [0:0]
def eftAllInDB_3dpwtest(self, test_dataset_3dpw = None, test_dataset_h36m= None, bExportPKL = True):
if config.bIsDevfair:
now = datetime.datetime.now()
# newName = '{:02d}-{:02d}-{}'.format(now.month, now.day, now.hour*3600 + now.minute*60 + now.second)
newName = '{:02d}-{:02d}'.format(now.month, now.day)
outputDir = newName + '_' + self.options.db_set + '_' + self.options.name
else:
now = datetime.datetime.now()
# outputDir = self.options.db_set
newName = 'test_{:02d}-{:02d}-{}'.format(now.month, now.day, now.hour*3600 + now.minute*60 + now.second)
outputDir = newName + '_' + self.options.db_set + '_' + self.options.name
exemplarOutputPath = os.path.join(config.EXEMPLAR_OUTPUT_ROOT , outputDir)
if not os.path.exists(exemplarOutputPath):
os.mkdir(exemplarOutputPath)
"""Training process."""
# Run training for num_epochs epochs
# Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch
train_data_loader = CheckpointDataLoader(self.train_ds,checkpoint=self.checkpoint,
batch_size=1, #Always o1
num_workers=self.options.num_workers,
pin_memory=self.options.pin_memory,
shuffle=False) #No Shuffle
maxExemplarIter = self.options.maxExemplarIter
# Iterate over all batches in an epoch
outputList ={}
reconError =[]
for step, batch in enumerate(tqdm(train_data_loader)):#, desc='Epoch '+str(epoch),
# total=len(self.train_ds) // self.options.batch_size,
# initial=train_data_loader.checkpoint_batch_idx),
# train_data_loader.checkpoint_batch_idx):
# if step==100:
# break
#3DPW test
# if 'downtown_bus_00' not in batch['imgname']:
# continue
#Only performed for 1/100 data (roughly hundred level)
if self.options.bExemplar_analysis_testloss:
sampleIdx = batch['sample_index'][0].item()
if sampleIdx%100 !=0:
continue
if self.options.bExemplar_badsample_finder:
sampleIdx = batch['sample_index'][0].item()
# if sampleIdx%100 !=0:
# continue
bSkipExisting = self.options.bNotSkipExemplar==False #bNotSkipExemplar ===True --> bSkipExisting==False
if bSkipExisting:
if self.options.db_set =='panoptic':
# fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
# fileNameOnly = (batch['pkl_save_name'][0])[:-4].replace("/","-")
sampleIdx = batch['sample_index'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
sampleIdxSaveFrame = 100* (int(sampleIdx/100.0) + 1)
fileName = '{:08d}.pkl'.format(sampleIdxSaveFrame)
outputPath = os.path.join(exemplarOutputPath,fileName)
# print(">> checking: {}".format(outputPath))
if os.path.exists(outputPath):
print("Skipped: {}".format(outputPath))
continue
elif '3dpw' in self.options.db_set:
fileNameOnly = os.path.basename(batch['imgname'][0])[:-4]
seqName = os.path.basename(os.path.dirname(batch['imgname'][0]))
fileNameOnly = f"{seqName}_{fileNameOnly}"
sampleIdx = batch['sample_index'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
if os.path.exists(outputPath):
print("Skipped: {}".format(outputPath))
continue
else:
fileNameOnly = os.path.basename(batch['imgname'][0])[:-4]
sampleIdx = batch['sample_index'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
if os.path.exists(outputPath):
print("Skipped: {}".format(outputPath))
continue
# g_timer.tic()
self.reloadModel() #For each sample
# g_timer.toc(average =False, bPrint=True,title="reload")
# self.exemplerTrainingMode()
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in batch.items()}
output_backup={}
reconErrorInfo ={}
for it in range(maxExemplarIter):
g_timer.tic()
if self.options.bUseHand3D:
output, losses = self.run_eft_step_wHand(batch)
else:
output, losses = self.run_eft_step(batch)
#Check r_error
reconErrorInfo[it] = (losses['r_error'], losses['loss_keypoints'])
# g_timer.toc(average =False, bPrint=True,title="eachStep"
output['loss_keypoints_2d'] = losses['loss_keypoints']
output['loss'] = losses['loss']
if it==0:
output_backup['pred_shape'] = output['pred_shape'].copy()
output_backup['pred_pose_rotmat'] = output['pred_pose_rotmat'].copy()
output_backup['pred_camera'] = output['pred_camera'].copy()
output_backup['loss_keypoints_2d'] = output['loss_keypoints_2d']
output_backup['loss'] = output['loss']
# #Save the first output here for coparison (why??)
# batch['pose'] = torch.tensor(output['pred_pose_rotmat'].copy())# SMPL pose parameters #[N,72]
# batch['betas'] = torch.tensor(output['pred_shape'].copy()) # SMPL beta parameters #[N,10]
# pred_rotmat_hom = torch.cat([batch['pose'].view(-1, 3, 3), torch.tensor([0,0,1], dtype=torch.float32,).view(1, 3, 1).expand(batch['pose'].shape[0] * 24, -1, -1)], dim=-1)
# batch['pose'] = rotation_matrix_to_angle_axis(pred_rotmat_hom).contiguous().view(batch['pose'].shape[0], -1)
reconError.append(reconErrorInfo)
output['reconErrorInfo'] = reconErrorInfo
output['pred_shape_init'] = output_backup['pred_shape']
output['pred_pose_rotmat_init'] = output_backup['pred_pose_rotmat']
output['pred_camera_init'] = output_backup['pred_camera']
output['loss_init'] = output_backup['loss']
output['loss_keypoints_2d_init'] = output_backup['loss_keypoints_2d']
output['numOfIteration'] = it
if self.options.bUseSMPLX:
output['smpltype'] = 'smplx'
else:
output['smpltype'] = 'smpl'
#Exemplar Tuning Analysis
if self.options.bExemplar_analysis_testloss and test_dataset_3dpw is not None:
print(">> Testing : test set size:{}".format(len(test_dataset_3dpw)))
error_3dpw = self.test(test_dataset_3dpw, '3dpw')
output['test_error_3dpw'] = error_3dpw
error_h36m = self.test(test_dataset_h36m, 'h36m-p1')
output['test_error_h36m'] = error_h36m
if self.options.bExemplar_badsample_finder and test_dataset_3dpw is not None:
print(">> Testing : test set size:{}".format(len(test_dataset_3dpw)))
error_3dpw = self.test(test_dataset_3dpw, '3dpw')
output['test_error_3dpw'] = error_3dpw
if bExportPKL: #Export Output to PKL files
if self.options.db_set =='panoptic' or "haggling" in self.options.db_set:
# fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")
sampleIdx = output['sampleIdx'][0]
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
if sampleIdx%100==0:
outputList[sampleIdx] = output
# fileName = '{:80d}.pkl'.format(fileNameOnly,sampleIdx)
fileName = '{:08d}.pkl'.format(sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
print("Saved:{}".format(outputPath))
with open(outputPath,'wb') as f:
pickle.dump(outputList,f) #Bug fixed
f.close()
outputList ={} #reset
else:
outputList[sampleIdx] = output
elif "3dpw" in self.options.db_set:
fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
seqName = os.path.basename(os.path.dirname(output['imageName'][0]))
fileNameOnly = f"{seqName}_{fileNameOnly}"
# fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")
sampleIdx = output['sampleIdx'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
print("Saved:{}".format(outputPath))
with open(outputPath,'wb') as f:
pickle.dump(output,f)
f.close()
else:
fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
# fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")
sampleIdx = output['sampleIdx'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
print("Saved:{}".format(outputPath))
with open(outputPath,'wb') as f:
pickle.dump(output,f)
f.close()
# # # Tensorboard logging every summary_steps steps
# if self.step_count % self.options.summary_steps == 0:
# self.train_summaries(batch, *out)
if False: #Display the best iteration
reconErrorPerIter=[]
for it in range(maxExemplarIter):
print(it)
reconErrorPerIter.append([d[it][0] for d in reconError])
# viewer2D.Plot(reconError)
reconErrorPerIter = np.array(reconErrorPerIter)
for it in range(maxExemplarIter):
print("{}: reconError:{}".format(it, np.mean(reconErrorPerIter[it,:])))
#Find the best
print("Best: reconError:{}".format(np.mean(np.min(reconErrorPerIter,axis=0))))