def backproject()

in models/backprojection_utils.py [0:0]


def backproject(voxel_dim, voxel_size, world_center, Rt, K, features, depth=None, resize=True, return_pointcloud=False):
    """
    Take 2d features and fills them along rays in a 3d volume.
    Inspired by https://github.com/magicleap/Atlas/blob/master/atlas/model.py#L35

    Args:
        voxel_dim (tuple, int): tuple indicating the number of voxels along each dimension
            of the voxel grid (nx,ny,nz)
        voxel_size (float): the dimensions of a voxel in real world units (i.e. whatever units
            were used to measure the camera intrinsic and extrinsic matrices). E.g. if voxel_size
            is (100, 100, 100) and represents a 4mx4mx4m cube in reality, then voxel_size would
            be equal to 0.04 (4m per 100 voxels = 4/100 = 0.04m per voxel).
        world_center (torch.Tensor): xyz coordinate indicating the center of the world coordinate frame
        Rt (torch.Tensor): Bx4x4 extrinsic camera matrices
        K (float): Bx4x4 intrinsic camera matrices
        features (torch.Tensor): BxCxHxW features to be backprojected into 3D
        depth (torch.Tensor): Bx1xHxW depth values to use for backprojection (optional)
        resize (bool): indicates whether to upsample the input features to ensure there are no holes in
            the ray projection

    Returns:
        volume (torch.Tensor): volume containing backprojected features. Of shape [nx, ny, nz, C].

    """
    tform_cam2world = Rt.inverse()
    fx, fy = K[0, 0, 0], K[0, 1, 1]  # grab the first value and assume all others are similar

    if resize:
        # upscale the feature map so that we don't get empty holes in our ray projections
        B, C, H, W = features.shape
        features = torch.nn.functional.interpolate(features, size=max(voxel_dim), mode="bilinear", align_corners=False)
        downsample_ratio = max(voxel_dim) / max(H, W)  # new / old
        fx, fy = fx * downsample_ratio, fy * downsample_ratio
        samples_per_ray = max(voxel_dim)

    nx, ny, nz = voxel_dim
    B, C, H, W = features.shape
    device = features.device

    if depth is not None:
        samples_per_ray = 1
        if resize:
            depth = torch.nn.functional.interpolate(depth, size=(H, W), mode="bilinear", align_corners=False)

    voxels_per_unit_dimension = 1 / voxel_size

    # get ray origins and ray directions based on focal length and extrinsic matrix
    ro, rd = get_ray_bundle_batch(H, W, (fx, fy), tform_cam2world)
    ro = ro.view((-1, 3))
    rd = rd.view((-1, 3))
    num_rays = ro.shape[0]

    if depth is None:
        # project points along each ray at uniform intervals
        t_vals = torch.linspace(0.0, 1.0, samples_per_ray, dtype=ro.dtype, device=ro.device)
        z_vals = 0 * (1.0 - t_vals) + (voxel_size * samples_per_ray) * t_vals
        z_vals = z_vals.expand([num_rays, samples_per_ray])
    else:
        z_vals = depth.view(num_rays, 1)

    # pts -> (num_rays, N_samples, 3)
    # pts are in world coordinates
    pts = ro[..., None, :] + rd[..., None, :] * z_vals[..., :, None]

    # now we want to convert from world coordinates to voxel coordinates
    pts = pts * voxels_per_unit_dimension  # scale to match voxel grid
    world_center = torch.tensor(world_center, dtype=torch.float, device=device)
    world_center = world_center * voxels_per_unit_dimension
    # one corner of voxel grid will always be at (0, 0, 0), and the oppostie corner at (voxel_dim)
    #     so we can easily find center by dividing by 2
    voxel_center = torch.tensor(voxel_dim, dtype=torch.float, device=device) / 2
    offset = voxel_center - world_center
    pts_aligned = pts + offset  # pts should now be aligned with voxel grid
    pts_grid = pts_aligned.round().long()  # snap to grid

    pts_flat = pts_grid.view(-1, 3)
    px = pts_flat[:, 0]
    py = pts_flat[:, 1]
    pz = pts_flat[:, 2]

    # find out which points along the backprojected rays lie within the volume
    valid = (px >= 0) & (px < voxel_dim[0]) & (py >= 0) & (py < voxel_dim[1]) & (pz >= 0) & (pz < voxel_dim[2])

    volume = torch.zeros(B, nx, ny, nz, C, dtype=features.dtype, device=device)

    batch_idx = torch.arange(B, dtype=torch.long, device=device)
    batch_idx = torch.repeat_interleave(batch_idx, repeats=H * W * samples_per_ray, dim=0)

    # put channel dimension at the back
    features = features.permute(0, 2, 3, 1).contiguous()
    # replicate features for each sample along the rays
    features = features.view(-1, C).unsqueeze(1).expand(-1, samples_per_ray, -1).reshape(-1, C)

    if return_pointcloud:
        batch_coords = torch.cat([batch_idx.unsqueeze(1), pts_aligned.view(-1, 3)], dim=-1)
        return batch_coords[valid], features[valid]
    else:
        volume[batch_idx[valid], pts_flat[valid][:, 0], pts_flat[valid][:, 1], pts_flat[valid][:, 2], :] = features[
            valid
        ]
        volume = volume.permute(0, 4, 1, 2, 3).contiguous()
        return volume