in loss/scene_flow_loss.py [0:0]
def scene_flow_loss(self, points_cam, metadata, pixels):
"""Scene Flow Loss.
The scene flow loss consists of two parts:
- static scene flow loss: geometric consistency loss
- temporal smoothness scene flow loss
Both losses are measured in 3D world coordindates.
Using only static loss: N = 2
Both losses are measured in 3D world coordindates.
Args:
points_cam (B, N, 3, H, W): points in local camera coordinate.
pixels (B, N, 2, H, W)
metadata: dictionary of related metadata to compute the loss. Here assumes
metadata include entries as below.
{
'extrinsics': torch.tensor (B, N, 3, 4), # extrinsics of each frame.
Each (3, 4) = [R, t]
'intrinsics': torch.tensor (B, N, 4), # (fx, fy, cx, cy)
'geometry_consistency':
{
'flows': (B, 2, H, W),) * 2 in pixels.
For k in range(2) (ref or tgt),
pixel p = pixels[indices[b, k]][:, i, j]
correspond to
p + flows[k][b, :, i, j]
in frame indices[b, (k + 1) % 2].
'masks': ((B, 1, H, W),) * 2. Masks of valid flow
matches. Values are 0 or 1.
}
'temporal_smoothness': (if using temporal smoothness scene flow loss)
{
'indices': torch.tensor (4),
indices for consecutive consecutive frames
[(ref_index-1, ref_index + 1, tgt_index - 1, tgt_index + 1), ...]
'flows': ((B, 2, H, W),)* 4 in pixels.
flows[0][b,:, i, j] - flow map for ref_index -> ref_index - 1 (backward flow)
flows[1][b,:, i, j] - flow map for ref_index -> ref_index + 1 (forward flow)
flows[2][b,:, i, j] - flow map for tgt_index -> tgt_index - 1 (backward flow)
flows[3][b,:, i, j] - flow map for tgt_index -> tgt_index + 1 (forward flow)
'masks': ((B, 1, H, W),) * 4. Masks of valid flow matches
to compute the consistency in training.
Values are 0 or 1.
}
}
"""
extrinsics = metadata["extrinsics"]
extrinsics = select_tensors(extrinsics)
intrinsics = metadata["intrinsics"]
intrinsics = select_tensors(intrinsics)
points_cam = select_tensors(points_cam)
pixels = select_tensors(pixels)
pair_idx = [0, 1]
points_cam_pair = points_cam[pair_idx]
pixels_pair = pixels[pair_idx]
extrinsics_pair = extrinsics[pair_idx]
intrinsics_pair = intrinsics[pair_idx]
geom_meta = metadata["geometry_consistency"]
flows_pair = (flows for flows in geom_meta["flows"])
masks_pair = (masks for masks in geom_meta["masks"])
static_losses, smooth_reproj_losses, smooth_disparity_losses = [], [], []
scene_flow_pair, scene_flow_neighbor = [], []
if self.opt.lambda_scene_flow_static > 0:
static_losses, scene_flow_pair = self.static_scene_flow_loss(
points_cam_pair,
pixels_pair,
extrinsics_pair,
flows_pair,
masks_pair,
)
if (
self.opt.lambda_smooth_disparity > 0
or self.opt.lambda_smooth_reprojection > 0
or self.opt.lambda_smooth_depth_ratio > 0
):
smooth_meta = metadata["temporal_smoothness"]
smooth_valid = smooth_meta["valid"]
smooth_valid = smooth_valid.unsqueeze(-1)
flows_n = smooth_meta["flows"]
masks_n = smooth_meta["masks"]
smooth_reproj_losses, smooth_disparity_losses, smooth_depth_ratio_losses, scene_flow_neighbor = \
self.smooth_scene_flow_loss(
points_cam,
pixels,
intrinsics_pair,
extrinsics,
flows_n,
masks_n,
smooth_valid,
)
B = points_cam_pair[0].shape[0]
dtype = points_cam_pair[0].dtype
batch_losses = {}
scene_flow_loss_sum = 0.0
# Static scene flow loss directly in 3D (not stable)
if self.opt.lambda_scene_flow_static > 0:
static_loss = (
self.opt.lambda_scene_flow_static
* torch.mean(torch.stack(static_losses, dim=-1), dim=-1)
if len(static_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"static": static_loss})
scene_flow_loss_sum = scene_flow_loss_sum + static_loss
# Smooth scene flow loss (on spatial reprojection errors)
if self.opt.lambda_smooth_reprojection > 0:
smooth_reproj_loss = (
self.opt.lambda_smooth_reprojection
* torch.mean(torch.stack(smooth_reproj_losses, dim=-1), dim=-1)
if len(smooth_reproj_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"smooth_reproj": smooth_reproj_loss})
scene_flow_loss_sum = scene_flow_loss_sum + smooth_reproj_loss
# Smooth scene flow loss (on disparity errors)
if self.opt.lambda_smooth_disparity > 0:
smooth_disparity_loss = (
self.opt.lambda_smooth_disparity
* torch.mean(torch.stack(smooth_disparity_losses, dim=-1), dim=-1)
if len(smooth_disparity_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"smooth_disparity": smooth_disparity_loss})
scene_flow_loss_sum = scene_flow_loss_sum + smooth_disparity_loss
# Smooth scene flow loss (on depth ratio)
if self.opt.lambda_smooth_depth_ratio > 0:
smooth_depth_ratio_loss = (
torch.mean(torch.stack(smooth_depth_ratio_losses, dim=-1), dim=-1)
if len(smooth_depth_ratio_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"smooth_depth_ratio": smooth_depth_ratio_loss})
scene_flow_loss_sum = scene_flow_loss_sum + smooth_depth_ratio_loss
# List of scene flow maps for visualization
scene_flow = scene_flow_pair + scene_flow_neighbor
scene_flow_loss_mean = torch.mean(scene_flow_loss_sum)
return scene_flow_loss_mean, batch_losses, scene_flow