def sample_pixels()

in fairnr/modules/reader.py [0:0]


    def sample_pixels(self, uv, size, alpha=None, mask=None, **kwargs):
        H, W = int(size[0,0,0]), int(size[0,0,1])
        S, V = uv.size()[:2]

        if mask is None:
            if alpha is not None:
                mask = (alpha > 0)
            else:
                mask = uv.new_ones(S, V, uv.size(-1)).bool()
        mask = mask.float().reshape(S, V, H, W)

        if self.args.sampling_at_center < 1.0:
            r = (1 - self.args.sampling_at_center) / 2.0
            mask0 = mask.new_zeros(S, V, H, W)
            mask0[:, :, int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1
            mask = mask * mask0
        
        if self.args.sampling_on_bbox:
            x_has_points = mask.sum(2, keepdim=True) > 0
            y_has_points = mask.sum(3, keepdim=True) > 0
            mask = (x_has_points & y_has_points).float()  

        probs = mask / (mask.sum() + 1e-8)
        if self.args.sampling_on_mask > 0.0:
            probs = self.args.sampling_on_mask * probs + (1 - self.args.sampling_on_mask) * 1.0 / (H * W)

        num_pixels = int(self.args.pixel_per_view)
        patch_size, skip_size = self.args.sampling_patch_size, self.args.sampling_skipping_size
        C = patch_size * skip_size
        
        if C > 1:
            probs = probs.reshape(S, V, H // C, C, W // C, C).sum(3).sum(-1)
            num_pixels = num_pixels // patch_size // patch_size

        flatten_probs = probs.reshape(S, V, -1) 
        sampled_index = sampling_without_replacement(torch.log(flatten_probs+ TINY), num_pixels)
        sampled_masks = torch.zeros_like(flatten_probs).scatter_(-1, sampled_index, 1).reshape(S, V, H // C, W // C)

        if C > 1:
            sampled_masks = sampled_masks[:, :, :, None, :, None].repeat(
                1, 1, 1, patch_size, 1, patch_size).reshape(S, V, H // skip_size, W // skip_size)
            if skip_size > 1:
                full_datamask = sampled_masks.new_zeros(S, V, skip_size * skip_size, H // skip_size, W // skip_size)
                full_index = torch.randint(skip_size*skip_size, (S, V))
                for i in range(S):
                    for j in range(V):
                        full_datamask[i, j, full_index[i, j]] = sampled_masks[i, j]
                sampled_masks = full_datamask.reshape(
                    S, V, skip_size, skip_size, H // skip_size, W // skip_size).permute(0, 1, 4, 2, 5, 3).reshape(S, V, H, W)
        
        X, Y = uv[:,:,0].reshape(S, V, H, W), uv[:,:,1].reshape(S, V, H, W)
        X = X[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size)
        Y = Y[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size)
        return torch.cat([X, Y], 2), sampled_masks