in occant_baselines/models/mapnet.py [0:0]
def ground_projection(img_feats, spatial_locs, valid_inputs, local_shape, K, eps=-1e16):
r"""Inputs:
img_feats - (bs, F, H/K, W/K) image features to project to ground plane
spatial_locs - (bs, 2, H, W)
for each batch, H and W, the (x, y) locations on map are given.
valid_inputs - (bs, 1, H, W) ByteTensor
local_shape - (outh, outw) tuple indicating size of output projection
K - image_size / map_shape ratio (needed for sampling values from spatial_locs)
eps - fill_value
Outputs:
proj_feats - (bs, F, s, s)
"""
device = img_feats.device
outh, outw = local_shape
bs, F, HbyK, WbyK = img_feats.shape
img_feat_locs = (
(torch.arange(0, HbyK, 1) * K + K / 2).long().to(device),
(torch.arange(0, WbyK, 1) * K + K / 2).long().to(device),
)
input_feats = img_feats
input_idxes = spatial_locs[
:, :, img_feat_locs[0][:, None], img_feat_locs[1]
] # (bs, 2, HbyK, WbyK)
valid_inputs_depth = valid_inputs[
:, :, img_feat_locs[0][:, None], img_feat_locs[1]
] # (bs, 1, HbyK, WbyK)
valid_inputs_depth = valid_inputs_depth.squeeze(1) # (bs, HbyK, WbyK)
invalid_inputs_depth = ~valid_inputs_depth
output_feats = torch.zeros(bs, F, outh, outw).to(device)
output_feats.fill_(eps)
output_feats_rshp = output_feats.view(*output_feats.shape[:2], -1)
input_idxes_flip = torch.flip(input_idxes, [1]) # convert x, y to y, x
invalid_writes = (
(input_idxes_flip[:, 0] >= outh)
| (input_idxes_flip[:, 1] >= outw)
| (input_idxes_flip[:, 0] < 0)
| (input_idxes_flip[:, 1] < 0)
| invalid_inputs_depth
) # (bs, H, W)
# Set the idxes for all invalid locations to (0, 0)
input_idxes_flip[:, 0][invalid_writes] = 0
input_idxes_flip[:, 1][invalid_writes] = 0
invalid_writes = invalid_writes.float().unsqueeze(1)
input_feats_masked = input_feats * (1 - invalid_writes) + eps * invalid_writes
input_feats_rshp = input_feats_masked.view(bs, F, -1)
input_idxes_rshp = (
input_idxes_flip[:, 0, :, :] * outw + input_idxes_flip[:, 1, :, :]
)
input_idxes_rshp = input_idxes_rshp.view(bs, 1, -1).expand(-1, F, -1)
output_feats_rshp, _ = torch_scatter.scatter_max(
input_feats_rshp, input_idxes_rshp, dim=2, dim_size=outh * outw, fill_value=eps,
)
output_feats = output_feats_rshp.view(bs, F, outh, outw)
eps_mask = (output_feats == eps).float()
output_feats = output_feats * (1 - eps_mask) + eps_mask * (output_feats - eps)
return output_feats