in datasets/GDTPretrainDataset.py [0:0]
def __getitem__(self, index):
"""
Given the video index, return tensors: video, audio, label, vid_idx, idx
Otherwise, repeatly find a random video that can be decoded as a replacement.
Args:
index (int): the video index provided by the pytorch sampler.
Returns:
frames (tensor): the frames of sampled from the video. The dimension
is `channel` x `num frames` x `height` x `width`.
label (int): the label of the current video.
index (int): if the video provided by pytorch sampler can be
decoded, then return the index of the video. If not, return the
index of the video replacement that can be decoded.
"""
# T1: Sample selection
index_capped = index
index = self.valid_indices[index_capped]
# Get two random shifts
clip_idx1 = random.randint(0, 1000)
clip_idx2 = clip_idx1 if self.sync else random.randint(0, 1000)
clip_idx2 = clip_idx2 + 1000 if clip_idx2 < 0 else clip_idx2
num_clips = 1000
clip_idx_list = [clip_idx1, clip_idx2]
# Lists to store GDTs
V = []
A = []
# T2: Temporal Shift transformation: tau_1, tau_2
for tau_ix in range(2):
time_idx = clip_idx_list[tau_ix]
# Get video container
video_container = get_video_container(
self._path_to_videos[index],
ENABLE_MULTI_THREAD_DECODE,
DECODING_BACKEND,
)
# T3: Modality splicing transformation: (V, A)
frames, spec = decode(
self._path_to_videos[index],
video_container,
self.sample_rate,
self.num_frames,
time_idx,
num_clips=num_clips,
video_meta=self._video_meta[index],
target_fps=self.target_fps,
backend=DECODING_BACKEND,
max_spatial_scale=self.train_jitter_scles[1],
decode_audio=self.decode_audio,
aug_audio=self.aug_audio,
num_sec=self.num_sec,
aud_sample_rate=self.aud_sample_rate,
aud_spec_type=self.aud_spec_type,
use_volume_jittering=self.use_volume_jittering,
use_temporal_jittering=self.use_temporal_jittering,
z_normalize=self.z_normalize,
)
# T4: Time Reversal Operation: (R, RT)
for r_ix in range(2):
# Clone frames and spec
no_aug_frames = frames.clone()
aug_spec = spec.clone()
# Reverse audio and video
if r_ix % 2 == 0:
no_aug_frames = no_aug_frames
aug_spec = aug_spec
else:
no_aug_frames = no_aug_frames.flip(0) # T H W C
aug_spec = aug_spec.flip(-1) # F x T
# T5: Data Augmentation: (gv, ga)
aug_frames = self.augmentation(no_aug_frames)
# Add to V, A list
V.append(aug_frames)
A.append(aug_spec)
label = self._labels[index]
vid_idx = self._vid_indices[index]
idx = index
return torch.cat(V, dim=0), torch.cat(A, dim=0), label, vid_idx, index_capped