def convert_3dpw_to_lmdb()

in scripts/preprocess_amass_3dpw.py [0:0]


def convert_3dpw_to_lmdb(output_file, threedpw_root):
    """Convert 3DPW to LMDB format that we can use during training."""
    print("Converting 3DPW data under {} and exporting it to {} ...".format(threedpw_root, output_file))
    smpl_layer = create_default_smpl_model(C.DEVICE)
    max_possible_len = 1000
    env = lmdb.open(output_file, map_size=1 << 29)
    cache = dict()

    # Find all pickle files.
    pkl_files = []
    for root_dir, dirs, files in os.walk(threedpw_root):
        for f in files:
            if f.endswith('.pkl'):
                pkl_files.append(os.path.join(root_dir, f))

    idx = 0
    for i in tqdm(range(len(pkl_files))):
        file_id = os.path.split(pkl_files[i])[-1]
        sample = pkl.load(open(pkl_files[i], 'rb'), encoding='latin1')
        n_subjects = len(sample['poses_60Hz'])

        for s in range(n_subjects):
            poses = sample['poses_60Hz'][s][:, :C.MAX_INDEX_ROOT_AND_BODY]  # (N_FRAMES, 66)
            betas = sample['betas'][s][:C.N_SHAPE_PARAMS]  # (N_SHAPE_PARAMS, )
            trans = sample['trans_60Hz'][s]  # (N_FRAMES, 3)
            gender = sample['genders'][s]
            n_frames = poses.shape[0]

            n_shards = n_frames // max_possible_len
            joints = []
            for j in range(n_shards+1):
                sf = j*max_possible_len
                ef = None if j == n_shards else (j+1)*max_possible_len
                with torch.no_grad():
                    ps = torch.from_numpy(poses[sf:ef]).to(dtype=torch.float32, device=C.DEVICE)
                    ts = torch.from_numpy(trans[sf:ef]).to(dtype=torch.float32, device=C.DEVICE)
                    bs = torch.from_numpy(betas).to(dtype=torch.float32, device=C.DEVICE)
                    _, js = smpl_layer(poses_body=ps[:, 3:], betas=bs, poses_root=ps[:, :3], trans=ts)
                    joints.append(js[:, :(1 + C.N_JOINTS)].reshape(-1, (1 + C.N_JOINTS)*3).detach().cpu().numpy())

            joints = np.concatenate(joints, axis=0)  # (N_FRAMES, 66)
            assert joints.shape[0] == n_frames

            gender = 'female' if gender == 'f' else 'male'

            # Store.
            cache["poses{}".format(idx)] = poses.astype(np.float32).tobytes()
            cache["betas{}".format(idx)] = betas.astype(np.float32).tobytes()
            cache["trans{}".format(idx)] = trans.astype(np.float32).tobytes()
            cache["joints{}".format(idx)] = joints.astype(np.float32).tobytes()
            cache["n_frames{}".format(idx)] = "{}".format(n_frames).encode()
            cache["id{}".format(idx)] = file_id.encode()
            cache["gender{}".format(idx)] = gender.encode()

            if idx > 0 and idx % 1000 == 0:
                with env.begin(write=True) as txn:
                    for k, v in cache.items():
                        txn.put(k.encode(), v)
                    cache = dict()
                    torch.cuda.empty_cache()

            idx += 1

    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k.encode(), v)
        txn.put('__len__'.encode(), "{}".format(len(pkl_files)).encode())