torch::Tensor RasterizeMeshesBackwardCpu()

in pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp [330:471]


torch::Tensor RasterizeMeshesBackwardCpu(
    const torch::Tensor& face_verts, // (F, 3, 3)
    const torch::Tensor& pix_to_face, // (N, H, W, K)
    const torch::Tensor& grad_zbuf, // (N, H, W, K)
    const torch::Tensor& grad_bary, // (N, H, W, K, 3)
    const torch::Tensor& grad_dists, // (N, H, W, K)
    const bool perspective_correct,
    const bool clip_barycentric_coords) {
  const int F = face_verts.size(0);
  const int N = pix_to_face.size(0);
  const int H = pix_to_face.size(1);
  const int W = pix_to_face.size(2);
  const int K = pix_to_face.size(3);

  torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
  auto face_verts_a = face_verts.accessor<float, 3>();
  auto pix_to_face_a = pix_to_face.accessor<int64_t, 4>();
  auto grad_dists_a = grad_dists.accessor<float, 4>();
  auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
  auto grad_bary_a = grad_bary.accessor<float, 5>();

  for (int n = 0; n < N; ++n) {
    // Iterate through the horizontal lines of the image from top to bottom.
    for (int y = 0; y < H; ++y) {
      // Reverse the order of yi so that +Y is pointing upwards in the image.
      const int yidx = H - 1 - y;

      // Y coordinate of the top of the pixel.
      const float yf = PixToNonSquareNdc(yidx, H, W);
      // Iterate through pixels on this horizontal line, left to right.
      for (int x = 0; x < W; ++x) {
        // Reverse the order of xi so that +X is pointing to the left in the
        // image.
        const int xidx = W - 1 - x;

        // X coordinate of the left of the pixel.
        const float xf = PixToNonSquareNdc(xidx, W, H);
        const vec2<float> pxy(xf, yf);

        // Iterate through the faces that hit this pixel.
        for (int k = 0; k < K; ++k) {
          // Get face index from forward pass output.
          const int f = pix_to_face_a[n][y][x][k];
          if (f < 0) {
            continue; // padded face.
          }
          // Get coordinates of the three face vertices.
          const auto face_verts_f = face_verts_a[f];
          const float x0 = face_verts_f[0][0];
          const float y0 = face_verts_f[0][1];
          const float z0 = face_verts_f[0][2];
          const float x1 = face_verts_f[1][0];
          const float y1 = face_verts_f[1][1];
          const float z1 = face_verts_f[1][2];
          const float x2 = face_verts_f[2][0];
          const float y2 = face_verts_f[2][1];
          const float z2 = face_verts_f[2][2];
          const vec2<float> v0xy(x0, y0);
          const vec2<float> v1xy(x1, y1);
          const vec2<float> v2xy(x2, y2);

          // Get upstream gradients for the face.
          const float grad_dist_upstream = grad_dists_a[n][y][x][k];
          const float grad_zbuf_upstream = grad_zbuf_a[n][y][x][k];
          const auto grad_bary_upstream_w012 = grad_bary_a[n][y][x][k];
          const float grad_bary_upstream_w0 = grad_bary_upstream_w012[0];
          const float grad_bary_upstream_w1 = grad_bary_upstream_w012[1];
          const float grad_bary_upstream_w2 = grad_bary_upstream_w012[2];
          const vec3<float> grad_bary_upstream(
              grad_bary_upstream_w0,
              grad_bary_upstream_w1,
              grad_bary_upstream_w2);

          const vec3<float> bary0 =
              BarycentricCoordinatesForward(pxy, v0xy, v1xy, v2xy);
          const vec3<float> bary = !perspective_correct
              ? bary0
              : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
          const vec3<float> bary_clip =
              !clip_barycentric_coords ? bary : BarycentricClipForward(bary);

          // Distances inside the face are negative so get the
          // correct sign to apply to the upstream gradient.
          const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
          const float sign = inside ? -1.0f : 1.0f;

          const auto grad_dist_f = PointTriangleDistanceBackward(
              pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
          const auto ddist_d_v0 = std::get<1>(grad_dist_f);
          const auto ddist_d_v1 = std::get<2>(grad_dist_f);
          const auto ddist_d_v2 = std::get<3>(grad_dist_f);

          // Upstream gradient for barycentric coords from zbuf calculation:
          // zbuf = bary_w0 * z0 + bary_w1 * z1 + bary_w2 * z2
          // Therefore
          // d_zbuf/d_bary_w0 = z0
          // d_zbuf/d_bary_w1 = z1
          // d_zbuf/d_bary_w2 = z2
          const vec3<float> d_zbuf_d_baryclip(z0, z1, z2);

          // Total upstream barycentric gradients are the sum of
          // external upstream gradients and contribution from zbuf.
          const vec3<float> grad_bary_f_sum =
              (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip);

          vec3<float> grad_bary0 = grad_bary_f_sum;

          if (clip_barycentric_coords) {
            grad_bary0 = BarycentricClipBackward(bary, grad_bary0);
          }

          if (perspective_correct) {
            auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
                bary0, z0, z1, z2, grad_bary0);
            grad_bary0 = std::get<0>(perspective_grads);
            grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
            grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
            grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
          }

          auto grad_bary_f =
              BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
          const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
          const vec2<float> dbary_d_v1 = std::get<2>(grad_bary_f);
          const vec2<float> dbary_d_v2 = std::get<3>(grad_bary_f);

          // Update output gradient buffer.
          grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
          grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
          grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x;
          grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
          grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
          grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y;
          grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
          grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
          grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z;
        }
      }
    }
  }
  return grad_face_verts;
}