def scene_flow_loss()

in loss/scene_flow_loss.py [0:0]


    def scene_flow_loss(self, points_cam, metadata, pixels):
        """Scene Flow Loss.

        The scene flow loss consists of two parts:
            - static scene flow loss: geometric consistency loss
            - temporal smoothness scene flow loss
        Both losses are measured in 3D world coordindates.

        Using only static loss: N = 2
        Both losses are measured in 3D world coordindates.

        Args:
            points_cam (B, N, 3, H, W): points in local camera coordinate.
            pixels (B, N, 2, H, W)
            metadata: dictionary of related metadata to compute the loss. Here assumes
                metadata include entries as below.
                {
                    'extrinsics': torch.tensor (B, N, 3, 4), # extrinsics of each frame.
                                    Each (3, 4) = [R, t]
                    'intrinsics': torch.tensor (B, N, 4), # (fx, fy, cx, cy)
                    'geometry_consistency':
                        {
                            'flows':    (B, 2, H, W),) * 2 in pixels.
                                        For k in range(2) (ref or tgt),
                                            pixel p = pixels[indices[b, k]][:, i, j]
                                        correspond to
                                            p + flows[k][b, :, i, j]
                                        in frame indices[b, (k + 1) % 2].
                            'masks':    ((B, 1, H, W),) * 2. Masks of valid flow
                                        matches. Values are 0 or 1.
                        }
                    'temporal_smoothness': (if using temporal smoothness scene flow loss)
                        {
                            'indices':  torch.tensor (4),
                                indices for consecutive consecutive frames
                                        [(ref_index-1, ref_index + 1, tgt_index - 1, tgt_index + 1), ...]
                            'flows':    ((B, 2, H, W),)* 4 in pixels.
                                flows[0][b,:, i, j] - flow map for ref_index -> ref_index - 1 (backward flow)
                                flows[1][b,:, i, j] - flow map for ref_index -> ref_index + 1 (forward flow)
                                flows[2][b,:, i, j] - flow map for tgt_index -> tgt_index - 1 (backward flow)
                                flows[3][b,:, i, j] - flow map for tgt_index -> tgt_index + 1 (forward flow)
                            'masks':    ((B, 1, H, W),) * 4. Masks of valid flow matches
                                to compute the consistency in training.
                                Values are 0 or 1.
                    }
                }
        """
        extrinsics = metadata["extrinsics"]
        extrinsics = select_tensors(extrinsics)
        intrinsics = metadata["intrinsics"]
        intrinsics = select_tensors(intrinsics)
        points_cam = select_tensors(points_cam)
        pixels = select_tensors(pixels)

        pair_idx = [0, 1]

        points_cam_pair = points_cam[pair_idx]
        pixels_pair = pixels[pair_idx]
        extrinsics_pair = extrinsics[pair_idx]
        intrinsics_pair = intrinsics[pair_idx]

        geom_meta = metadata["geometry_consistency"]
        flows_pair = (flows for flows in geom_meta["flows"])
        masks_pair = (masks for masks in geom_meta["masks"])

        static_losses, smooth_reproj_losses, smooth_disparity_losses = [], [], []
        scene_flow_pair, scene_flow_neighbor = [], []

        if self.opt.lambda_scene_flow_static > 0:
            static_losses, scene_flow_pair = self.static_scene_flow_loss(
                points_cam_pair,
                pixels_pair,
                extrinsics_pair,
                flows_pair,
                masks_pair,
            )

        if (
            self.opt.lambda_smooth_disparity > 0
            or self.opt.lambda_smooth_reprojection > 0
            or self.opt.lambda_smooth_depth_ratio > 0
        ):
            smooth_meta = metadata["temporal_smoothness"]
            smooth_valid = smooth_meta["valid"]
            smooth_valid = smooth_valid.unsqueeze(-1)

            flows_n = smooth_meta["flows"]
            masks_n = smooth_meta["masks"]
            smooth_reproj_losses, smooth_disparity_losses, smooth_depth_ratio_losses, scene_flow_neighbor = \
                self.smooth_scene_flow_loss(
                    points_cam,
                    pixels,
                    intrinsics_pair,
                    extrinsics,
                    flows_n,
                    masks_n,
                    smooth_valid,
                )

        B = points_cam_pair[0].shape[0]
        dtype = points_cam_pair[0].dtype

        batch_losses = {}
        scene_flow_loss_sum = 0.0

        # Static scene flow loss directly in 3D (not stable)
        if self.opt.lambda_scene_flow_static > 0:
            static_loss = (
                self.opt.lambda_scene_flow_static
                * torch.mean(torch.stack(static_losses, dim=-1), dim=-1)
                if len(static_losses) > 0
                else torch.zeros(B, dtype=dtype, device=_device)
            )
            batch_losses.update({"static": static_loss})
            scene_flow_loss_sum = scene_flow_loss_sum + static_loss

        # Smooth scene flow loss (on spatial reprojection errors)
        if self.opt.lambda_smooth_reprojection > 0:
            smooth_reproj_loss = (
                self.opt.lambda_smooth_reprojection
                * torch.mean(torch.stack(smooth_reproj_losses, dim=-1), dim=-1)
                if len(smooth_reproj_losses) > 0
                else torch.zeros(B, dtype=dtype, device=_device)
            )
            batch_losses.update({"smooth_reproj": smooth_reproj_loss})
            scene_flow_loss_sum = scene_flow_loss_sum + smooth_reproj_loss

        # Smooth scene flow loss (on disparity errors)
        if self.opt.lambda_smooth_disparity > 0:
            smooth_disparity_loss = (
                self.opt.lambda_smooth_disparity
                * torch.mean(torch.stack(smooth_disparity_losses, dim=-1), dim=-1)
                if len(smooth_disparity_losses) > 0
                else torch.zeros(B, dtype=dtype, device=_device)
            )
            batch_losses.update({"smooth_disparity": smooth_disparity_loss})
            scene_flow_loss_sum = scene_flow_loss_sum + smooth_disparity_loss

        # Smooth scene flow loss (on depth ratio)
        if self.opt.lambda_smooth_depth_ratio > 0:
            smooth_depth_ratio_loss = (
                torch.mean(torch.stack(smooth_depth_ratio_losses, dim=-1), dim=-1)
                if len(smooth_depth_ratio_losses) > 0
                else torch.zeros(B, dtype=dtype, device=_device)
            )
            batch_losses.update({"smooth_depth_ratio": smooth_depth_ratio_loss})
            scene_flow_loss_sum = scene_flow_loss_sum + smooth_depth_ratio_loss

        # List of scene flow maps for visualization
        scene_flow = scene_flow_pair + scene_flow_neighbor

        scene_flow_loss_mean = torch.mean(scene_flow_loss_sum)

        return scene_flow_loss_mean, batch_losses, scene_flow