in empose/eval/helpers.py [0:0]
def load_model_and_eval_data(config, shuffle=False, partition='valid'):
"""Load model and the dataset."""
assert partition in ['valid', 'test_real', 'test_real_0715']
net, model_config, _, preprocess_fn = load_model(config.model_id, is_valid=(partition == 'valid'))
ws = config.seq_length if hasattr(config, 'seq_length') else model_config.window_size
bs = config.n_samples if hasattr(config, 'n_samples') else 6
if partition == 'valid':
transform = [ExtractWindow(ws, mode='middle'),
ToTensor()]
transform = transforms.Compose(transform)
valid_data = LMDBDataset(os.path.join(os.path.dirname(C.DATA_DIR), "3dpw_lmdb"), transform=transform)
eval_loader = DataLoader(valid_data,
batch_size=bs,
shuffle=shuffle,
num_workers=model_config.data_workers,
collate_fn=AMASSBatch.from_sample_list)
else:
test_transform = transforms.Compose([NormalizeRealMarkers(),
ToTensor()])
partition_to_dir = {'test_real': C.DATA_DIR_TEST,
'test_real_0715': os.path.join(C.DATA_DIR_TEST, 'hold_out')}
test_dir = partition_to_dir[partition]
test_data = RealDataset(test_dir, transform=test_transform)
eval_loader = DataLoader(test_data,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=RealBatch.from_sample_list)
net.eval()
return net, eval_loader, preprocess_fn, model_config