def forward()

in run_nerf_helpers.py [0:0]


    def forward(self, x, detailed_output=False):

        input_pts, input_views, input_latents = torch.split(
            x,
            [self.input_ch, self.input_ch_views, self.ray_bending_latent_size],
            dim=-1,
        )

        if detailed_output:
            details = {}
            details["initial_input_pts"] = (
                input_pts[:, :3].clone().detach()
            )  # only keep xyz (embedding/positional encoding has raw xyz as the first three entries)
        else:
            details = None
        if self.ray_bender[0] is not None:
            if self.use_viewdirs and not self.approx_nonrigid_viewdirs:
                if self.ray_bender[0].use_positionally_encoded_input:
                    raise RuntimeError("not supported")
                with torch.enable_grad():  # necessay to work properly in no_grad() mode
                    initial_input_pts = input_pts[:, :3]
                    if not initial_input_pts.requires_grad:
                        initial_input_pts.requires_grad = True  # only do this when the overall rendering is running in no_grad() mode
                    input_pts = self.ray_bender[0](
                        initial_input_pts, input_latents, details
                    )
                    bent_input_pts = input_pts[:, :3]
            else:
                input_pts = self.ray_bender[0](input_pts, input_latents, details)
        if detailed_output:
            details["input_pts"] = input_pts[:, :3].clone().detach()

        h = input_pts
        if self.time_conditioned_baseline:
            h = torch.cat([h, input_latents], -1)
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                if self.time_conditioned_baseline:
                    h = torch.cat([input_pts, input_latents, h], -1)
                else:
                    h = torch.cat([input_pts, h], -1)

        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)

            if self.ray_bender[0] is not None:
                if self.approx_nonrigid_viewdirs:
                    input_views = self.viewdirs_via_finite_differences(input_pts[:, :3])
                else:
                    input_views = self.exact_nonrigid_viewdirs(
                        initial_input_pts, bent_input_pts, input_views[:, :3]
                    )

            h = torch.cat([feature, input_views], -1)
            layers = self.views_linears

            for i, l in enumerate(layers):
                h = layers[i](h)
                h = F.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        if detailed_output:
            if self.test_time_nonrigid_object_removal_threshold is not None:
                outputs[ details["rigidity_mask"].flatten() >= self.test_time_nonrigid_object_removal_threshold , 3] *= 0. # make nonrigid objects invisible
                #outputs[ details["rigidity_mask"].flatten() <= self.test_time_nonrigid_object_removal_threshold , 3] *= 0. # make rigid objects invisible
            return outputs, details
        else:
            return outputs