in build_graph/tools/superpoint/model.py [0:0]
def run(self, img):
""" Process a numpy image to extract points and descriptors.
Input
img - HxW numpy float32 input image in range [0,1].
Output
corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
desc - 256xN numpy array of corresponding unit normalized descriptors.
heatmap - HxW numpy heatmap in range [0,1] of point confidences.
"""
assert img.ndim == 2, 'Image must be grayscale.'
assert img.dtype == np.float32, 'Image must be float32.'
H, W = img.shape[0], img.shape[1]
inp = img.copy()
inp = (inp.reshape(1, H, W))
inp = torch.from_numpy(inp)
inp = torch.autograd.Variable(inp).view(1, 1, H, W)
if self.cuda:
inp = inp.cuda()
# Forward pass of network.
outs = self.net.forward(inp)
semi, coarse_desc = outs[0], outs[1]
# Convert pytorch -> numpy.
semi = semi.data.cpu().numpy().squeeze()
# --- Process points.
dense = np.exp(semi) # Softmax.
dense = dense / (np.sum(dense, axis=0)+.00001) # Should sum to 1.
# Remove dustbin.
nodust = dense[:-1, :, :]
# Reshape to get full resolution heatmap.
Hc = int(H / self.cell)
Wc = int(W / self.cell)
nodust = nodust.transpose(1, 2, 0)
heatmap = np.reshape(nodust, [Hc, Wc, self.cell, self.cell])
heatmap = np.transpose(heatmap, [0, 2, 1, 3])
heatmap = np.reshape(heatmap, [Hc*self.cell, Wc*self.cell])
xs, ys = np.where(heatmap >= self.conf_thresh) # Confidence threshold.
if len(xs) == 0:
return np.zeros((3, 0)), None, None
pts = np.zeros((3, len(xs))) # Populate point data sized 3xN.
pts[0, :] = ys
pts[1, :] = xs
pts[2, :] = heatmap[xs, ys]
pts, _ = self.nms_fast(pts, H, W, dist_thresh=self.nms_dist) # Apply NMS.
inds = np.argsort(pts[2,:])
pts = pts[:,inds[::-1]] # Sort by confidence.
# Remove points along border.
bord = self.border_remove
toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W-bord))
toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H-bord))
toremove = np.logical_or(toremoveW, toremoveH)
pts = pts[:, ~toremove]
# --- Process descriptor.
D = coarse_desc.shape[1]
if pts.shape[1] == 0:
desc = np.zeros((D, 0))
else:
# Interpolate into descriptor map using 2D point locations.
samp_pts = torch.from_numpy(pts[:2, :].copy())
samp_pts[0, :] = (samp_pts[0, :] / (float(W)/2.)) - 1.
samp_pts[1, :] = (samp_pts[1, :] / (float(H)/2.)) - 1.
samp_pts = samp_pts.transpose(0, 1).contiguous()
samp_pts = samp_pts.view(1, 1, -1, 2)
samp_pts = samp_pts.float()
if self.cuda:
samp_pts = samp_pts.cuda()
desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts)
desc = desc.data.cpu().numpy().reshape(D, -1)
desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]
return pts, desc, heatmap