in run_nerf_helpers.py [0:0]
def forward(self, x, detailed_output=False):
input_pts, input_views, input_latents = torch.split(
x,
[self.input_ch, self.input_ch_views, self.ray_bending_latent_size],
dim=-1,
)
if detailed_output:
details = {}
details["initial_input_pts"] = (
input_pts[:, :3].clone().detach()
) # only keep xyz (embedding/positional encoding has raw xyz as the first three entries)
else:
details = None
if self.ray_bender[0] is not None:
if self.use_viewdirs and not self.approx_nonrigid_viewdirs:
if self.ray_bender[0].use_positionally_encoded_input:
raise RuntimeError("not supported")
with torch.enable_grad(): # necessay to work properly in no_grad() mode
initial_input_pts = input_pts[:, :3]
if not initial_input_pts.requires_grad:
initial_input_pts.requires_grad = True # only do this when the overall rendering is running in no_grad() mode
input_pts = self.ray_bender[0](
initial_input_pts, input_latents, details
)
bent_input_pts = input_pts[:, :3]
else:
input_pts = self.ray_bender[0](input_pts, input_latents, details)
if detailed_output:
details["input_pts"] = input_pts[:, :3].clone().detach()
h = input_pts
if self.time_conditioned_baseline:
h = torch.cat([h, input_latents], -1)
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
if self.time_conditioned_baseline:
h = torch.cat([input_pts, input_latents, h], -1)
else:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
if self.ray_bender[0] is not None:
if self.approx_nonrigid_viewdirs:
input_views = self.viewdirs_via_finite_differences(input_pts[:, :3])
else:
input_views = self.exact_nonrigid_viewdirs(
initial_input_pts, bent_input_pts, input_views[:, :3]
)
h = torch.cat([feature, input_views], -1)
layers = self.views_linears
for i, l in enumerate(layers):
h = layers[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
if detailed_output:
if self.test_time_nonrigid_object_removal_threshold is not None:
outputs[ details["rigidity_mask"].flatten() >= self.test_time_nonrigid_object_removal_threshold , 3] *= 0. # make nonrigid objects invisible
#outputs[ details["rigidity_mask"].flatten() <= self.test_time_nonrigid_object_removal_threshold , 3] *= 0. # make rigid objects invisible
return outputs, details
else:
return outputs