in dataset/dataset_downstream.py [0:0]
def __getitem__(self, item):
frame_root, frame_num, cls = self.samples[item]
sample_num = 3 if self.sample_num <= 0 or self.sample_num > frame_num else self.sample_num
frame_indices = np.round(np.linspace(1, frame_num, num=frame_num)).astype(np.int64)
if self.data_source == 'anet':
frame_indices = np.round(np.linspace(0, frame_num-1, num=frame_num)).astype(np.int64)
segments_length = frame_num // sample_num
segments = []
for i in range(sample_num):
start_idx = i * segments_length
if i == sample_num - 1:
segment = frame_indices[start_idx:]
else:
end = (i + 1) * segments_length
segment = frame_indices[start_idx:end]
segments.append(segment)
images = []
images_ids = []
for segment in segments:
image_path_ind = np.random.choice(segment, 1)[0]
image = self._get_aug_frame(frame_root, image_path_ind).unsqueeze(dim=0)
images.append(image)
images_ids.append(image_path_ind)
if len(images) < sample_num:
return None
clips = torch.cat(images, dim=0)
return clips, cls