in dataset/dataset_kinetics.py [0:0]
def __getitem__(self, item):
frame_root, frame_num, cls = self.samples[item]
initial_seed = random.randint(0, 2 ** 31)
set_rng(initial_seed)
###### Step 1: TSN samples ######
# segments (base on num_images_to_return)
frame_indices = np.round(np.linspace(1, frame_num, num=frame_num)).astype(np.int64)
segments_length = frame_num // self.num_segments
segments = []
for i in range(self.num_segments):
start_idx = i * segments_length
if i == self.num_segments - 1:
segment = frame_indices[start_idx:]
else:
end = (i + 1) * segments_length
segment = frame_indices[start_idx:end]
segments.append(segment)
# sample frames from each segments
key_images = []
queue_images = []
# debug
key_ids = []
queue_ids = []
for segment in segments:
image_path_inds = np.random.choice(segment, 2, replace=False)
for ii, ind in enumerate(image_path_inds):
image = self._get_aug_frame(frame_root, ind).unsqueeze(dim=0)
if ii == 0:
key_images.append(image)
key_ids.append(ind)
else:
queue_images.append(image)
queue_ids.append(ind)
if len(key_images) < self.num_segments:
return None
###### Step 2: SeCo samples ######
rand_segment = random.randint(0, 1)
if rand_segment == 0:
frame1_aug1 = queue_images[0].squeeze(dim=0)
frame1_aug2 = self._get_aug_frame(frame_root, queue_ids[0])
frame2_aug = queue_images[1].squeeze(dim=0)
frame3_aug = queue_images[2].squeeze(dim=0)
else:
frame1_aug1 = queue_images[2].squeeze(dim=0)
frame1_aug2 = self._get_aug_frame(frame_root, queue_ids[2])
frame2_aug = queue_images[0].squeeze(dim=0)
frame3_aug = queue_images[1].squeeze(dim=0)
###### Step 3: Order samples ######
# 4 labels: 0 (00), 1 (10), 2 (01), 3 (11)
rand_shuffle1 = random.randint(0, 1)
rand_shuffle2 = random.randint(0, 1)
if rand_shuffle1:
queue_images, queue_ids = shuffle_list(queue_images, queue_ids)
if rand_shuffle2:
key_images, key_ids = shuffle_list(key_images, key_ids)
order_label = 3 # label: 11
else:
order_label = 1 # label: 10
else:
if rand_shuffle2:
key_images, key_ids = shuffle_list(key_images, key_ids)
order_label = 2 # label: 01
else:
order_label = 0 # label: 00
# tsn q and k
tsn_q = torch.cat(queue_images, dim=0)
tsn_k = torch.cat(key_images, dim=0)
return frame1_aug1, frame1_aug2, frame2_aug, frame3_aug, order_label, tsn_q, tsn_k