def backproject()

in downstream/semseg/lib/pc_utils.py [0:0]


  def backproject(self,
                  depth_map,
                  labels=None,
                  max_depth=None,
                  max_height=None,
                  min_height=None,
                  rgb_img=None,
                  extrinsics=None,
                  prune=True):
    """Backproject a depth map into 3D points (camera coordinate system). Attach color if RGB image
    is provided, otherwise use gray [128 128 128] color.

    Does not show points at Z = 0 or maximum Z = 65535 depth.

    Args:
      labels: Tensor with the same shape as depth map (but can be 1-channel or 3-channel).
      max_depth: Maximum depth in cm. All pts with depth greater than max_depth will be ignored.
      max_height: Maximum height in cm. All pts with height greater than max_height will be ignored.

    Returns:
      points_3d: Numpy array of size Nx3 (XYZ) or Nx6 (XYZRGB).
    """
    if labels is not None:
      assert depth_map.shape[:2] == labels.shape[:2]
      if (labels.ndim == 2) or ((labels.ndim == 3) and (labels.shape[2] == 1)):
        n_label_channels = 1
      elif (labels.ndim == 3) and (labels.shape[2] == 3):
        n_label_channels = 3

    if rgb_img is not None:
      assert depth_map.shape[:2] == rgb_img.shape[:2]
    else:
      rgb_img = np.ones_like(depth_map, dtype=np.uint8) * 255

    # Convert from 1-channel to 3-channel
    if (rgb_img.ndim == 3) and (rgb_img.shape[2] == 1):
      rgb_img = np.tile(rgb_img, [1, 1, 3])

    # Convert depth map to single channel if it is multichannel
    if (depth_map.ndim == 3) and depth_map.shape[2] == 3:
      depth_map = np.squeeze(depth_map[:, :, 0])
    depth_map = depth_map.astype(np.float32)

    # Get image dimensions
    H, W = depth_map.shape

    # Create meshgrid (pixel coordinates)
    Z = depth_map
    A, B = np.meshgrid(range(W), range(H))
    ones = np.ones_like(A)
    grid = np.concatenate((A[:, :, np.newaxis], B[:, :, np.newaxis], ones[:, :, np.newaxis]),
                          axis=2)
    grid = grid.astype(np.float32) * Z[:, :, np.newaxis]
    # Nx3 where each row is (a*Z, b*Z, Z)
    grid_flattened = grid.reshape((-1, 3))
    grid_flattened = grid_flattened.T  # 3xN where each col is (a*Z, b*Z, Z)
    prod = np.dot(self.K_inv, grid_flattened)
    XYZ = np.concatenate((prod[:2, :].T, Z.flatten()[:, np.newaxis]), axis=1)  # Nx3
    XYZRGB = np.hstack((XYZ, rgb_img.reshape((-1, 3))))
    points_3d = XYZRGB

    if labels is not None:
      labels_reshaped = labels.reshape((-1, n_label_channels))

    # Prune points
    if prune is True:
      valid = []
      for idx in range(points_3d.shape[0]):
        cur_y = points_3d[idx, 1]
        cur_z = points_3d[idx, 2]
        if (cur_z == 0) or (cur_z == 65535):  # Don't show things at 0 distance or max distance
          continue
        elif (max_depth is not None) and (cur_z > max_depth):
          continue
        elif (max_height is not None) and (cur_y > max_height):
          continue
        elif (min_height is not None) and (cur_y < min_height):
          continue
        else:
          valid.append(idx)
      points_3d = points_3d[np.asarray(valid)]
      if labels is not None:
        labels_reshaped = labels_reshaped[np.asarray(valid)]

    if extrinsics is not None:
      points_3d = self.camera2world(extrinsics, points_3d)

    if labels is not None:
      points_3d_labels = np.hstack((points_3d[:, :3], labels_reshaped))
      return points_3d, points_3d_labels
    else:
      return points_3d