in loaders/video_dataset.py [0:0]
def __getitem__(self, index: int):
"""Fetch tuples of data. index = i * (i-1) / 2 + j, where i > j for pair (i,j)
So [-1+sqrt(1+8k)]/2 < i <= [1+sqrt(1+8k))]/2, where k=index. So
i = floor([1+sqrt(1+8k))]/2)
j = k - i * (i - 1) / 2.
The number of image frames fetched, N, is not the 1, but computed
based on what kind of consistency to be measured.
For instance, geometry_consistency_loss requires random pairs as samples.
So N = 2.
If with more losses, say triplet one from temporal_consistency_loss. Then
N = 2 + 3.
Returns:
stacked_images (N, C, H, W): image frames
targets: {
'extrinsics': torch.tensor (N, 3, 4), # extrinsics of each frame.
Each (3, 4) = [R, t].
point_wolrd = R * point_cam + t
'intrinsics': torch.tensor (N, 4), # (fx, fy, cx, cy) for each frame
'geometry_consistency':
{
'indices': torch.tensor (2),
indices for corresponding pairs
[(ref_index, tgt_index), ...]
'flows': ((2, H, W),) * 2 in pixels.
For k in range(2) (ref or tgt),
pixel p = pixels[indices[b, k]][:, i, j]
correspond to
p + flows[k][b, :, i, j]
in frame indices[b, (k + 1) % 2].
'masks': ((1, H, W),) * 2. Masks of valid flow matches
to compute the consistency in training.
Values are 0 or 1.
}
}
"""
pair = self.flow_indices[index]
indices = torch.tensor(pair)
intrinsics = torch.stack([self.intrinsics[k] for k in pair], dim=0)
extrinsics = torch.stack([self.extrinsics[k] for k in pair], dim=0)
images = torch.stack(
[load_color(self.color_fmt.format(k), channels_first=True) for k in pair],
dim=0,
)
flows = [
load_flow(self.flow_fmt.format(k_ref, k_tgt), channels_first=True)
for k_ref, k_tgt in [pair, pair[::-1]]
]
masks = [
load_mask(self.mask_fmt.format(k_ref, k_tgt), channels_first=True)
for k_ref, k_tgt in [pair, pair[::-1]]
]
metadata = {
"extrinsics": extrinsics,
"intrinsics": intrinsics,
"geometry_consistency": {
"indices": indices,
"flows": flows,
"masks": masks,
},
}
if getattr(self, "scales", None):
if isinstance(self.scales, dict):
metadata["scales"] = torch.stack(
[torch.Tensor([self.scales[k]]) for k in pair], dim=0
)
else:
metadata["scales"] = torch.Tensor(
[self.scales, self.scales]).reshape(2, 1)
return (images, metadata)