in eft/train/eftFitter.py [0:0]
def eftAllInDB(self, eft_out_dir = "./eft_out/", bExportPKL = True, test_dataset_3dpw = None, test_dataset_h36m= None):
# if config.bIsDevfair:
# now = datetime.datetime.now()
# # newName = '{:02d}-{:02d}-{}'.format(now.month, now.day, now.hour*3600 + now.minute*60 + now.second)
# newName = '{:02d}-{:02d}'.format(now.month, now.day)
# outputDir = newName + '_' + self.options.db_set + '_' + self.options.name
# else:
now = datetime.datetime.now()
# outputDir = self.options.db_set
newName = 'eftout_{:02d}-{:02d}-{}'.format(now.month, now.day, now.hour*3600 + now.minute*60 + now.second)
outputDir = newName + '_' + self.options.db_set + '_' + self.options.name
# exemplarOutputPath = os.path.join(config.EXEMPLAR_OUTPUT_ROOT , outputDir)
os.makedirs(eft_out_dir,exist_ok=True)
exemplarOutputPath = os.path.join(eft_out_dir , outputDir)
os.makedirs(exemplarOutputPath,exist_ok=True)
"""Training process."""
# Run training for num_epochs epochs
# Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch
train_data_loader = CheckpointDataLoader(self.train_ds,checkpoint=self.checkpoint,
batch_size=1, #Always o1
num_workers=self.options.num_workers,
pin_memory=self.options.pin_memory,
shuffle=False) #No Shuffle
maxExemplarIter = self.options.maxExemplarIter
# Iterate over all batches in an epoch
outputList ={}
for step, batch in enumerate(tqdm(train_data_loader)):#, desc='Epoch '+str(epoch),
# total=len(self.train_ds) // self.options.batch_size,
# initial=train_data_loader.checkpoint_batch_idx),
# train_data_loader.checkpoint_batch_idx):
#3DPW test
# if 'downtown_bus_00' not in batch['imgname']:
# continue
#Only performed for 1/100 data (roughly hundred level)
if self.options.bExemplar_analysis_testloss:
sampleIdx = batch['sample_index'][0].item()
if sampleIdx%100 !=0:
continue
if self.options.bExemplar_badsample_finder:
sampleIdx = batch['sample_index'][0].item()
# if sampleIdx%100 !=0:
# continue
# if self.options.bDebug_visEFT and self.options.bUseHand2D:
# is_valid = sum(batch['kp_leftHand_gt'][0,:,2]).clone().detach().cpu().numpy().item() + sum(batch['kp_rightHand_gt'][0,:,2]).clone().detach().cpu().numpy().item()
# if is_valid<0.1:
# continue
bSkipExisting = self.options.bNotSkipExemplar==False #bNotSkipExemplar ===True --> bSkipExisting==False
if bSkipExisting:
if self.options.db_set =='panoptic':
# fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
# fileNameOnly = (batch['pkl_save_name'][0])[:-4].replace("/","-")
sampleIdx = batch['sample_index'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
sampleIdxSaveFrame = 100* (int(sampleIdx/100.0) + 1)
fileName = '{:08d}.pkl'.format(sampleIdxSaveFrame)
outputPath = os.path.join(exemplarOutputPath,fileName)
# print(">> checking: {}".format(outputPath))
if os.path.exists(outputPath):
print("Skipped: {}".format(outputPath))
continue
elif '3dpw' in self.options.db_set:
fileNameOnly = os.path.basename(batch['imgname'][0])[:-4]
seqName = os.path.basename(os.path.dirname(batch['imgname'][0]))
fileNameOnly = f"{seqName}_{fileNameOnly}"
sampleIdx = batch['sample_index'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
if os.path.exists(outputPath):
print("Skipped: {}".format(outputPath))
continue
else:
fileNameOnly = os.path.basename(batch['imgname'][0])[:-4]
sampleIdx = batch['sample_index'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
if os.path.exists(outputPath):
print("Skipped: {}".format(outputPath))
continue
g_timer.tic()
self.reloadModel() #For each sample
#Freeze non resnet part model
if self.options.ablation_layerteset_onlyLayer4:
# self.model.conv1.requires_grad = False
# self.model.bn1.requires_grad = False
# self.model.relu.requires_grad = False
# self.model.maxpool.requires_grad = False
# self.model.layer1.requires_grad = False
# self.model.layer2.requires_grad = False
# self.model.layer3.requires_grad = False
# self.model.layer4.requires_grad = False
# self.model.fc1.requires_grad = False
# self.model.drop1.requires_grad = False
# self.model.fc2.requires_grad = False
# self.model.drop2.requires_grad = False
# self.model.decpose.requires_grad = False
# self.model.decshape.requires_grad = False
# self.model.deccam.requires_grad = False
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters(): #Optimize Layer 4 of resnet
# print(name)
# if 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
# par.requires_grad = True
if 'layer4' in name:
# print(f">> Activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_onlyAfterRes: #Optimize HMR Decoder part
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_Layer4Later: #Optimize Layer 4 of resent + HMR FC part
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'layer4' in name or 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_onlyRes:
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'layer' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_Layer3Later:
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'layer3' in name or 'layer4' in name or 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_Layer2Later:
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'layer2' in name or 'layer3' in name or 'layer4' in name or 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_Layer1Later:
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'layer1' in name or 'layer2' in name or 'layer3' in name or 'layer4' in name or 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_all: #No Freeze. debugging purpose
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'conv1' in name or 'layer' in name or 'fc' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_onlyRes_withconv1: #Only use ResNet. Freeze HMR part all
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'conv1' in name or 'layer' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_decOnly: #Optimize the last layer of hmr decoder
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
if self.options.ablation_layerteset_fc2Later:
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'fc2' in name or 'decpose' in name or 'decshape' in name or 'deccam' in name:
# print(f"activate {name}")
par.requires_grad = True
#Freeze all except the last layer of Resnet
if self.options.ablation_layerteset_onlyRes50LastConv:
for par in self.model.parameters():
par.requires_grad = False
for name, par in self.model.named_parameters():
if 'layer4.2.conv3' in name:
# print(f"activate {name}")
par.requires_grad = True
# g_timer.toc(average =False, bPrint=True,title="reload")
# self.exemplerTrainingMode()
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in batch.items()}
output_backup={}
for it in range(maxExemplarIter):
##########################################################################################
##### RUN EFT
##########################################################################################
# g_timer.tic()
if self.options.bUseHand3D:
output, losses = self.run_eft_step_wHand(batch)
elif self.options.bUseHand2D:
output, losses = self.run_eft_step_with_2dhand(batch, iterIdx=it)
else:
output, losses = self.run_eft_step(batch, iterIdx=it)
#Check frozeon layers
# if self.options.abl:
# sumVal =0
# for par in self.model.layer4.parameters():
# sumVal +=par.mean()
# print("fc1 {}, self.model.layer4 {}".format(self.model.fc1.weight.mean(), sumVal))
# g_timer.toc(average =True, bPrint=True,title="eachStep")
output['loss_keypoints_2d'] = losses['loss_keypoints']
output['loss'] = losses['loss']
if it==0:
output_backup['pred_shape'] = output['pred_shape'].copy()
output_backup['pred_pose_rotmat'] = output['pred_pose_rotmat'].copy()
output_backup['pred_camera'] = output['pred_camera'].copy()
output_backup['loss_keypoints_2d'] = output['loss_keypoints_2d']
output_backup['loss'] = output['loss']
# #Save the first output here for coparison (why??)
# batch['pose'] = torch.tensor(output['pred_pose_rotmat'].copy())# SMPL pose parameters #[N,72]
# batch['betas'] = torch.tensor(output['pred_shape'].copy()) # SMPL beta parameters #[N,10]
# pred_rotmat_hom = torch.cat([batch['pose'].view(-1, 3, 3), torch.tensor([0,0,1], dtype=torch.float32,).view(1, 3, 1).expand(batch['pose'].shape[0] * 24, -1, -1)], dim=-1)
# batch['pose'] = rotation_matrix_to_angle_axis(pred_rotmat_hom).contiguous().view(batch['pose'].shape[0], -1)
# print("keypoint loss: {}".format(output['loss_keypoints_2d']))
# Thresholding by 2D keypoint error
if True:
if output['loss_keypoints_2d']< self.options.eft_thresh_keyptErr_2d: # 1e-4:
# glViewer.show(0)
break
# break #Debug
# g_timer.toc(average =True, bPrint=True,title="wholeEFT")
if self.options.bDebug_visEFT:
# glViewer.show(0)
if False: #Render to File
imgname = output['imageName'][0]
root_imgname = os.path.basename(imgname)[:-4]
renderRoot=f'/home/hjoo/temp/render_eft/eft_{root_imgname}'
smpl_utils.renderSMPLoutput_merge(renderRoot)
glViewer.show(0)
output['pred_shape_init'] = output_backup['pred_shape']
output['pred_pose_rotmat_init'] = output_backup['pred_pose_rotmat']
output['pred_camera_init'] = output_backup['pred_camera']
output['loss_init'] = output_backup['loss']
output['loss_keypoints_2d_init'] = output_backup['loss_keypoints_2d']
output['numOfIteration'] = it
if self.options.bUseSMPLX:
output['smpltype'] = 'smplx'
else:
output['smpltype'] = 'smpl'
#Exemplar Tuning Analysis
if self.options.bExemplar_analysis_testloss and test_dataset_3dpw is not None:
print(">> Testing : test set size:{}".format(len(test_dataset_3dpw)))
error_3dpw = self.test(test_dataset_3dpw, '3dpw')
output['test_error_3dpw'] = error_3dpw
error_h36m = self.test(test_dataset_h36m, 'h36m-p1')
output['test_error_h36m'] = error_h36m
if self.options.bExemplar_badsample_finder and test_dataset_3dpw is not None:
print(">> Testing : test set size:{}".format(len(test_dataset_3dpw)))
error_3dpw = self.test(test_dataset_3dpw, '3dpw')
output['test_error_3dpw'] = error_3dpw
if bExportPKL: #Export Output to PKL files
if self.options.db_set =='panoptic' or "haggling" in self.options.db_set or self.options.db_set =='pennaction':
# fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")
sampleIdx = output['sampleIdx'][0]
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
if sampleIdx%100==0:
outputList[sampleIdx] = output
# fileName = '{:80d}.pkl'.format(fileNameOnly,sampleIdx)
fileName = '{:08d}.pkl'.format(sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
print("Saved:{}".format(outputPath))
with open(outputPath,'wb') as f:
pickle.dump(outputList,f) #Bug fixed
f.close()
outputList ={} #reset
else:
outputList[sampleIdx] = output
elif "3dpw" in self.options.db_set:
fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
seqName = os.path.basename(os.path.dirname(output['imageName'][0]))
fileNameOnly = f"{seqName}_{fileNameOnly}"
# fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")
sampleIdx = output['sampleIdx'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
print("Saved:{}".format(outputPath))
with open(outputPath,'wb') as f:
pickle.dump(output,f)
f.close()
else:
fileNameOnly = os.path.basename(output['imageName'][0])[:-4]
# fileNameOnly = (output['imageName'][0])[:-4].replace("/","-")
sampleIdx = output['sampleIdx'][0].item()
if self.options.bExemplar_dataLoaderStart>=0:
sampleIdx +=self.options.bExemplar_dataLoaderStart
fileName = '{}_{}.pkl'.format(fileNameOnly,sampleIdx)
outputPath = os.path.join(exemplarOutputPath,fileName)
print("Saved:{}".format(outputPath))
with open(outputPath,'wb') as f:
pickle.dump(output,f)
f.close()