in loss/consistency_loss.py [0:0]
def geometry_consistency_loss(self, points_cam, metadata, pixels):
"""Geometry Consistency Loss.
For each pair as specified by indices,
geom_consistency = reprojection_error + disparity_error
reprojection_error is measured in the screen space of each camera in the pair.
Args:
points_cam (B, N, 3, H, W): points in local camera coordinate.
pixels (B, N, 2, H, W)
metadata: dictionary of related metadata to compute the loss. Here assumes
metadata include entries as below.
{
'extrinsics': torch.tensor (B, N, 3, 4), # extrinsics of each frame.
Each (3, 4) = [R, t]
'intrinsics': torch.tensor (B, N, 4), # (fx, fy, cx, cy)
'geometry_consistency':
{
'flows': (B, 2, H, W),) * 2 in pixels.
For k in range(2) (ref or tgt),
pixel p = pixels[indices[b, k]][:, i, j]
correspond to
p + flows[k][b, :, i, j]
in frame indices[b, (k + 1) % 2].
'masks': ((B, 1, H, W),) * 2. Masks of valid flow
matches. Values are 0 or 1.
}
}
"""
geom_meta = metadata["geometry_consistency"]
points_cam_pair = select_tensors(points_cam)
extrinsics = metadata["extrinsics"]
extrinsics_pair = select_tensors(extrinsics)
intrinsics = metadata["intrinsics"]
intrinsics_pair = select_tensors(intrinsics)
pixels_pair = select_tensors(pixels)
flows_pair = (flows for flows in geom_meta["flows"])
masks_pair = (masks for masks in geom_meta["masks"])
reproj_losses, disp_losses, depth_ratio_losses = [], [], []
inv_idxs = [1, 0]
for (
points_cam_ref,
tgt_points_cam_tgt,
pixels_ref,
flows_ref,
masks_ref,
intrinsics_ref,
intrinsics_tgt,
extrinsics_ref,
extrinsics_tgt,
) in zip(
points_cam_pair,
points_cam_pair[inv_idxs],
pixels_pair,
flows_pair,
masks_pair,
intrinsics_pair,
intrinsics_pair[inv_idxs],
extrinsics_pair,
extrinsics_pair[inv_idxs],
):
# === Reprojection loss ===
if self.opt.lambda_static_reprojection > 0:
# change to camera space for target_camera
points_cam_tgt = reproject_points(
points_cam_ref, extrinsics_ref, extrinsics_tgt
)
matched_pixels_tgt = pixels_ref + flows_ref
pixels_tgt = project(points_cam_tgt, intrinsics_tgt)
reproj_dist = torch.norm(pixels_tgt - matched_pixels_tgt,
dim=1, keepdim=True)
reproj_losses.append(
weighted_mean_loss(self.robust_dist(reproj_dist), masks_ref)
)
# === Disparity loss ===
if self.opt.lambda_static_disparity > 0:
# disparity consistency
f = torch.mean(focal_length(intrinsics_ref))
# warp points in target image grid target camera coordinates to
# reference image grid
warped_tgt_points_cam_tgt = sample(
tgt_points_cam_tgt, matched_pixels_tgt
)
disp_diff = 1.0 / points_cam_tgt[:, -1:, ...] \
- 1.0 / warped_tgt_points_cam_tgt[:, -1:, ...]
disp_losses.append(
f * weighted_mean_loss(self.robust_dist(disp_diff), masks_ref)
)
# === Depth ratio loss ===
if self.opt.lambda_static_depth_ratio > 0:
warped_tgt_points_cam_tgt = sample(
tgt_points_cam_tgt, matched_pixels_tgt
)
# the camera is facing the -z axis
depth_warped_tgt = torch.abs(warped_tgt_points_cam_tgt[:, -1:, ...])
depth_tgt = torch.abs(points_cam_tgt[:, -1:, ...])
# compute the min and max values for both depth values
depth_min = torch.min(depth_warped_tgt, depth_tgt)
depth_max = torch.max(depth_warped_tgt, depth_tgt)
# Compute the depth ratio. Pre-multiply weights before applying robust functions
depth_ratio = self.opt.lambda_static_depth_ratio * torch.log(depth_min / depth_max)
depth_ratio_losses.append(
weighted_mean_loss(self.robust_dist(depth_ratio), masks_ref)
)
B = points_cam_pair[0].shape[0]
dtype = points_cam_pair[0].dtype
batch_losses = {}
consistency_loss_sum = 0.0
# Spatial reprojection loss
if self.opt.lambda_static_reprojection > 0:
reproj_loss = (
self.opt.lambda_static_reprojection
* torch.mean(torch.stack(reproj_losses, dim=-1), dim=-1)
if len(reproj_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"reproj": reproj_loss})
consistency_loss_sum = consistency_loss_sum + reproj_loss
# Disparity loss
if self.opt.lambda_static_disparity > 0:
disp_loss = (
self.opt.lambda_static_disparity
* torch.mean(torch.stack(disp_losses, dim=-1), dim=-1)
if len(disp_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"disp": disp_loss})
consistency_loss_sum = consistency_loss_sum + disp_loss
# Depth ratio loss
if self.opt.lambda_static_depth_ratio > 0:
depth_ratio_loss = (
torch.mean(torch.stack(depth_ratio_losses, dim=-1), dim=-1)
if len(depth_ratio_losses) > 0
else torch.zeros(B, dtype=dtype, device=_device)
)
batch_losses.update({"depth ratio": depth_ratio_loss})
consistency_loss_sum = consistency_loss_sum + depth_ratio_loss
return torch.mean(consistency_loss_sum), batch_losses