in loaders/video_dataset.py [0:0]
def __getitem__(self, index: int):
"""Fetch tuples of data. index = i * (i-1) / 2 + j, where i > j for pair (i,j)
So [-1+sqrt(1+8k)]/2 < i <= [1+sqrt(1+8k))]/2, where k=index. So
i = floor([1+sqrt(1+8k))]/2)
j = k - i * (i - 1) / 2.
The number of image frames fetched, N, is not the 1, but computed
based on what kind of consistency to be measured.
For instance, geometry_consistency_loss requires random pairs as samples -> N = 2
When using temporal_smoothness_loss, it requires two sets of triplets -> N =6
Returns:
stacked_images (N, C, H, W): image frames
targets: {
'extrinsics': torch.tensor (N, 3, 4), # extrinsics of each frame.
Each (3, 4) = [R, t].
point_wolrd = R * point_cam + t
'intrinsics': torch.tensor (N, 4), # (fx, fy, cx, cy) for each frame
'geometry_consistency':
{
'indices': torch.tensor (2),
indices for corresponding pairs
[(ref_index, tgt_index), ...]
'flows': ((2, H, W),) * 2 in pixels.
For k in range(2) (ref or tgt),
pixel p = pixels[indices[b, k]][:, i, j]
correspond to
p + flows[k][b, :, i, j]
in frame indices[b, (k + 1) % 2].
'masks': ((1, H, W),) * 2. Masks of valid flow matches
to compute the consistency in training.
Values are 0 or 1.
}
'temporal_smoothness': (if using temporal smoothness scene flow loss)
{
'indices': torch.tensor (4),
indices for consecutive consecutive frames
[(ref_index-1, ref_index + 1, tgt_index - 1, tgt_index + 1), ...]
'flows': ((2, H, W),) * 4 in pixels.
flows[0][b,:, i, j] - flow map for ref_index -> ref_index - 1 (backward flow)
flows[1][b,:, i, j] - flow map for ref_index -> ref_index + 1 (forward flow)
flows[2][b,:, i, j] - flow map for tgt_index -> tgt_index - 1 (backward flow)
flows[3][b,:, i, j] - flow map for tgt_index -> tgt_index + 1 (forward flow)
'masks': ((1, H, W),) * 4. Masks of valid flow matches
to compute the consistency in training.
Values are 0 or 1.
'valid': torch.tensor (2), 1.0: valid; 0.0: invalid flow neighbors
}
}
"""
pair = self.flow_indices[index]
indices = torch.tensor(pair)
# Prepare metadata for the sampled frame pair
intrinsics = torch.stack([self.intrinsics[k] for k in pair], dim=0)
extrinsics = torch.stack([self.extrinsics[k] for k in pair], dim=0)
images = torch.stack(
[load_color(self.color_fmt.format(k), channels_first=True) for k in pair],
dim=0,
)
flows = [
load_flow(self.flow_fmt.format(k_ref, k_tgt), channels_first=True)
for k_ref, k_tgt in [pair, pair[::-1]]
]
masks = [
load_mask(self.mask_fmt.format(k_ref, k_tgt), channels_first=True)
for k_ref, k_tgt in [pair, pair[::-1]]
]
metadata = {
"extrinsics": extrinsics,
"intrinsics": intrinsics,
"geometry_consistency": {
"indices": indices,
"flows": flows,
"masks": masks,
},
}
if self.use_temporal_smooth_loss:
ref_index, trg_index = pair
_, _, H, W = images.shape
# Get metadata from forward/backward frames
intrinsics_n_ref, extrinsics_n_ref, images_n_ref, flows_n_ref, masks_n_ref = \
self.get_neighbor_meta(ref_index, (H, W))
intrinsics_n_trg, extrinsics_n_trg, images_n_trg, flows_n_trg, masks_n_trg = \
self.get_neighbor_meta(trg_index, (H, W))
# Concatentate the flow neighbors together
intrinsics_n = torch.cat([intrinsics_n_ref, intrinsics_n_trg], dim=0)
extrinsics_n = torch.cat([extrinsics_n_ref, extrinsics_n_trg], dim=0)
images_n = torch.cat([images_n_ref, images_n_trg], dim=0)
flows_n = flows_n_ref + flows_n_trg
masks_n = masks_n_ref + masks_n_trg
# Label invalid boundary frames
valid_flow_neighbor = torch.zeros(2, 1)
if 0 < ref_index < self.num_frames - 1:
valid_flow_neighbor[0] = 1.0
if 0 < trg_index < self.num_frames - 1:
valid_flow_neighbor[1] = 1.0
# Flow neigbhor frame indices
anchor = [ref_index, ref_index, trg_index, trg_index]
neighbors = [a + b for a, b in zip(anchor, [-1, 1, -1, 1])]
neighbors = [max(0, min(idx, self.num_frames - 1)) for idx in neighbors]
# Update the images, intrinsics, extrinsics, and metadata
images = torch.cat([images, images_n], dim=0)
intrinsics = torch.cat([intrinsics, intrinsics_n], dim=0)
extrinsics = torch.cat([extrinsics, extrinsics_n], dim=0)
# Update the extrinsics, intrinsics, and meta data for temporal smoothness
metadata["extrinsics"] = extrinsics
metadata["intrinsics"] = intrinsics
metadata["temporal_smoothness"] = {
"indices": torch.tensor(neighbors),
"flows": flows_n,
"masks": masks_n,
"valid": valid_flow_neighbor,
}
N = 6 if self.use_temporal_smooth_loss else 2
idx_list = list(pair) + neighbors if self.use_temporal_smooth_loss else pair
# Prepare the depth scale (map)
if getattr(self, "scales", None):
if isinstance(self.scales, dict):
metadata["scales"] = torch.stack([self.scales[k] for k in idx_list], dim=0) # (N, 1, 1) or (N, H, W)
else:
metadata["scales"] = self.scales * torch.ones([N, 1], dtype=_dtype)
# Prepare the 2D warp map from spatial transformation
if self.recon != "colmap":
metadata["warp"] = torch.stack([self.warp_map[k] for k in idx_list], dim=0) # (N, 2, H, W)
return (images, metadata)