def generate_contactpose_dataset()

in contactopt/create_dataset_contactpose.py [0:0]


def generate_contactpose_dataset(dataset, output_file, low_p, high_p, num_pert=1, aug_trans=0.02, aug_rot=0.05, aug_pca=0.3):
    """
    Generates a dataset pkl file and does preprocessing for the PyTorch dataloader
    :param dataset: List of ContactPose objects
    :param output_file: path to output pkl file
    :param low_p: Lower split location of the dataset, [0-1)
    :param high_p: Upper split location of the dataset, [0-1)
    :param num_pert: Number of random perturbations which are computed for every true dataset sample
    :param aug_trans: Std deviation of hand translation noise added to the datasets, meters
    :param aug_rot: Std deviation of hand rotation noise, axis-angle radians
    :param aug_pca: Std deviation of hand pose noise, PCA units
    """
    low_split = int(len(dataset) * low_p)
    high_split = int(len(dataset) * high_p)
    dataset = dataset[low_split:high_split]

    if len(object_cut_list) > 0:
        dataset = [s for s in dataset if s[2] not in object_cut_list]
        print('Some objects are being removed', object_cut_list)

    def process_sample(s, idx):
        ho_gt = HandObject()
        ho_gt.load_from_contactpose(s[3])
        sample_list = []
        # print('Processing', idx)

        for i in range(num_pert):
            # Since we're only saving pointers to the data, it's memory efficient
            sample_data = dict()

            ho_aug = HandObject()
            aug_t = np.random.randn(3) * aug_trans
            aug_p = np.concatenate((np.random.randn(3) * aug_rot, np.random.randn(15) * aug_pca)).astype(np.float32)
            ho_aug.load_from_ho(ho_gt, aug_p, aug_t)

            sample_data['ho_gt'] = ho_gt
            sample_data['ho_aug'] = ho_aug
            sample_data['obj_sampled_idx'] = np.random.randint(0, len(ho_gt.obj_verts), SAMPLE_VERTS_NUM)
            sample_data['hand_feats_aug'], sample_data['obj_feats_aug'] = ho_aug.generate_pointnet_features(sample_data['obj_sampled_idx'])

            sample_list.append(sample_data)

        return sample_list

    parallel = True
    if parallel:
        num_cores = multiprocessing.cpu_count()
        print('Running on {} cores'.format(num_cores))
        all_data_2d = Parallel(n_jobs=num_cores)(delayed(process_sample)(s, idx) for idx, s in enumerate(tqdm(dataset)))
        all_data = [item for sublist in all_data_2d for item in sublist] # flatten 2d list
    else:
        all_data = []   # Do non-parallel
        for idx, s in enumerate(tqdm(dataset)):
            all_data.extend(process_sample(s, idx))

    print('Writing pickle file, often slow and freezes computer')
    pickle.dump(all_data, open(output_file, 'wb'))