static __device__ float3 grid_sample_backward()

in extensions/mvpraymarch/utils.h [239:380]


static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H,
        int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
        bool border) {
    int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
    int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D;
    int gOut_sC = 1;

    // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
    float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
    float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
    float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);

    float gix_mult = (inp_W - 1.f) / 2;
    float giy_mult = (inp_H - 1.f) / 2;
    float giz_mult = (inp_D - 1.f) / 2;

    if (border) {
        // clip coordinates to image borders
        ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
        iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
        iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
    }

    // get corner pixel values from (x, y, z)
    // for 4d, we used north-east-south-west
    // for 5d, we add top-bottom
    int ix_tnw = static_cast<int>(::floor(ix));
    int iy_tnw = static_cast<int>(::floor(iy));
    int iz_tnw = static_cast<int>(::floor(iz));

    int ix_tne = ix_tnw + 1;
    int iy_tne = iy_tnw;
    int iz_tne = iz_tnw;

    int ix_tsw = ix_tnw;
    int iy_tsw = iy_tnw + 1;
    int iz_tsw = iz_tnw;

    int ix_tse = ix_tnw + 1;
    int iy_tse = iy_tnw + 1;
    int iz_tse = iz_tnw;

    int ix_bnw = ix_tnw;
    int iy_bnw = iy_tnw;
    int iz_bnw = iz_tnw + 1;

    int ix_bne = ix_tnw + 1;
    int iy_bne = iy_tnw;
    int iz_bne = iz_tnw + 1;

    int ix_bsw = ix_tnw;
    int iy_bsw = iy_tnw + 1;
    int iz_bsw = iz_tnw + 1;

    int ix_bse = ix_tnw + 1;
    int iy_bse = iy_tnw + 1;
    int iz_bse = iz_tnw + 1;

    // get surfaces to each neighbor:
    float tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
    float tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
    float tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
    float tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
    float bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
    float bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
    float bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
    float bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);

    float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
    //float *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
    //float *gInp_ptr_NC = grad_input.data + n * gInp_sN;
    //float *inp_ptr_NC = input.data + n * inp_sN;
    float *gOut_ptr_NCDHW = &grad_out.x;
    float *gInp_ptr_NC = grad_vals;
    float *inp_ptr_NC = vals;
    // calculate bilinear weighted pixel value and set output pixel
    for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
      float gOut = *gOut_ptr_NCDHW;

      // calculate and set grad_input
      safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
      safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
      safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
      safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
      safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
      safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
      safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
      safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);

      // calculate grad_grid
      if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
        float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
        gix -= tnw_val * (iy_bse - iy)    * (iz_bse - iz)    * gOut;
        giy -= tnw_val * (ix_bse - ix)    * (iz_bse - iz)    * gOut;
        giz -= tnw_val * (ix_bse - ix)    * (iy_bse - iy)    * gOut;
      }
      if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
        float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
        gix += tne_val * (iy_bsw - iy)    * (iz_bsw - iz)    * gOut;
        giy -= tne_val * (ix    - ix_bsw) * (iz_bsw - iz)    * gOut;
        giz -= tne_val * (ix    - ix_bsw) * (iy_bsw - iy)    * gOut;
      }
      if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
        float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
        gix -= tsw_val * (iy - iy_bne)    * (iz_bne - iz)    * gOut;
        giy += tsw_val * (ix_bne - ix)    * (iz_bne - iz)    * gOut;
        giz -= tsw_val * (ix_bne - ix)    * (iy    - iy_bne) * gOut;
      }
      if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
        float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
        gix += tse_val * (iy - iy_bnw)    * (iz_bnw - iz)    * gOut;
        giy += tse_val * (ix    - ix_bnw) * (iz_bnw - iz)    * gOut;
        giz -= tse_val * (ix    - ix_bnw) * (iy    - iy_bnw) * gOut;
      }
      if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
        float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
        gix -= bnw_val * (iy_tse - iy)    * (iz - iz_tse)    * gOut;
        giy -= bnw_val * (ix_tse - ix)    * (iz - iz_tse)    * gOut;
        giz += bnw_val * (ix_tse - ix)    * (iy_tse - iy)    * gOut;
      }
      if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
        float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
        gix += bne_val * (iy_tsw - iy)    * (iz - iz_tsw)    * gOut;
        giy -= bne_val * (ix    - ix_tsw) * (iz - iz_tsw)    * gOut;
        giz += bne_val * (ix    - ix_tsw) * (iy_tsw - iy)    * gOut;
      }
      if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
        float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
        gix -= bsw_val * (iy - iy_tne)    * (iz - iz_tne)    * gOut;
        giy += bsw_val * (ix_tne - ix)    * (iz - iz_tne)    * gOut;
        giz += bsw_val * (ix_tne - ix)    * (iy    - iy_tne) * gOut;
      }
      if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
        float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
        gix += bse_val * (iy - iy_tnw)    * (iz - iz_tnw)    * gOut;
        giy += bse_val * (ix    - ix_tnw) * (iz - iz_tnw)    * gOut;
        giz += bse_val * (ix    - ix_tnw) * (iy    - iy_tnw) * gOut;
      }
    }

    return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
}