in pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp [330:471]
torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& face_verts, // (F, 3, 3)
const torch::Tensor& pix_to_face, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct,
const bool clip_barycentric_coords) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
const int W = pix_to_face.size(2);
const int K = pix_to_face.size(3);
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
auto face_verts_a = face_verts.accessor<float, 3>();
auto pix_to_face_a = pix_to_face.accessor<int64_t, 4>();
auto grad_dists_a = grad_dists.accessor<float, 4>();
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
auto grad_bary_a = grad_bary.accessor<float, 5>();
for (int n = 0; n < N; ++n) {
// Iterate through the horizontal lines of the image from top to bottom.
for (int y = 0; y < H; ++y) {
// Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - y;
// Y coordinate of the top of the pixel.
const float yf = PixToNonSquareNdc(yidx, H, W);
// Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) {
// Reverse the order of xi so that +X is pointing to the left in the
// image.
const int xidx = W - 1 - x;
// X coordinate of the left of the pixel.
const float xf = PixToNonSquareNdc(xidx, W, H);
const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel.
for (int k = 0; k < K; ++k) {
// Get face index from forward pass output.
const int f = pix_to_face_a[n][y][x][k];
if (f < 0) {
continue; // padded face.
}
// Get coordinates of the three face vertices.
const auto face_verts_f = face_verts_a[f];
const float x0 = face_verts_f[0][0];
const float y0 = face_verts_f[0][1];
const float z0 = face_verts_f[0][2];
const float x1 = face_verts_f[1][0];
const float y1 = face_verts_f[1][1];
const float z1 = face_verts_f[1][2];
const float x2 = face_verts_f[2][0];
const float y2 = face_verts_f[2][1];
const float z2 = face_verts_f[2][2];
const vec2<float> v0xy(x0, y0);
const vec2<float> v1xy(x1, y1);
const vec2<float> v2xy(x2, y2);
// Get upstream gradients for the face.
const float grad_dist_upstream = grad_dists_a[n][y][x][k];
const float grad_zbuf_upstream = grad_zbuf_a[n][y][x][k];
const auto grad_bary_upstream_w012 = grad_bary_a[n][y][x][k];
const float grad_bary_upstream_w0 = grad_bary_upstream_w012[0];
const float grad_bary_upstream_w1 = grad_bary_upstream_w012[1];
const float grad_bary_upstream_w2 = grad_bary_upstream_w012[2];
const vec3<float> grad_bary_upstream(
grad_bary_upstream_w0,
grad_bary_upstream_w1,
grad_bary_upstream_w2);
const vec3<float> bary0 =
BarycentricCoordinatesForward(pxy, v0xy, v1xy, v2xy);
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
const vec3<float> bary_clip =
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
// Distances inside the face are negative so get the
// correct sign to apply to the upstream gradient.
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
const auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const auto ddist_d_v0 = std::get<1>(grad_dist_f);
const auto ddist_d_v1 = std::get<2>(grad_dist_f);
const auto ddist_d_v2 = std::get<3>(grad_dist_f);
// Upstream gradient for barycentric coords from zbuf calculation:
// zbuf = bary_w0 * z0 + bary_w1 * z1 + bary_w2 * z2
// Therefore
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const vec3<float> d_zbuf_d_baryclip(z0, z1, z2);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
const vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip);
vec3<float> grad_bary0 = grad_bary_f_sum;
if (clip_barycentric_coords) {
grad_bary0 = BarycentricClipBackward(bary, grad_bary0);
}
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, z0, z1, z2, grad_bary0);
grad_bary0 = std::get<0>(perspective_grads);
grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
const vec2<float> dbary_d_v1 = std::get<2>(grad_bary_f);
const vec2<float> dbary_d_v2 = std::get<3>(grad_bary_f);
// Update output gradient buffer.
grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x;
grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y;
grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z;
}
}
}
}
return grad_face_verts;
}