std::tuple Renderer::arg_check()

in pytorch3d/csrc/pulsar/pytorch/renderer.cpp [203:656]


std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
    const torch::Tensor& vert_pos,
    const torch::Tensor& vert_col,
    const torch::Tensor& vert_radii,
    const torch::Tensor& cam_pos,
    const torch::Tensor& pixel_0_0_center,
    const torch::Tensor& pixel_vec_x,
    const torch::Tensor& pixel_vec_y,
    const torch::Tensor& focal_length,
    const torch::Tensor& principal_point_offsets,
    const float& gamma,
    const float& max_depth,
    float& min_depth,
    const c10::optional<torch::Tensor>& bg_col,
    const c10::optional<torch::Tensor>& opacity,
    const float& percent_allowed_difference,
    const uint& max_n_hits,
    const uint& mode) {
  LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD) << "Arg check.";
  size_t batch_size = 1;
  size_t n_points;
  bool batch_processing = false;
  if (vert_pos.ndimension() == 3) {
    // Check all parameters adhere batch size.
    batch_processing = true;
    batch_size = vert_pos.size(0);
    TORCH_CHECK_ARG(
        vert_col.ndimension() == 3 &&
            vert_col.size(0) == static_cast<int64_t>(batch_size),
        2,
        "vert_col needs to have batch size.");
    TORCH_CHECK_ARG(
        vert_radii.ndimension() == 2 &&
            vert_radii.size(0) == static_cast<int64_t>(batch_size),
        3,
        "vert_radii must be specified per batch.");
    TORCH_CHECK_ARG(
        cam_pos.ndimension() == 2 &&
            cam_pos.size(0) == static_cast<int64_t>(batch_size),
        4,
        "cam_pos must be specified per batch and have the correct batch size.");
    TORCH_CHECK_ARG(
        pixel_0_0_center.ndimension() == 2 &&
            pixel_0_0_center.size(0) == static_cast<int64_t>(batch_size),
        5,
        "pixel_0_0_center must be specified per batch.");
    TORCH_CHECK_ARG(
        pixel_vec_x.ndimension() == 2 &&
            pixel_vec_x.size(0) == static_cast<int64_t>(batch_size),
        6,
        "pixel_vec_x must be specified per batch.");
    TORCH_CHECK_ARG(
        pixel_vec_y.ndimension() == 2 &&
            pixel_vec_y.size(0) == static_cast<int64_t>(batch_size),
        7,
        "pixel_vec_y must be specified per batch.");
    TORCH_CHECK_ARG(
        focal_length.ndimension() == 1 &&
            focal_length.size(0) == static_cast<int64_t>(batch_size),
        8,
        "focal_length must be specified per batch.");
    TORCH_CHECK_ARG(
        principal_point_offsets.ndimension() == 2 &&
            principal_point_offsets.size(0) == static_cast<int64_t>(batch_size),
        9,
        "principal_point_offsets must be specified per batch.");
    if (opacity.has_value()) {
      TORCH_CHECK_ARG(
          opacity.value().ndimension() == 2 &&
              opacity.value().size(0) == static_cast<int64_t>(batch_size),
          13,
          "Opacity needs to be specified batch-wise.");
    }
    // Check all parameters are for a matching number of points.
    n_points = vert_pos.size(1);
    TORCH_CHECK_ARG(
        vert_col.size(1) == static_cast<int64_t>(n_points),
        2,
        ("The number of points for vertex positions (" +
         std::to_string(n_points) + ") and vertex colors (" +
         std::to_string(vert_col.size(1)) + ") doesn't agree.")
            .c_str());
    TORCH_CHECK_ARG(
        vert_radii.size(1) == static_cast<int64_t>(n_points),
        3,
        ("The number of points for vertex positions (" +
         std::to_string(n_points) + ") and vertex radii (" +
         std::to_string(vert_col.size(1)) + ") doesn't agree.")
            .c_str());
    if (opacity.has_value()) {
      TORCH_CHECK_ARG(
          opacity.value().size(1) == static_cast<int64_t>(n_points),
          13,
          "Opacity needs to be specified per point.");
    }
    // Check all parameters have the correct last dimension size.
    TORCH_CHECK_ARG(
        vert_pos.size(2) == 3,
        1,
        ("Vertex positions must be 3D (have shape " +
         std::to_string(vert_pos.size(2)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        vert_col.size(2) == this->renderer_vec[0].cam.n_channels,
        2,
        ("Vertex colors must have the right number of channels (have shape " +
         std::to_string(vert_col.size(2)) + ", need " +
         std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        cam_pos.size(1) == 3,
        4,
        ("Camera position must be 3D (has shape " +
         std::to_string(cam_pos.size(1)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        pixel_0_0_center.size(1) == 3,
        5,
        ("pixel_0_0_center must be 3D (has shape " +
         std::to_string(pixel_0_0_center.size(1)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        pixel_vec_x.size(1) == 3,
        6,
        ("pixel_vec_x must be 3D (has shape " +
         std::to_string(pixel_vec_x.size(1)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        pixel_vec_y.size(1) == 3,
        7,
        ("pixel_vec_y must be 3D (has shape " +
         std::to_string(pixel_vec_y.size(1)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        principal_point_offsets.size(1) == 2,
        9,
        "principal_point_offsets must contain x and y offsets.");
    // Ensure enough renderers are available for the batch.
    ensure_n_renderers_gte(batch_size);
  } else {
    // Check all parameters are of correct dimension.
    TORCH_CHECK_ARG(
        vert_col.ndimension() == 2, 2, "vert_col needs to have dimension 2.");
    TORCH_CHECK_ARG(
        vert_radii.ndimension() == 1, 3, "vert_radii must have dimension 1.");
    TORCH_CHECK_ARG(
        cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1.");
    TORCH_CHECK_ARG(
        pixel_0_0_center.ndimension() == 1,
        5,
        "pixel_0_0_center must have dimension 1.");
    TORCH_CHECK_ARG(
        pixel_vec_x.ndimension() == 1, 6, "pixel_vec_x must have dimension 1.");
    TORCH_CHECK_ARG(
        pixel_vec_y.ndimension() == 1, 7, "pixel_vec_y must have dimension 1.");
    TORCH_CHECK_ARG(
        focal_length.ndimension() == 0,
        8,
        "focal_length must have dimension 0.");
    TORCH_CHECK_ARG(
        principal_point_offsets.ndimension() == 1,
        9,
        "principal_point_offsets must have dimension 1.");
    if (opacity.has_value()) {
      TORCH_CHECK_ARG(
          opacity.value().ndimension() == 1,
          13,
          "Opacity needs to be specified per sample.");
    }
    // Check each.
    n_points = vert_pos.size(0);
    TORCH_CHECK_ARG(
        vert_col.size(0) == static_cast<int64_t>(n_points),
        2,
        ("The number of points for vertex positions (" +
         std::to_string(n_points) + ") and vertex colors (" +
         std::to_string(vert_col.size(0)) + ") doesn't agree.")
            .c_str());
    TORCH_CHECK_ARG(
        vert_radii.size(0) == static_cast<int64_t>(n_points),
        3,
        ("The number of points for vertex positions (" +
         std::to_string(n_points) + ") and vertex radii (" +
         std::to_string(vert_col.size(0)) + ") doesn't agree.")
            .c_str());
    if (opacity.has_value()) {
      TORCH_CHECK_ARG(
          opacity.value().size(0) == static_cast<int64_t>(n_points),
          12,
          "Opacity needs to be specified per point.");
    }
    // Check all parameters have the correct last dimension size.
    TORCH_CHECK_ARG(
        vert_pos.size(1) == 3,
        1,
        ("Vertex positions must be 3D (have shape " +
         std::to_string(vert_pos.size(1)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        vert_col.size(1) == this->renderer_vec[0].cam.n_channels,
        2,
        ("Vertex colors must have the right number of channels (have shape " +
         std::to_string(vert_col.size(1)) + ", need " +
         std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        cam_pos.size(0) == 3,
        4,
        ("Camera position must be 3D (has shape " +
         std::to_string(cam_pos.size(0)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        pixel_0_0_center.size(0) == 3,
        5,
        ("pixel_0_0_center must be 3D (has shape " +
         std::to_string(pixel_0_0_center.size(0)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        pixel_vec_x.size(0) == 3,
        6,
        ("pixel_vec_x must be 3D (has shape " +
         std::to_string(pixel_vec_x.size(0)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        pixel_vec_y.size(0) == 3,
        7,
        ("pixel_vec_y must be 3D (has shape " +
         std::to_string(pixel_vec_y.size(0)) + ")!")
            .c_str());
    TORCH_CHECK_ARG(
        principal_point_offsets.size(0) == 2,
        9,
        "principal_point_offsets must have x and y component.");
  }
  // Check device placement.
  auto dev = torch::device_of(vert_pos).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      1,
      ("Vertex positions must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(vert_col).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      2,
      ("Vertex colors must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(vert_radii).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      3,
      ("Vertex radii must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(cam_pos).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      4,
      ("Camera position must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(pixel_0_0_center).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      5,
      ("pixel_0_0_center must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(pixel_vec_x).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      6,
      ("pixel_vec_x must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(pixel_vec_y).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      7,
      ("pixel_vec_y must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  dev = torch::device_of(principal_point_offsets).value();
  TORCH_CHECK_ARG(
      dev.type() == this->device_type && dev.index() == this->device_index,
      9,
      ("principal_point_offsets must be stored on device " +
       c10::DeviceTypeName(this->device_type) + ", index " +
       std::to_string(this->device_index) + "! Are stored on " +
       c10::DeviceTypeName(dev.type()) + ", index " +
       std::to_string(dev.index()) + ".")
          .c_str());
  if (opacity.has_value()) {
    dev = torch::device_of(opacity.value()).value();
    TORCH_CHECK_ARG(
        dev.type() == this->device_type && dev.index() == this->device_index,
        13,
        ("opacity must be stored on device " +
         c10::DeviceTypeName(this->device_type) + ", index " +
         std::to_string(this->device_index) + "! Is stored on " +
         c10::DeviceTypeName(dev.type()) + ", index " +
         std::to_string(dev.index()) + ".")
            .c_str());
  }
  // Type checks.
  TORCH_CHECK_ARG(
      vert_pos.scalar_type() == c10::kFloat, 1, "pulsar requires float types.");
  TORCH_CHECK_ARG(
      vert_col.scalar_type() == c10::kFloat, 2, "pulsar requires float types.");
  TORCH_CHECK_ARG(
      vert_radii.scalar_type() == c10::kFloat,
      3,
      "pulsar requires float types.");
  TORCH_CHECK_ARG(
      cam_pos.scalar_type() == c10::kFloat, 4, "pulsar requires float types.");
  TORCH_CHECK_ARG(
      pixel_0_0_center.scalar_type() == c10::kFloat,
      5,
      "pulsar requires float types.");
  TORCH_CHECK_ARG(
      pixel_vec_x.scalar_type() == c10::kFloat,
      6,
      "pulsar requires float types.");
  TORCH_CHECK_ARG(
      pixel_vec_y.scalar_type() == c10::kFloat,
      7,
      "pulsar requires float types.");
  TORCH_CHECK_ARG(
      focal_length.scalar_type() == c10::kFloat,
      8,
      "pulsar requires float types.");
  TORCH_CHECK_ARG(
      // Unfortunately, the PyTorch interface is inconsistent for
      // Int32: in Python, there exists an explicit int32 type, in
      // C++ this is currently `c10::kInt`.
      principal_point_offsets.scalar_type() == c10::kInt,
      9,
      "principal_point_offsets must be provided as int32.");
  if (opacity.has_value()) {
    TORCH_CHECK_ARG(
        opacity.value().scalar_type() == c10::kFloat,
        13,
        "opacity must be a float type.");
  }
  // Content checks.
  TORCH_CHECK_ARG(
      (vert_radii > FEPS).all().item<bool>(),
      3,
      ("Vertex radii must be > FEPS (min is " +
       std::to_string(vert_radii.min().item<float>()) + ").")
          .c_str());
  if (this->orthogonal()) {
    TORCH_CHECK_ARG(
        (focal_length == 0.f).all().item<bool>(),
        8,
        ("for an orthogonal projection focal length must be zero (abs max: " +
         std::to_string(focal_length.abs().max().item<float>()) + ").")
            .c_str());
  } else {
    TORCH_CHECK_ARG(
        (focal_length > FEPS).all().item<bool>(),
        8,
        ("for a perspective projection focal length must be > FEPS (min " +
         std::to_string(focal_length.min().item<float>()) + ").")
            .c_str());
  }
  TORCH_CHECK_ARG(
      gamma <= 1.f && gamma >= 1E-5f,
      10,
      ("gamma must be in [1E-5, 1] (" + std::to_string(gamma) + ").").c_str());
  if (min_depth == 0.f) {
    min_depth = focal_length.max().item<float>() + 2.f * FEPS;
  }
  TORCH_CHECK_ARG(
      min_depth > focal_length.max().item<float>(),
      12,
      ("min_depth must be > focal_length (" + std::to_string(min_depth) +
       " vs. " + std::to_string(focal_length.max().item<float>()) + ").")
          .c_str());
  TORCH_CHECK_ARG(
      max_depth > min_depth + FEPS,
      11,
      ("max_depth must be > min_depth + FEPS (" + std::to_string(max_depth) +
       " vs. " + std::to_string(min_depth + FEPS) + ").")
          .c_str());
  TORCH_CHECK_ARG(
      percent_allowed_difference >= 0.f && percent_allowed_difference < 1.f,
      14,
      ("percent_allowed_difference must be in [0., 1.[ (" +
       std::to_string(percent_allowed_difference) + ").")
          .c_str());
  TORCH_CHECK_ARG(max_n_hits > 0, 14, "max_n_hits must be > 0!");
  TORCH_CHECK_ARG(mode < 2, 15, "mode must be in {0, 1}.");
  torch::Tensor real_bg_col;
  if (bg_col.has_value()) {
    TORCH_CHECK_ARG(
        bg_col.value().device().type() == this->device_type &&
            bg_col.value().device().index() == this->device_index,
        13,
        "bg_col must be stored on the renderer device!");
    TORCH_CHECK_ARG(
        bg_col.value().ndimension() == 1 &&
            bg_col.value().size(0) == renderer_vec[0].cam.n_channels,
        13,
        "bg_col must have the same number of channels as the image,).");
    real_bg_col = bg_col.value();
  } else {
    real_bg_col = torch::ones(
                      {renderer_vec[0].cam.n_channels},
                      c10::Device(this->device_type, this->device_index))
                      .to(c10::kFloat);
  }
  if (opacity.has_value()) {
    TORCH_CHECK_ARG(
        (opacity.value() >= 0.f).all().item<bool>(),
        13,
        "opacity must be >= 0.");
    TORCH_CHECK_ARG(
        (opacity.value() <= 1.f).all().item<bool>(),
        13,
        "opacity must be <= 1.");
  }
  LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD)
      << "  batch_size: " << batch_size;
  LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD)
      << "  n_points: " << n_points;
  LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD)
      << "  batch_processing: " << batch_processing;
  return std::tuple<size_t, size_t, bool, torch::Tensor>(
      batch_size, n_points, batch_processing, real_bg_col);
}