GLOBAL void calc_gradients()

in pytorch3d/csrc/pulsar/include/renderer.calc_gradients.device.h [22:193]


GLOBAL void calc_gradients(
    const CamInfo cam, /** Camera in world coordinates. */
    float const* const RESTRICT grad_im, /** The gradient image. */
    const float
        gamma, /** The transparency parameter used in the forward pass. */
    float3 const* const RESTRICT vert_poss, /** Vertex position vector. */
    float const* const RESTRICT vert_cols, /** Vertex color vector. */
    float const* const RESTRICT vert_rads, /** Vertex radius vector. */
    float const* const RESTRICT opacity, /** Vertex opacity. */
    const uint num_balls, /** Number of balls. */
    float const* const RESTRICT result_d, /** Result image. */
    float const* const RESTRICT forw_info_d, /** Forward pass info. */
    DrawInfo const* const RESTRICT di_d, /** Draw information. */
    IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
    // Mode switches.
    const bool calc_grad_pos,
    const bool calc_grad_col,
    const bool calc_grad_rad,
    const bool calc_grad_cam,
    const bool calc_grad_opy,
    // Out variables.
    float* const RESTRICT grad_rad_d, /** Radius gradients. */
    float* const RESTRICT grad_col_d, /** Color gradients. */
    float3* const RESTRICT grad_pos_d, /** Position gradients. */
    CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */
    float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */
    int* const RESTRICT
        grad_contributed_d, /** Gradient contribution counter. */
    // Infrastructure.
    const int n_track,
    const uint offs_x,
    const uint offs_y /** Debug offsets. */
) {
  uint limit_x = cam.film_width, limit_y = cam.film_height;
  if (offs_x != 0) {
    // We're in debug mode.
    limit_x = 1;
    limit_y = 1;
  }
  GET_PARALLEL_IDS_2D(coord_x_base, coord_y_base, limit_x, limit_y);
  // coord_x_base and coord_y_base are in the film coordinate system.
  // We now need to translate to the aperture coordinate system. If
  // the principal point was shifted left/up nothing has to be
  // subtracted - only shift needs to be added in case it has been
  // shifted down/right.
  const uint film_coord_x = coord_x_base + offs_x;
  const uint ap_coord_x = film_coord_x +
      2 * static_cast<uint>(std::max(0, cam.principal_point_offset_x));
  const uint film_coord_y = coord_y_base + offs_y;
  const uint ap_coord_y = film_coord_y +
      2 * static_cast<uint>(std::max(0, cam.principal_point_offset_y));
  const float3 ray_dir = /** Ray cast through the pixel, normalized. */
      cam.pixel_0_0_center + ap_coord_x * cam.pixel_dir_x +
      ap_coord_y * cam.pixel_dir_y;
  const float norm_ray_dir = length(ray_dir);
  // ray_dir_norm *must* be calculated here in the same way as in the draw
  // function to have the same values withno other numerical instabilities
  // (for example, ray_dir * FRCP(norm_ray_dir) does not work)!
  float3 ray_dir_norm; /** Ray cast through the pixel, normalized. */
  float2 projected_ray; /** Ray intersection with the sensor. */
  if (cam.orthogonal_projection) {
    ray_dir_norm = cam.sensor_dir_z;
    projected_ray.x = static_cast<float>(ap_coord_x);
    projected_ray.y = static_cast<float>(ap_coord_y);
  } else {
    ray_dir_norm = normalize(
        cam.pixel_0_0_center + ap_coord_x * cam.pixel_dir_x +
        ap_coord_y * cam.pixel_dir_y);
    // This is a reasonable assumption for normal focal lengths and image sizes.
    PASSERT(FABS(ray_dir_norm.z) > FEPS);
    projected_ray.x = ray_dir_norm.x / ray_dir_norm.z * cam.focal_length;
    projected_ray.y = ray_dir_norm.y / ray_dir_norm.z * cam.focal_length;
  }
  float* result = const_cast<float*>(
      result_d + film_coord_y * cam.film_width * cam.n_channels +
      film_coord_x * cam.n_channels);
  const float* grad_im_l = grad_im +
      film_coord_y * cam.film_width * cam.n_channels +
      film_coord_x * cam.n_channels;
  // For writing...
  float3 grad_pos;
  float grad_rad, grad_opy;
  CamGradInfo grad_cam_local = CamGradInfo();
  // Set up shared infrastructure.
  const int fwi_loc = film_coord_y * cam.film_width * (3 + 2 * n_track) +
      film_coord_x * (3 + 2 * n_track);
  float sm_m = forw_info_d[fwi_loc];
  float sm_d = forw_info_d[fwi_loc + 1];
  PULSAR_LOG_DEV_APIX(
      PULSAR_LOG_GRAD,
      "grad|sm_m: %f, sm_d: %f, result: "
      "%f, %f, %f; grad_im: %f, %f, %f.\n",
      sm_m,
      sm_d,
      result[0],
      result[1],
      result[2],
      grad_im_l[0],
      grad_im_l[1],
      grad_im_l[2]);
  // Start processing.
  for (int grad_idx = 0; grad_idx < n_track; ++grad_idx) {
    int sphere_idx;
    FASI(forw_info_d[fwi_loc + 3 + 2 * grad_idx], sphere_idx);
    PASSERT(
        sphere_idx == -1 ||
        sphere_idx >= 0 && static_cast<uint>(sphere_idx) < num_balls);
    if (sphere_idx >= 0) {
      // TODO: make more efficient.
      grad_pos = make_float3(0.f, 0.f, 0.f);
      grad_rad = 0.f;
      grad_cam_local = CamGradInfo();
      const DrawInfo di = di_d[sphere_idx];
      grad_opy = 0.f;
      draw(
          di,
          opacity == NULL ? 1.f : opacity[sphere_idx],
          cam,
          gamma,
          ray_dir_norm,
          projected_ray,
          // Mode switches.
          false, // draw only
          calc_grad_pos,
          calc_grad_col,
          calc_grad_rad,
          calc_grad_cam,
          calc_grad_opy,
          // Position info.
          ap_coord_x,
          ap_coord_y,
          sphere_idx,
          // Optional in.
          &ii_d[sphere_idx],
          &ray_dir,
          &norm_ray_dir,
          grad_im_l,
          NULL,
          // In/out
          &sm_d,
          &sm_m,
          result,
          // Optional out.
          NULL,
          NULL,
          &grad_pos,
          grad_col_d + sphere_idx * cam.n_channels,
          &grad_rad,
          &grad_cam_local,
          &grad_opy);
      ATOMICADD(&(grad_rad_d[sphere_idx]), grad_rad);
      // Color has been added directly.
      ATOMICADD_F3(&(grad_pos_d[sphere_idx]), grad_pos);
      ATOMICADD_F3(
          &(grad_cam_buf_d[sphere_idx].cam_pos), grad_cam_local.cam_pos);
      if (!cam.orthogonal_projection) {
        ATOMICADD_F3(
            &(grad_cam_buf_d[sphere_idx].pixel_0_0_center),
            grad_cam_local.pixel_0_0_center);
      }
      ATOMICADD_F3(
          &(grad_cam_buf_d[sphere_idx].pixel_dir_x),
          grad_cam_local.pixel_dir_x);
      ATOMICADD_F3(
          &(grad_cam_buf_d[sphere_idx].pixel_dir_y),
          grad_cam_local.pixel_dir_y);
      ATOMICADD(&(grad_opy_d[sphere_idx]), grad_opy);
      ATOMICADD(&(grad_contributed_d[sphere_idx]), 1);
    }
  }
  END_PARALLEL_2D_NORET();
};