in extensions/mvpraymarch/mvpraymarch.cpp [282:396]
std::vector<torch::Tensor> raymarch_backward(
torch::Tensor rayposim,
torch::Tensor raydirim,
float stepsize,
torch::Tensor tminmaxim,
torch::optional<torch::Tensor> sortedobjid,
torch::optional<torch::Tensor> nodechildren,
torch::optional<torch::Tensor> nodeaabb,
torch::Tensor primpos,
torch::Tensor grad_primpos,
torch::optional<torch::Tensor> primrot,
torch::optional<torch::Tensor> grad_primrot,
torch::optional<torch::Tensor> primscale,
torch::optional<torch::Tensor> grad_primscale,
torch::Tensor tplate,
torch::Tensor grad_tplate,
torch::optional<torch::Tensor> warp,
torch::optional<torch::Tensor> grad_warp,
torch::Tensor rayrgbaim,
torch::Tensor grad_rayrgba,
torch::optional<torch::Tensor> raysatim,
torch::optional<torch::Tensor> raytermim,
int algorithm=0,
bool sortboxes=true,
int maxhitboxes=512,
bool synchitboxes=false,
bool chlast=false,
float fadescale=8.f,
float fadeexp=8.f,
int accum=0,
float termthresh=0.f,
int griddim=3,
int blocksizex=8,
int blocksizey=16) {
CHECK_INPUT(rayposim);
CHECK_INPUT(raydirim);
CHECK_INPUT(tminmaxim);
if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
if (nodechildren) { CHECK_INPUT(*nodechildren); }
if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
CHECK_INPUT(tplate);
if (warp) { CHECK_INPUT(*warp); }
CHECK_INPUT(primpos);
if (primrot) { CHECK_INPUT(*primrot); }
if (primscale) { CHECK_INPUT(*primscale); }
CHECK_INPUT(rayrgbaim);
if (raysatim) { CHECK_INPUT(*raysatim); }
if (raytermim) { CHECK_INPUT(*raytermim); }
CHECK_INPUT(grad_rayrgba);
CHECK_INPUT(grad_tplate);
if (grad_warp) { CHECK_INPUT(*grad_warp); }
CHECK_INPUT(grad_primpos);
if (grad_primrot) { CHECK_INPUT(*grad_primrot); }
if (grad_primscale) { CHECK_INPUT(*grad_primscale); }
int N = rayposim.size(0);
int H = rayposim.size(1);
int W = rayposim.size(2);
int K = primpos.size(1);
int TD, TH, TW;
if (chlast) {
TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
} else {
TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
}
int WD = 0, WH = 0, WW = 0;
if (warp) {
if (chlast) {
WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
} else {
WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
}
}
raymarch_backward_cuda(N, H, W, K,
reinterpret_cast<float *>(rayposim.data_ptr()),
reinterpret_cast<float *>(raydirim.data_ptr()),
stepsize,
reinterpret_cast<float *>(tminmaxim.data_ptr()),
sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
reinterpret_cast<float *>(primpos.data_ptr()),
reinterpret_cast<float *>(grad_primpos.data_ptr()),
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
grad_primrot ? reinterpret_cast<float *>(grad_primrot->data_ptr()) : nullptr,
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
grad_primscale ? reinterpret_cast<float *>(grad_primscale->data_ptr()) : nullptr,
TD, TH, TW,
reinterpret_cast<float *>(tplate.data_ptr()),
reinterpret_cast<float *>(grad_tplate.data_ptr()),
WD, WH, WW,
warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
grad_warp ? reinterpret_cast<float *>(grad_warp->data_ptr()) : nullptr,
reinterpret_cast<float *>(rayrgbaim.data_ptr()),
reinterpret_cast<float *>(grad_rayrgba.data_ptr()),
raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
griddim, blocksizex, blocksizey,
0);
return {};
}