in train.py [0:0]
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
opacity_color: [num_rays, num_samples]. opacity assigned to each sampled color. independent of ray.
visibility_weights: [num_rays, num_samples]. Weights assigned to each sampled color. visibility along ray.
depth_map: [num_rays]. Estimated distance to object.
"""
device = raw.get_device()
def raw2alpha(raw, dists, act_fn=F.relu):
return 1.0 - torch.exp(-act_fn(raw) * dists)
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat(
[dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1
) # [N_rays, N_samples]
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
noise = 0.0
if raw_noise_std > 0.0:
noise = torch.randn(raw[..., 3].shape, device=device) * raw_noise_std
# Overwrite randomly sampled data if pytest
if pytest:
np.random.seed(0)
noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
noise = torch.Tensor(noise, device=device)
opacity_alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples]
# weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
visibility_weights = (
opacity_alpha
* torch.cumprod(
torch.cat(
[
torch.ones((opacity_alpha.shape[0], 1), device=device),
1.0 - opacity_alpha + 1e-10,
],
-1,
),
-1,
)[:, :-1]
)
rgb_map = torch.sum(visibility_weights[..., None] * rgb, -2) # [N_rays, 3]
depth_map = torch.sum(visibility_weights * z_vals, -1)
acc_map = torch.sum(visibility_weights, -1)
# disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / (acc_map + 1e-10))
disp_map = 1.0 / torch.max(
1e-10 * torch.ones_like(depth_map),
depth_map / torch.sum(visibility_weights, -1),
)
if white_bkgd:
rgb_map = rgb_map + (1.0 - acc_map[..., None])
return rgb_map, disp_map, acc_map, opacity_alpha, visibility_weights, depth_map