def project_zbuffer()

in models/projection/depth_manipulator.py [0:0]


    def project_zbuffer(self, depth, K, K_inv, RTinv_cam1, RT_cam2):
        """ Determine the sampler that comes from projecting
        the given depth according to the given camera parameters.
        """
        bs, _, w, h = depth.size()

        # Obtain unprojected coordinates
        orig_xys = self.grid.to(depth.device).repeat(bs, 1, 1, 1).detach()
        xys = orig_xys * depth
        xys[:, -1, :] = 1

        xys = xys.view(bs, 4, -1)

        # Transform into camera coordinate of the first view
        cam1_X = K_inv.bmm(xys)

        # Transform into world coordinates
        RT = RT_cam2.bmm(RTinv_cam1)
        wrld_X = RT.bmm(cam1_X)

        # And intrinsics
        xy_proj = K.bmm(wrld_X)

        # And finally we project to get the final result
        mask = (xy_proj[:, 2:3, :].abs() < EPS).detach()
        sampler = xy_proj[:, 0:2, :] / -xy_proj[:, 2:3, :]
        sampler[mask.repeat(1, 2, 1)] = -10
        sampler[:, 1, :] = -sampler[:, 1, :]
        sampler[:, 0, :] = sampler[:, 0, :]

        with torch.no_grad():
            print(
                "Warning : not backpropagating through the "
                + "projection -- is this what you want??"
            )
            tsampler = (sampler + 1) * 128
            tsampler = tsampler.view(bs, 2, -1)
            zs, sampler_inds = xy_proj[:, 2:3, :].sort(
                dim=2, descending=True
            )  # Hack for how it's going to be understood by scatter: enforces that
            # nearer points are the ones rendered.
            bsinds = (
                torch.linspace(0, bs - 1, bs)
                .long()
                .unsqueeze(1)
                .repeat(1, w * h)
                .to(sampler.device)
                .unsqueeze(1)
            )

            xs = tsampler[bsinds, 0, sampler_inds].long()
            ys = tsampler[bsinds, 1, sampler_inds].long()
            mask = (tsampler < 0) | (tsampler > 255)
            mask = mask.float().max(dim=1, keepdim=True)[0] * 4

            xs = xs.clamp(min=0, max=255)
            ys = ys.clamp(min=0, max=255)

            bilinear_sampler = torch.zeros(bs, 2, w, h).to(sampler.device) - 2
            orig_xys = orig_xys[:, :2, :, :].view((bs, 2, -1))
            bilinear_sampler[bsinds, 0, ys, xs] = (
                orig_xys[bsinds, 0, sampler_inds] + mask
            )
            bilinear_sampler[bsinds, 1, ys, xs] = (
                -orig_xys[bsinds, 1, sampler_inds] + mask
            )

        return bilinear_sampler, -xy_proj[:, 2:3, :].view(bs, 1, w, h)