in slowfast/datasets/kinetics.py [0:0]
def __getitem__(self, index):
"""
Given the video index, return the list of frames, label, and video
index if the video can be fetched and decoded successfully, otherwise
repeatly find a random video that can be decoded as a replacement.
Args:
index (int): the video index provided by the pytorch sampler.
Returns:
frames (tensor): the frames of sampled from the video. The dimension
is `channel` x `num frames` x `height` x `width`.
label (int): the label of the current video.
index (int): if the video provided by pytorch sampler can be
decoded, then return the index of the video. If not, return the
index of the video replacement that can be decoded.
"""
short_cycle_idx = None
# When short cycle is used, input index is a tupple.
if isinstance(index, tuple):
index, short_cycle_idx = index
if self.mode in ["train", "val"]:
# -1 indicates random sampling.
temporal_sample_index = -1
spatial_sample_index = -1
min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
if short_cycle_idx in [0, 1]:
crop_size = int(
round(
self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
* self.cfg.MULTIGRID.DEFAULT_S
)
)
if self.cfg.MULTIGRID.DEFAULT_S > 0:
# Decreasing the scale is equivalent to using a larger "span"
# in a sampling grid.
min_scale = int(
round(
float(min_scale)
* crop_size
/ self.cfg.MULTIGRID.DEFAULT_S
)
)
elif self.mode in ["test"]:
temporal_sample_index = (
self._spatial_temporal_idx[index]
// self.cfg.TEST.NUM_SPATIAL_CROPS
)
# spatial_sample_index is in [0, 1, 2]. Corresponding to left,
# center, or right if width is larger than height, and top, middle,
# or bottom if height is larger than width.
spatial_sample_index = (
self._spatial_temporal_idx[index]
% self.cfg.TEST.NUM_SPATIAL_CROPS
)
min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3
# The testing is deterministic and no jitter should be performed.
# min_scale, max_scale, and crop_size are expect to be the same.
assert len({min_scale, max_scale, crop_size}) == 1
else:
raise NotImplementedError(
"Does not support {} mode".format(self.mode)
)
sampling_rate = utils.get_random_sampling_rate(
self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
self.cfg.DATA.SAMPLING_RATE,
)
# Try to decode and sample a clip from a video. If the video can not be
# decoded, repeatly find a random video replacement that can be decoded.
for _ in range(self._num_retries):
video_container = None
try:
video_container = container.get_video_container(
self._path_to_videos[index],
self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE,
self.cfg.DATA.DECODING_BACKEND,
)
except Exception as e:
logger.info(
"Failed to load video from {} with error {}".format(
self._path_to_videos[index], e
)
)
# Select a random video if the current video was not able to access.
if video_container is None:
index = random.randint(0, len(self._path_to_videos) - 1)
continue
# Decode video. Meta info is used to perform selective decoding.
frames = decoder.decode(
video_container,
sampling_rate,
self.cfg.DATA.NUM_FRAMES,
temporal_sample_index,
self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
video_meta=self._video_meta[index],
target_fps=self.cfg.DATA.TARGET_FPS,
backend=self.cfg.DATA.DECODING_BACKEND,
max_spatial_scale=max_scale,
)
# If decoding failed (wrong format, video is too short, and etc),
# select another video.
if frames is None:
index = random.randint(0, len(self._path_to_videos) - 1)
continue
# Perform color normalization.
frames = utils.tensor_normalize(
frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
)
# T H W C -> C T H W.
frames = frames.permute(3, 0, 1, 2)
# Perform data augmentation.
frames = utils.spatial_sampling(
frames,
spatial_idx=spatial_sample_index,
min_scale=min_scale,
max_scale=max_scale,
crop_size=crop_size,
random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
)
label = self._labels[index]
frames = utils.pack_pathway_output(self.cfg, frames)
return frames, label, index, {}
else:
raise RuntimeError(
"Failed to fetch video after {} retries.".format(
self._num_retries
)
)