in scripts/preprocess_amass_3dpw.py [0:0]
def convert_3dpw_to_lmdb(output_file, threedpw_root):
"""Convert 3DPW to LMDB format that we can use during training."""
print("Converting 3DPW data under {} and exporting it to {} ...".format(threedpw_root, output_file))
smpl_layer = create_default_smpl_model(C.DEVICE)
max_possible_len = 1000
env = lmdb.open(output_file, map_size=1 << 29)
cache = dict()
# Find all pickle files.
pkl_files = []
for root_dir, dirs, files in os.walk(threedpw_root):
for f in files:
if f.endswith('.pkl'):
pkl_files.append(os.path.join(root_dir, f))
idx = 0
for i in tqdm(range(len(pkl_files))):
file_id = os.path.split(pkl_files[i])[-1]
sample = pkl.load(open(pkl_files[i], 'rb'), encoding='latin1')
n_subjects = len(sample['poses_60Hz'])
for s in range(n_subjects):
poses = sample['poses_60Hz'][s][:, :C.MAX_INDEX_ROOT_AND_BODY] # (N_FRAMES, 66)
betas = sample['betas'][s][:C.N_SHAPE_PARAMS] # (N_SHAPE_PARAMS, )
trans = sample['trans_60Hz'][s] # (N_FRAMES, 3)
gender = sample['genders'][s]
n_frames = poses.shape[0]
n_shards = n_frames // max_possible_len
joints = []
for j in range(n_shards+1):
sf = j*max_possible_len
ef = None if j == n_shards else (j+1)*max_possible_len
with torch.no_grad():
ps = torch.from_numpy(poses[sf:ef]).to(dtype=torch.float32, device=C.DEVICE)
ts = torch.from_numpy(trans[sf:ef]).to(dtype=torch.float32, device=C.DEVICE)
bs = torch.from_numpy(betas).to(dtype=torch.float32, device=C.DEVICE)
_, js = smpl_layer(poses_body=ps[:, 3:], betas=bs, poses_root=ps[:, :3], trans=ts)
joints.append(js[:, :(1 + C.N_JOINTS)].reshape(-1, (1 + C.N_JOINTS)*3).detach().cpu().numpy())
joints = np.concatenate(joints, axis=0) # (N_FRAMES, 66)
assert joints.shape[0] == n_frames
gender = 'female' if gender == 'f' else 'male'
# Store.
cache["poses{}".format(idx)] = poses.astype(np.float32).tobytes()
cache["betas{}".format(idx)] = betas.astype(np.float32).tobytes()
cache["trans{}".format(idx)] = trans.astype(np.float32).tobytes()
cache["joints{}".format(idx)] = joints.astype(np.float32).tobytes()
cache["n_frames{}".format(idx)] = "{}".format(n_frames).encode()
cache["id{}".format(idx)] = file_id.encode()
cache["gender{}".format(idx)] = gender.encode()
if idx > 0 and idx % 1000 == 0:
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k.encode(), v)
cache = dict()
torch.cuda.empty_cache()
idx += 1
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k.encode(), v)
txn.put('__len__'.encode(), "{}".format(len(pkl_files)).encode())