in slowfast/datasets/ssv2.py [0:0]
def __getitem__(self, index):
"""
Given the video index, return the list of frames, label, and video
index if the video frames can be fetched.
Args:
index (int): the video index provided by the pytorch sampler.
Returns:
frames (tensor): the frames of sampled from the video. The dimension
is `channel` x `num frames` x `height` x `width`.
label (int): the label of the current video.
index (int): the index of the video.
"""
short_cycle_idx = None
# When short cycle is used, input index is a tupple.
if isinstance(index, tuple):
index, short_cycle_idx = index
if self.mode in ["train", "val"]:
# -1 indicates random sampling.
spatial_sample_index = -1
min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
if short_cycle_idx in [0, 1]:
crop_size = int(
round(
self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
* self.cfg.MULTIGRID.DEFAULT_S
)
)
if self.cfg.MULTIGRID.DEFAULT_S > 0:
# Decreasing the scale is equivalent to using a larger "span"
# in a sampling grid.
min_scale = int(
round(
float(min_scale)
* crop_size
/ self.cfg.MULTIGRID.DEFAULT_S
)
)
elif self.mode in ["test"]:
# spatial_sample_index is in [0, 1, 2]. Corresponding to left,
# center, or right if width is larger than height, and top, middle,
# or bottom if height is larger than width.
spatial_sample_index = (
self._spatial_temporal_idx[index]
% self.cfg.TEST.NUM_SPATIAL_CROPS
)
min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3
# The testing is deterministic and no jitter should be performed.
# min_scale, max_scale, and crop_size are expect to be the same.
assert len({min_scale, max_scale, crop_size}) == 1
else:
raise NotImplementedError(
"Does not support {} mode".format(self.mode)
)
label = self._labels[index]
seq = self.get_seq_frames(index)
frames = torch.as_tensor(
utils.retry_load_images(
[self._path_to_videos[index][frame] for frame in seq],
self._num_retries,
)
)
if self.cfg.DATA.USE_RAND_AUGMENT and self.mode in ["train"]:
# Transform to PIL Image
frames = [transforms.ToPILImage()(frame.squeeze().numpy()) for frame in frames]
# Perform RandAugment
img_size_min = crop_size
auto_augment_desc = "rand-m20-mstd0.5-inc1"
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in self.cfg.DATA.MEAN]),
)
seed = random.randint(0, 100000000)
frames = [autoaugment.rand_augment_transform(
auto_augment_desc, aa_params, seed)(frame) for frame in frames]
# To Tensor: T H W C
frames = [torch.tensor(np.array(frame)) for frame in frames]
frames = torch.stack(frames)
# Perform color normalization.
frames = utils.tensor_normalize(
frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
)
# T H W C -> C T H W.
frames = frames.permute(3, 0, 1, 2)
# Perform data augmentation.
use_random_resize_crop = self.cfg.DATA.USE_RANDOM_RESIZE_CROPS
if use_random_resize_crop:
if self.mode in ["train", "val"]:
frames = transform.random_resize_crop_video(
frames, crop_size,
scale=(0.05, 1.0),
interpolation_mode="bilinear")
frames, _ = transform.horizontal_flip(0.5, frames)
else:
assert len({min_scale, max_scale, crop_size}) == 1
frames, _ = transform.random_short_side_scale_jitter(
frames, min_scale, max_scale
)
frames, _ = transform.uniform_crop(
frames, crop_size, spatial_sample_index)
else:
# Perform data augmentation.
frames = utils.spatial_sampling(
frames,
spatial_idx=spatial_sample_index,
min_scale=min_scale,
max_scale=max_scale,
crop_size=crop_size,
random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
)
# T H W C -> T C H W.
if self.mode in ["train", "val"]:
frames = frames.permute(1, 0, 2, 3)
frames = utils.frames_augmentation(
frames,
colorjitter=self.cfg.DATA.COLORJITTER,
use_grayscale=self.cfg.DATA.GRAYSCALE,
use_gaussian=self.cfg.DATA.GAUSSIAN
)
frames = utils.pack_pathway_output(self.cfg, frames)
return frames, label, index, {}