def forward()

in fairnr/clib/__init__.py [0:0]


    def forward(ctx, pts_idx, min_depth, max_depth, probs, steps, fixed_step_size=-1, deterministic=False):
        G, N, P = 200, pts_idx.size(0), pts_idx.size(1)
        H = int(np.ceil(N / G)) * G
        
        if H > N:
            pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H-N, P)], 0)
            min_depth = torch.cat([min_depth, min_depth[:1].expand(H-N, P)], 0)
            max_depth = torch.cat([max_depth, max_depth[:1].expand(H-N, P)], 0)
            probs = torch.cat([probs, probs[:1].expand(H-N, P)], 0)
            steps = torch.cat([steps, steps[:1].expand(H-N)], 0)
        # print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)
        pts_idx = pts_idx.reshape(G, -1, P)
        min_depth = min_depth.reshape(G, -1, P)
        max_depth = max_depth.reshape(G, -1, P)
        probs = probs.reshape(G, -1, P)
        steps = steps.reshape(G, -1)

        # pre-generate noise
        max_steps = steps.ceil().long().max() + P
        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
        if deterministic:
            noise += 0.5
        else:
            noise = noise.uniform_().clamp(min=0.001, max=0.999)  # in case
        
        # call cuda function
        chunk_size = 4 * G  # to avoid oom?
        results = [
            _ext.inverse_cdf_sampling(
                pts_idx[:, i:i+chunk_size].contiguous(), 
                min_depth.float()[:, i:i+chunk_size].contiguous(), 
                max_depth.float()[:, i:i+chunk_size].contiguous(), 
                noise.float()[:, i:i+chunk_size].contiguous(), 
                probs.float()[:, i:i+chunk_size].contiguous(), 
                steps.float()[:, i:i+chunk_size].contiguous(), 
                fixed_step_size)
            for i in range(0, min_depth.size(1), chunk_size)
        ]
        sampled_idx, sampled_depth, sampled_dists = [
            torch.cat([r[i] for r in results], 1)
            for i in range(3)
        ]
        sampled_depth = sampled_depth.type_as(min_depth)
        sampled_dists = sampled_dists.type_as(min_depth)
        
        sampled_idx = sampled_idx.reshape(H, -1)
        sampled_depth = sampled_depth.reshape(H, -1)
        sampled_dists = sampled_dists.reshape(H, -1)
        if H > N:
            sampled_idx = sampled_idx[: N]
            sampled_depth = sampled_depth[: N]
            sampled_dists = sampled_dists[: N]
        
        max_len = sampled_idx.ne(-1).sum(-1).max()
        sampled_idx = sampled_idx[:, :max_len]
        sampled_depth = sampled_depth[:, :max_len]
        sampled_dists = sampled_dists[:, :max_len]

        ctx.mark_non_differentiable(sampled_idx)
        ctx.mark_non_differentiable(sampled_depth)
        ctx.mark_non_differentiable(sampled_dists)
        return sampled_idx, sampled_depth, sampled_dists