def ground_projection()

in occant_baselines/models/mapnet.py [0:0]


def ground_projection(img_feats, spatial_locs, valid_inputs, local_shape, K, eps=-1e16):
    r"""Inputs:
        img_feats       - (bs, F, H/K, W/K) image features to project to ground plane
        spatial_locs    - (bs, 2, H, W)
                          for each batch, H and W, the (x, y) locations on map are given.
        valid_inputs    - (bs, 1, H, W) ByteTensor
        local_shape     - (outh, outw) tuple indicating size of output projection
        K               - image_size / map_shape ratio (needed for sampling values from spatial_locs)
        eps             - fill_value
    Outputs:
        proj_feats      - (bs, F, s, s)
    """
    device = img_feats.device
    outh, outw = local_shape
    bs, F, HbyK, WbyK = img_feats.shape
    img_feat_locs = (
        (torch.arange(0, HbyK, 1) * K + K / 2).long().to(device),
        (torch.arange(0, WbyK, 1) * K + K / 2).long().to(device),
    )

    input_feats = img_feats
    input_idxes = spatial_locs[
        :, :, img_feat_locs[0][:, None], img_feat_locs[1]
    ]  # (bs, 2, HbyK, WbyK)
    valid_inputs_depth = valid_inputs[
        :, :, img_feat_locs[0][:, None], img_feat_locs[1]
    ]  # (bs, 1, HbyK, WbyK)
    valid_inputs_depth = valid_inputs_depth.squeeze(1)  # (bs, HbyK, WbyK)
    invalid_inputs_depth = ~valid_inputs_depth

    output_feats = torch.zeros(bs, F, outh, outw).to(device)
    output_feats.fill_(eps)
    output_feats_rshp = output_feats.view(*output_feats.shape[:2], -1)
    input_idxes_flip = torch.flip(input_idxes, [1])  # convert x, y to y, x

    invalid_writes = (
        (input_idxes_flip[:, 0] >= outh)
        | (input_idxes_flip[:, 1] >= outw)
        | (input_idxes_flip[:, 0] < 0)
        | (input_idxes_flip[:, 1] < 0)
        | invalid_inputs_depth
    )  # (bs, H, W)

    # Set the idxes for all invalid locations to (0, 0)
    input_idxes_flip[:, 0][invalid_writes] = 0
    input_idxes_flip[:, 1][invalid_writes] = 0

    invalid_writes = invalid_writes.float().unsqueeze(1)

    input_feats_masked = input_feats * (1 - invalid_writes) + eps * invalid_writes
    input_feats_rshp = input_feats_masked.view(bs, F, -1)
    input_idxes_rshp = (
        input_idxes_flip[:, 0, :, :] * outw + input_idxes_flip[:, 1, :, :]
    )
    input_idxes_rshp = input_idxes_rshp.view(bs, 1, -1).expand(-1, F, -1)
    output_feats_rshp, _ = torch_scatter.scatter_max(
        input_feats_rshp, input_idxes_rshp, dim=2, dim_size=outh * outw, fill_value=eps,
    )
    output_feats = output_feats_rshp.view(bs, F, outh, outw)
    eps_mask = (output_feats == eps).float()
    output_feats = output_feats * (1 - eps_mask) + eps_mask * (output_feats - eps)

    return output_feats