in extensions/mvpraymarch/utils.h [137:236]
static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H,
int inp_W, float* vals, float3 pos, bool border) {
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
int out_sC = 1;
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
if (border) {
// clip coordinates to image borders
ix = clip_coordinates(ix, inp_W);
iy = clip_coordinates(iy, inp_H);
iz = clip_coordinates(iz, inp_D);
}
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int ix_tnw = static_cast<int>(::floor(ix));
int iy_tnw = static_cast<int>(::floor(iy));
int iz_tnw = static_cast<int>(::floor(iz));
int ix_tne = ix_tnw + 1;
int iy_tne = iy_tnw;
int iz_tne = iz_tnw;
int ix_tsw = ix_tnw;
int iy_tsw = iy_tnw + 1;
int iz_tsw = iz_tnw;
int ix_tse = ix_tnw + 1;
int iy_tse = iy_tnw + 1;
int iz_tse = iz_tnw;
int ix_bnw = ix_tnw;
int iy_bnw = iy_tnw;
int iz_bnw = iz_tnw + 1;
int ix_bne = ix_tnw + 1;
int iy_bne = iy_tnw;
int iz_bne = iz_tnw + 1;
int ix_bsw = ix_tnw;
int iy_bsw = iy_tnw + 1;
int iz_bsw = iz_tnw + 1;
int ix_bse = ix_tnw + 1;
int iy_bse = iy_tnw + 1;
int iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
out_t result;
//auto inp_ptr_NC = input.data + n * inp_sN;
//auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
float * inp_ptr_NC = vals;
float * out_ptr_NCDHW = &result.x;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
*out_ptr_NCDHW = static_cast<float>(0);
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
}
}
return result;
}