in dataset/dataset_downstream.py [0:0]
def __getitem__(self, item):
frame_root, frame_num, cls = self.samples[item]
sample_num = frame_num if self.sample_num <= 0 or self.sample_num > frame_num else self.sample_num
frame_indices = np.round(np.linspace(1, frame_num, num=sample_num)).astype(np.int64)
if self.data_source == 'anet':
frame_indices = np.round(np.linspace(0, frame_num-1, num=frame_num)).astype(np.int64)
frames = torch.cat([self._get_aug_frame(frame_root, frame_indices[i]).unsqueeze(dim=0) for i in range(sample_num)], dim=0)
return frames, cls