in data/dataset.py [0:0]
def get_split_info(self):
data = torch.load(self.root + '/metadata_{}.t7'.format(self.split))
train_data, val_data, test_data = [], [], []
for instance in data:
image, attr, obj, settype = instance['image'], instance[
'attr'], instance['obj'], instance['set']
if attr == 'NA' or (attr,
obj) not in self.pairs or settype == 'NA':
# ignore instances with unlabeled attributes
# ignore instances that are not in current split
continue
data_i = [image, attr, obj]
if settype == 'train':
train_data.append(data_i)
elif settype == 'val':
val_data.append(data_i)
else:
test_data.append(data_i)
return train_data, val_data, test_data