in apps/recon.py [0:0]
def recon(opt, use_rect=False):
# load checkpoints
state_dict_path = None
if opt.load_netMR_checkpoint_path is not None:
state_dict_path = opt.load_netMR_checkpoint_path
elif opt.resume_epoch < 0:
state_dict_path = '%s/%s_train_latest' % (opt.checkpoints_path, opt.name)
opt.resume_epoch = 0
else:
state_dict_path = '%s/%s_train_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
start_id = opt.start_id
end_id = opt.end_id
cuda = torch.device('cuda:%d' % opt.gpu_id if torch.cuda.is_available() else 'cpu')
state_dict = None
if state_dict_path is not None and os.path.exists(state_dict_path):
print('Resuming from ', state_dict_path)
state_dict = torch.load(state_dict_path, map_location=cuda)
print('Warning: opt is overwritten.')
dataroot = opt.dataroot
resolution = opt.resolution
results_path = opt.results_path
loadSize = opt.loadSize
opt = state_dict['opt']
opt.dataroot = dataroot
opt.resolution = resolution
opt.results_path = results_path
opt.loadSize = loadSize
else:
raise Exception('failed loading state dict!', state_dict_path)
# parser.print_options(opt)
if use_rect:
test_dataset = EvalDataset(opt)
else:
test_dataset = EvalWPoseDataset(opt)
print('test data size: ', len(test_dataset))
projection_mode = test_dataset.projection_mode
opt_netG = state_dict['opt_netG']
netG = HGPIFuNetwNML(opt_netG, projection_mode).to(device=cuda)
netMR = HGPIFuMRNet(opt, netG, projection_mode).to(device=cuda)
def set_eval():
netG.eval()
# load checkpoints
netMR.load_state_dict(state_dict['model_state_dict'])
os.makedirs(opt.checkpoints_path, exist_ok=True)
os.makedirs(opt.results_path, exist_ok=True)
os.makedirs('%s/%s/recon' % (opt.results_path, opt.name), exist_ok=True)
if start_id < 0:
start_id = 0
if end_id < 0:
end_id = len(test_dataset)
## test
with torch.no_grad():
set_eval()
print('generate mesh (test) ...')
for i in tqdm(range(start_id, end_id)):
if i >= len(test_dataset):
break
# for multi-person processing, set it to False
if True:
test_data = test_dataset[i]
save_path = '%s/%s/recon/result_%s_%d.obj' % (opt.results_path, opt.name, test_data['name'], opt.resolution)
print(save_path)
gen_mesh(opt.resolution, netMR, cuda, test_data, save_path, components=opt.use_compose)
else:
for j in range(test_dataset.get_n_person(i)):
test_dataset.person_id = j
test_data = test_dataset[i]
save_path = '%s/%s/recon/result_%s_%d.obj' % (opt.results_path, opt.name, test_data['name'], j)
gen_mesh(opt.resolution, netMR, cuda, test_data, save_path, components=opt.use_compose)