std::vector raymarch_backward()

in extensions/mvpraymarch/mvpraymarch.cpp [282:396]


std::vector<torch::Tensor> raymarch_backward(
        torch::Tensor rayposim,
        torch::Tensor raydirim,
        float stepsize,
        torch::Tensor tminmaxim,

        torch::optional<torch::Tensor> sortedobjid,
        torch::optional<torch::Tensor> nodechildren,
        torch::optional<torch::Tensor> nodeaabb,

        torch::Tensor primpos,
        torch::Tensor grad_primpos,
        torch::optional<torch::Tensor> primrot,
        torch::optional<torch::Tensor> grad_primrot,
        torch::optional<torch::Tensor> primscale,
        torch::optional<torch::Tensor> grad_primscale,

        torch::Tensor tplate,
        torch::Tensor grad_tplate,
        torch::optional<torch::Tensor> warp,
        torch::optional<torch::Tensor> grad_warp,

        torch::Tensor rayrgbaim,
        torch::Tensor grad_rayrgba,
        torch::optional<torch::Tensor> raysatim,
        torch::optional<torch::Tensor> raytermim,

        int algorithm=0,
        bool sortboxes=true,
        int maxhitboxes=512,
        bool synchitboxes=false,
        bool chlast=false,
        float fadescale=8.f,
        float fadeexp=8.f,
        int accum=0,
        float termthresh=0.f,
        int griddim=3,
        int blocksizex=8,
        int blocksizey=16) {
    CHECK_INPUT(rayposim);
    CHECK_INPUT(raydirim);
    CHECK_INPUT(tminmaxim);
    if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
    if (nodechildren) { CHECK_INPUT(*nodechildren); }
    if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
    CHECK_INPUT(tplate);
    if (warp) { CHECK_INPUT(*warp); }
    CHECK_INPUT(primpos);
    if (primrot) { CHECK_INPUT(*primrot); }
    if (primscale) { CHECK_INPUT(*primscale); }
    CHECK_INPUT(rayrgbaim);
    if (raysatim) { CHECK_INPUT(*raysatim); }
    if (raytermim) { CHECK_INPUT(*raytermim); }
    CHECK_INPUT(grad_rayrgba);
    CHECK_INPUT(grad_tplate);
    if (grad_warp) { CHECK_INPUT(*grad_warp); }
    CHECK_INPUT(grad_primpos);
    if (grad_primrot) { CHECK_INPUT(*grad_primrot); }
    if (grad_primscale) { CHECK_INPUT(*grad_primscale); }

    int N = rayposim.size(0);
    int H = rayposim.size(1);
    int W = rayposim.size(2);
    int K = primpos.size(1);

    int TD, TH, TW;
    if (chlast) {
        TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
    } else {
        TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
    }

    int WD = 0, WH = 0, WW = 0;
    if (warp) {
        if (chlast) {
            WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
        } else {
            WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
        }
    }

    raymarch_backward_cuda(N, H, W, K,
            reinterpret_cast<float *>(rayposim.data_ptr()),
            reinterpret_cast<float *>(raydirim.data_ptr()),
            stepsize,
            reinterpret_cast<float *>(tminmaxim.data_ptr()),
            sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
            nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
            nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,

            reinterpret_cast<float *>(primpos.data_ptr()),
            reinterpret_cast<float *>(grad_primpos.data_ptr()),
            primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
            grad_primrot ? reinterpret_cast<float *>(grad_primrot->data_ptr()) : nullptr,
            primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
            grad_primscale ? reinterpret_cast<float *>(grad_primscale->data_ptr()) : nullptr,

            TD, TH, TW,
            reinterpret_cast<float *>(tplate.data_ptr()),
            reinterpret_cast<float *>(grad_tplate.data_ptr()),
            WD, WH, WW,
            warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
            grad_warp ? reinterpret_cast<float *>(grad_warp->data_ptr()) : nullptr,

            reinterpret_cast<float *>(rayrgbaim.data_ptr()),
            reinterpret_cast<float *>(grad_rayrgba.data_ptr()),
            raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
            raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,

            algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
            griddim, blocksizex, blocksizey,
            0);

    return {};
}