in fairnr/clib/__init__.py [0:0]
def forward(ctx, pts_idx, min_depth, max_depth, probs, steps, fixed_step_size=-1, deterministic=False):
G, N, P = 200, pts_idx.size(0), pts_idx.size(1)
H = int(np.ceil(N / G)) * G
if H > N:
pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H-N, P)], 0)
min_depth = torch.cat([min_depth, min_depth[:1].expand(H-N, P)], 0)
max_depth = torch.cat([max_depth, max_depth[:1].expand(H-N, P)], 0)
probs = torch.cat([probs, probs[:1].expand(H-N, P)], 0)
steps = torch.cat([steps, steps[:1].expand(H-N)], 0)
# print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)
pts_idx = pts_idx.reshape(G, -1, P)
min_depth = min_depth.reshape(G, -1, P)
max_depth = max_depth.reshape(G, -1, P)
probs = probs.reshape(G, -1, P)
steps = steps.reshape(G, -1)
# pre-generate noise
max_steps = steps.ceil().long().max() + P
noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
if deterministic:
noise += 0.5
else:
noise = noise.uniform_().clamp(min=0.001, max=0.999) # in case
# call cuda function
chunk_size = 4 * G # to avoid oom?
results = [
_ext.inverse_cdf_sampling(
pts_idx[:, i:i+chunk_size].contiguous(),
min_depth.float()[:, i:i+chunk_size].contiguous(),
max_depth.float()[:, i:i+chunk_size].contiguous(),
noise.float()[:, i:i+chunk_size].contiguous(),
probs.float()[:, i:i+chunk_size].contiguous(),
steps.float()[:, i:i+chunk_size].contiguous(),
fixed_step_size)
for i in range(0, min_depth.size(1), chunk_size)
]
sampled_idx, sampled_depth, sampled_dists = [
torch.cat([r[i] for r in results], 1)
for i in range(3)
]
sampled_depth = sampled_depth.type_as(min_depth)
sampled_dists = sampled_dists.type_as(min_depth)
sampled_idx = sampled_idx.reshape(H, -1)
sampled_depth = sampled_depth.reshape(H, -1)
sampled_dists = sampled_dists.reshape(H, -1)
if H > N:
sampled_idx = sampled_idx[: N]
sampled_depth = sampled_depth[: N]
sampled_dists = sampled_dists[: N]
max_len = sampled_idx.ne(-1).sum(-1).max()
sampled_idx = sampled_idx[:, :max_len]
sampled_depth = sampled_depth[:, :max_len]
sampled_dists = sampled_dists[:, :max_len]
ctx.mark_non_differentiable(sampled_idx)
ctx.mark_non_differentiable(sampled_depth)
ctx.mark_non_differentiable(sampled_dists)
return sampled_idx, sampled_depth, sampled_dists