def extract_patch_data_rgb_detection()

in tools/data_prepare/patch_data_prepare_val.py [0:0]


def extract_patch_data_rgb_detection(det_filename, split, output_filename,
                                       whitelist=['Car'],
                                       img_height_threshold=25):
    ''' Extract point clouds in frustums extruded from 2D detection boxes.
        Update: Lidar points and 3d boxes are in *rect camera* coord system
            (as that in 3d box label files)

    Input:
        det_filename: string, each line is
            img_path typeid confidence xmin ymin xmax ymax
        split: string, either trianing or testing
        output_filename: string, the name for output .pickle file
        whitelist: a list of strings, object types we are interested in.
        img_height_threshold: int, neglect image with height lower than that.
        lidar_point_threshold: int, neglect frustum with too few points.
    Output:
        None (will write a .pickle file to the disk)
    '''
    data_dir = os.path.join(ROOT_DIR, 'data')
    dataset = KittiDataset(root_dir=data_dir, split=split)
    det_id_list, det_type_list, det_box2d_list, det_prob_list = read_det_file(det_filename)
    cache_id = -1
    cache = None

    id_list = []
    type_list = []
    box2d_list = []
    prob_list = []
    patch_xyz_list = []
    patch_rgb_list = []
    frustum_angle_list = []

    progress_bar = tqdm.tqdm(total=len(det_id_list), leave=True, desc='%s split patch data gen (from 2d detections)' % split)
    for det_idx in range(len(det_id_list)):
        data_idx = det_id_list[det_idx]
        if cache_id != data_idx:
            calib = dataset.get_calib(data_idx)

            # compute x,y,z for each pixel in depth map
            depth = dataset.get_depth(data_idx)
            image = dataset.get_image(data_idx)
            assert depth.size == image.size
            width, height = depth.size
            depth = np.array(depth).astype(np.float32) / 256
            uvdepth = np.zeros((height, width, 3), dtype=np.float32)
            for v in range(height):
                for u in range(width):
                    uvdepth[v, u, 0] = u
                    uvdepth[v, u, 1] = v
            uvdepth[:, :, 2] = depth
            uvdepth = uvdepth.reshape(-1, 3)
            xyz = calib.img_to_rect(uvdepth[:, 0], uvdepth[:, 1], uvdepth[:, 2])  # rect coord sys
            xyz = xyz.reshape(height, width, 3)  # record xyz, data type: float32
            rgb = np.array(image)

            cache = [xyz, rgb]
            cache_id = data_idx
        else:
            xyz, rgb = cache   # xyz map for whole image

        if det_type_list[det_idx] not in whitelist:
            progress_bar.update()
            continue

        # 2D BOX: Get pts rect backprojected
        xmin, ymin, xmax, ymax = det_box2d_list[det_idx]

        # Get frustum angle (according to center pixel in 2D BOX)
        box2d_center = np.array([(xmin + xmax) / 2.0, (ymin + ymax) / 2.0])
        uvdepth = np.zeros((1, 3))
        uvdepth[0, 0:2] = box2d_center
        uvdepth[0, 2] = 20  # some random depth
        box2d_center_rect = calib.img_to_rect(uvdepth[:, 0], uvdepth[:, 1], uvdepth[:, 2])
        frustum_angle = -1 * np.arctan2(box2d_center_rect[0, 2], box2d_center_rect[0, 0])

        # Pass objects that are too small
        if ymax - ymin < img_height_threshold:
            progress_bar.update()
            continue

        height, width, _ = xyz.shape
        xmin, ymin = max(xmin, 0), max(ymin, 0)  # check range
        xmax, ymax = min(xmax, width), min(ymax, height)  # check range
        patch_xyz = xyz[int(ymin):int(ymax), int(xmin):int(xmax), :]
        patch_rgb = rgb[int(ymin):int(ymax), int(xmin):int(xmax), :]

        id_list.append(data_idx)
        box2d_list.append(det_box2d_list[det_idx])
        patch_xyz_list.append(patch_xyz)
        patch_rgb_list.append(patch_rgb)
        type_list.append(det_type_list[det_idx])
        frustum_angle_list.append(frustum_angle)
        prob_list.append(det_prob_list[det_idx])

        progress_bar.update()
    progress_bar.close()

    with open(output_filename, 'wb') as fp:
        pickle.dump(id_list, fp)
        pickle.dump(box2d_list, fp)
        pickle.dump(patch_xyz_list, fp)
        pickle.dump(patch_rgb_list, fp)
        pickle.dump(type_list, fp)
        pickle.dump(frustum_angle_list, fp)
        pickle.dump(prob_list, fp)