in vision_charts/models.py [0:0]
def pooling(self, blocks, verts_pos, debug=False):
# convert vertex positions to x,y coordinates in the image, scaled to fractions of image dimension
ext_verts_pos = torch.cat(
(verts_pos, torch.FloatTensor(np.ones([verts_pos.shape[0], verts_pos.shape[1], 1])).cuda()), dim=-1)
ext_verts_pos = torch.matmul(ext_verts_pos, self.matrix.permute(1, 0))
xs = ext_verts_pos[:, :, 1] / ext_verts_pos[:, :, 2] / 256.
ys = ext_verts_pos[:, :, 0] / ext_verts_pos[:, :, 2] / 256.
full_features = None
batch_size = verts_pos.shape[0]
# check camera project covers the image
if debug:
dim = 256
xs = (torch.clamp(xs * dim, 0, dim - 1).data.cpu().numpy()).astype(np.uint8)
ys = (torch.clamp(ys * dim, 0, dim - 1).data.cpu().numpy()).astype(np.uint8)
for ex in range(blocks.shape[0]):
img = blocks[ex].permute(1, 2, 0).data.cpu().numpy()[:, :, :3]
for x, y in zip(xs[ex], ys[ex]):
img[x, y, 0] = 1
img[x, y, 1] = 0
img[x, y, 2] = 0
from PIL import Image
Image.fromarray((img * 255).astype(np.uint8)).save('results/temp.png')
print('saved')
input()
for block in blocks:
# scale projected vertex points to dimension of current feature map
dim = block.shape[-1]
cur_xs = torch.clamp(xs * dim, 0, dim - 1)
cur_ys = torch.clamp(ys * dim, 0, dim - 1)
# https://en.wikipedia.org/wiki/Bilinear_interpolation
x1s, y1s, x2s, y2s = torch.floor(cur_xs), torch.floor(cur_ys), torch.ceil(cur_xs), torch.ceil(cur_ys)
A = x2s - cur_xs
B = cur_xs - x1s
G = y2s - cur_ys
H = cur_ys - y1s
x1s = x1s.type(torch.cuda.LongTensor)
y1s = y1s.type(torch.cuda.LongTensor)
x2s = x2s.type(torch.cuda.LongTensor)
y2s = y2s.type(torch.cuda.LongTensor)
# flatten batch of feature maps to make vectorization easier
flat_block = block.permute(1, 0, 2, 3).contiguous().view(block.shape[1], -1)
block_idx = torch.arange(0, verts_pos.shape[0]).cuda().unsqueeze(-1).expand(batch_size, verts_pos.shape[1])
block_idx = block_idx * dim * dim
selection = (block_idx + (x1s * dim) + y1s).view(-1)
C = torch.index_select(flat_block, 1, selection)
C = C.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
selection = (block_idx + (x1s * dim) + y2s).view(-1)
D = torch.index_select(flat_block, 1, selection)
D = D.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
selection = (block_idx + (x2s * dim) + y1s).view(-1)
E = torch.index_select(flat_block, 1, selection)
E = E.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
selection = (block_idx + (x2s * dim) + y2s).view(-1)
F = torch.index_select(flat_block, 1, selection)
F = F.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
section1 = A.unsqueeze(1) * C * G.unsqueeze(1)
section2 = H.unsqueeze(1) * D * A.unsqueeze(1)
section3 = G.unsqueeze(1) * E * B.unsqueeze(1)
section4 = B.unsqueeze(1) * F * H.unsqueeze(1)
features = (section1 + section2 + section3 + section4)
features = features.permute(0, 2, 1)
if full_features is None:
full_features = features
else:
full_features = torch.cat((full_features, features), dim=2)
return full_features