in fairnr/modules/reader.py [0:0]
def sample_pixels(self, uv, size, alpha=None, mask=None, **kwargs):
H, W = int(size[0,0,0]), int(size[0,0,1])
S, V = uv.size()[:2]
if mask is None:
if alpha is not None:
mask = (alpha > 0)
else:
mask = uv.new_ones(S, V, uv.size(-1)).bool()
mask = mask.float().reshape(S, V, H, W)
if self.args.sampling_at_center < 1.0:
r = (1 - self.args.sampling_at_center) / 2.0
mask0 = mask.new_zeros(S, V, H, W)
mask0[:, :, int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1
mask = mask * mask0
if self.args.sampling_on_bbox:
x_has_points = mask.sum(2, keepdim=True) > 0
y_has_points = mask.sum(3, keepdim=True) > 0
mask = (x_has_points & y_has_points).float()
probs = mask / (mask.sum() + 1e-8)
if self.args.sampling_on_mask > 0.0:
probs = self.args.sampling_on_mask * probs + (1 - self.args.sampling_on_mask) * 1.0 / (H * W)
num_pixels = int(self.args.pixel_per_view)
patch_size, skip_size = self.args.sampling_patch_size, self.args.sampling_skipping_size
C = patch_size * skip_size
if C > 1:
probs = probs.reshape(S, V, H // C, C, W // C, C).sum(3).sum(-1)
num_pixels = num_pixels // patch_size // patch_size
flatten_probs = probs.reshape(S, V, -1)
sampled_index = sampling_without_replacement(torch.log(flatten_probs+ TINY), num_pixels)
sampled_masks = torch.zeros_like(flatten_probs).scatter_(-1, sampled_index, 1).reshape(S, V, H // C, W // C)
if C > 1:
sampled_masks = sampled_masks[:, :, :, None, :, None].repeat(
1, 1, 1, patch_size, 1, patch_size).reshape(S, V, H // skip_size, W // skip_size)
if skip_size > 1:
full_datamask = sampled_masks.new_zeros(S, V, skip_size * skip_size, H // skip_size, W // skip_size)
full_index = torch.randint(skip_size*skip_size, (S, V))
for i in range(S):
for j in range(V):
full_datamask[i, j, full_index[i, j]] = sampled_masks[i, j]
sampled_masks = full_datamask.reshape(
S, V, skip_size, skip_size, H // skip_size, W // skip_size).permute(0, 1, 4, 2, 5, 3).reshape(S, V, H, W)
X, Y = uv[:,:,0].reshape(S, V, H, W), uv[:,:,1].reshape(S, V, H, W)
X = X[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size)
Y = Y[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size)
return torch.cat([X, Y], 2), sampled_masks