std::tuple Renderer::forward()

in pytorch3d/csrc/pulsar/pytorch/renderer.cpp [658:888]


std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
    const torch::Tensor& vert_pos,
    const torch::Tensor& vert_col,
    const torch::Tensor& vert_radii,
    const torch::Tensor& cam_pos,
    const torch::Tensor& pixel_0_0_center,
    const torch::Tensor& pixel_vec_x,
    const torch::Tensor& pixel_vec_y,
    const torch::Tensor& focal_length,
    const torch::Tensor& principal_point_offsets,
    const float& gamma,
    const float& max_depth,
    float min_depth,
    const c10::optional<torch::Tensor>& bg_col,
    const c10::optional<torch::Tensor>& opacity,
    const float& percent_allowed_difference,
    const uint& max_n_hits,
    const uint& mode) {
  // Parameter checks.
  this->ensure_on_device(this->device_tracker.device());
  size_t batch_size;
  size_t n_points;
  bool batch_processing;
  torch::Tensor real_bg_col;
  std::tie(batch_size, n_points, batch_processing, real_bg_col) =
      this->arg_check(
          vert_pos,
          vert_col,
          vert_radii,
          cam_pos,
          pixel_0_0_center,
          pixel_vec_x,
          pixel_vec_y,
          focal_length,
          principal_point_offsets,
          gamma,
          max_depth,
          min_depth,
          bg_col,
          opacity,
          percent_allowed_difference,
          max_n_hits,
          mode);
  LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Extracting camera objects...";
  // Create the camera information.
  std::vector<CamInfo> cam_infos(batch_size);
  if (batch_processing) {
    for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
      cam_infos[batch_i] = cam_info_from_params(
          cam_pos[batch_i],
          pixel_0_0_center[batch_i],
          pixel_vec_x[batch_i],
          pixel_vec_y[batch_i],
          principal_point_offsets[batch_i],
          focal_length[batch_i].item<float>(),
          this->renderer_vec[0].cam.film_width,
          this->renderer_vec[0].cam.film_height,
          min_depth,
          max_depth,
          this->renderer_vec[0].cam.right_handed);
    }
  } else {
    cam_infos[0] = cam_info_from_params(
        cam_pos,
        pixel_0_0_center,
        pixel_vec_x,
        pixel_vec_y,
        principal_point_offsets,
        focal_length.item<float>(),
        this->renderer_vec[0].cam.film_width,
        this->renderer_vec[0].cam.film_height,
        min_depth,
        max_depth,
        this->renderer_vec[0].cam.right_handed);
  }
  LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Processing...";
  // Let's go!
  // Contiguous version of opacity, if available. We need to create this object
  // in scope to keep it alive.
  torch::Tensor opacity_contiguous;
  float const* opacity_ptr = nullptr;
  if (opacity.has_value()) {
    opacity_contiguous = opacity.value().contiguous();
    opacity_ptr = opacity_contiguous.data_ptr<float>();
  }
  if (this->device_type == c10::DeviceType::CUDA) {
// No else check necessary - if not compiled with CUDA
// we can't even reach this code (the renderer can't be
// moved to a CUDA device).
#ifdef WITH_CUDA
    int prev_active;
    cudaGetDevice(&prev_active);
    cudaSetDevice(this->device_index);
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
    START_TIME_CU(batch_forward);
#endif
    if (batch_processing) {
      for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
        // These calls are non-blocking and just kick off the computations.
        PRE::forward<true>(
            &this->renderer_vec[batch_i],
            vert_pos[batch_i].contiguous().data_ptr<float>(),
            vert_col[batch_i].contiguous().data_ptr<float>(),
            vert_radii[batch_i].contiguous().data_ptr<float>(),
            cam_infos[batch_i],
            gamma,
            percent_allowed_difference,
            max_n_hits,
            real_bg_col.contiguous().data_ptr<float>(),
            opacity_ptr,
            n_points,
            mode,
            at::cuda::getCurrentCUDAStream());
      }
    } else {
      PRE::forward<true>(
          this->renderer_vec.data(),
          vert_pos.contiguous().data_ptr<float>(),
          vert_col.contiguous().data_ptr<float>(),
          vert_radii.contiguous().data_ptr<float>(),
          cam_infos[0],
          gamma,
          percent_allowed_difference,
          max_n_hits,
          real_bg_col.contiguous().data_ptr<float>(),
          opacity_ptr,
          n_points,
          mode,
          at::cuda::getCurrentCUDAStream());
    }
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
    STOP_TIME_CU(batch_forward);
    float time_ms;
    GET_TIME_CU(batch_forward, &time_ms);
    std::cout << "Forward render batched time per example: "
              << time_ms / static_cast<float>(batch_size) << "ms" << std::endl;
#endif
    cudaSetDevice(prev_active);
#endif
  } else {
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
    START_TIME(batch_forward);
#endif
    if (batch_processing) {
      for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
        // These calls are non-blocking and just kick off the computations.
        PRE::forward<false>(
            &this->renderer_vec[batch_i],
            vert_pos[batch_i].contiguous().data_ptr<float>(),
            vert_col[batch_i].contiguous().data_ptr<float>(),
            vert_radii[batch_i].contiguous().data_ptr<float>(),
            cam_infos[batch_i],
            gamma,
            percent_allowed_difference,
            max_n_hits,
            real_bg_col.contiguous().data_ptr<float>(),
            opacity_ptr,
            n_points,
            mode,
            nullptr);
      }
    } else {
      PRE::forward<false>(
          this->renderer_vec.data(),
          vert_pos.contiguous().data_ptr<float>(),
          vert_col.contiguous().data_ptr<float>(),
          vert_radii.contiguous().data_ptr<float>(),
          cam_infos[0],
          gamma,
          percent_allowed_difference,
          max_n_hits,
          real_bg_col.contiguous().data_ptr<float>(),
          opacity_ptr,
          n_points,
          mode,
          nullptr);
    }
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
    STOP_TIME(batch_forward);
    float time_ms;
    GET_TIME(batch_forward, &time_ms);
    std::cout << "Forward render batched time per example: "
              << time_ms / static_cast<float>(batch_size) << "ms" << std::endl;
#endif
  }
  LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Extracting results...";
  // Create the results.
  std::vector<torch::Tensor> results(batch_size);
  std::vector<torch::Tensor> forw_infos(batch_size);
  for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
    results[batch_i] = from_blob(
        this->renderer_vec[batch_i].result_d,
        {this->renderer_vec[0].cam.film_height,
         this->renderer_vec[0].cam.film_width,
         this->renderer_vec[0].cam.n_channels},
        this->device_type,
        this->device_index,
        torch::kFloat,
        this->device_type == c10::DeviceType::CUDA
#ifdef WITH_CUDA
            ? at::cuda::getCurrentCUDAStream()
#else
            ? (cudaStream_t) nullptr
#endif
            : (cudaStream_t) nullptr);
    if (mode == 1)
      results[batch_i] = results[batch_i].slice(2, 0, 1, 1);
    forw_infos[batch_i] = from_blob(
        this->renderer_vec[batch_i].forw_info_d,
        {this->renderer_vec[0].cam.film_height,
         this->renderer_vec[0].cam.film_width,
         3 + 2 * this->n_track()},
        this->device_type,
        this->device_index,
        torch::kFloat,
        this->device_type == c10::DeviceType::CUDA
#ifdef WITH_CUDA
            ? at::cuda::getCurrentCUDAStream()
#else
            ? (cudaStream_t) nullptr
#endif
            : (cudaStream_t) nullptr);
  }
  LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Forward render complete.";
  if (batch_processing) {
    return std::tuple<torch::Tensor, torch::Tensor>(
        torch::stack(results), torch::stack(forw_infos));
  } else {
    return std::tuple<torch::Tensor, torch::Tensor>(results[0], forw_infos[0]);
  }
};