def convert_amass_to_lmdb()

in scripts/preprocess_amass_3dpw.py [0:0]


def convert_amass_to_lmdb(output_file, amass_root):
    """Convert AMASS to LMDB format that we can use during training."""
    print("Converting AMASS data under {} and exporting it to {} ...".format(amass_root, output_file))
    npz_file_ids = get_all_amass_file_ids(amass_root)
    smpl_layer = create_default_smpl_model(C.DEVICE)
    max_possible_len = 1000
    env = lmdb.open(output_file, map_size=1 << 33)
    cache = dict()

    for i in tqdm(range(len(npz_file_ids))):
        file_id = npz_file_ids[i]
        sample = np.load(os.path.join(amass_root, file_id))
        poses = sample['poses'][:, :C.MAX_INDEX_ROOT_AND_BODY]  # (N_FRAMES, 66)
        betas = sample['betas'][:C.N_SHAPE_PARAMS]  # (N_SHAPE_PARAMS, )
        trans = sample['trans']  # (N_FRAMES, 30)
        fps = sample['mocap_framerate'].tolist()
        gender = sample['gender'].tolist()
        n_frames = poses.shape[0]
        n_coords = poses.shape[1]

        # Resample to 60 FPS.
        poses = resample_rotations(poses.reshape(n_frames, -1, 3), fps, C.FPS).reshape(-1, n_coords)
        trans = resample_positions(trans, fps, C.FPS)

        # Extract joint information, watch out for CUDA out of memory.
        n_frames = poses.shape[0]
        n_shards = n_frames // max_possible_len
        joints = []
        for j in range(n_shards+1):
            sf = j*max_possible_len
            ef = None if j == n_shards else (j+1)*max_possible_len
            with torch.no_grad():
                ps = torch.from_numpy(poses[sf:ef]).to(dtype=torch.float32, device=C.DEVICE)
                ts = torch.from_numpy(trans[sf:ef]).to(dtype=torch.float32, device=C.DEVICE)
                bs = torch.from_numpy(betas).to(dtype=torch.float32, device=C.DEVICE)
                _, js = smpl_layer(poses_body=ps[:, 3:], betas=bs, poses_root=ps[:, :3], trans=ts)
                joints.append(js[:, :(1 + C.N_JOINTS)].reshape(-1, (1 + C.N_JOINTS)*3).detach().cpu().numpy())

        joints = np.concatenate(joints, axis=0)  # (N_FRAMES, 66)
        assert joints.shape[0] == n_frames

        if not isinstance(gender, str):
            gender = gender.decode()

        # Store.
        cache["poses{}".format(i)] = poses.astype(np.float32).tobytes()
        cache["betas{}".format(i)] = betas.astype(np.float32).tobytes()
        cache["trans{}".format(i)] = trans.astype(np.float32).tobytes()
        cache["joints{}".format(i)] = joints.astype(np.float32).tobytes()
        cache["n_frames{}".format(i)] = "{}".format(n_frames).encode()
        cache["id{}".format(i)] = file_id.encode()
        cache["gender{}".format(i)] = gender.encode()

        if i % 1000 == 0:
            with env.begin(write=True) as txn:
                for k, v in cache.items():
                    txn.put(k.encode(), v)
                cache = dict()
                torch.cuda.empty_cache()

    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k.encode(), v)
        txn.put('__len__'.encode(), "{}".format(len(npz_file_ids)).encode())