in src/data_loader.py [0:0]
def __init__(self,
root,
split,
transform=None,
shuffle=False,
perm=None,
include_eos=False):
self.root = root
self.transform = transform
self.shuffle = shuffle
self.include_eos = include_eos
labels_dir = os.path.join(self.root, 'Concepts81.txt')
lines = list(open(labels_dir, 'r'))
self.category_list = ['eos'] + lines + ['<pad>']
# remove eos from category list if not needed
if not self.include_eos:
self.category_list = self.category_list[1:]
if split == 'train' or split == 'val':
self.tags = self.load_tags('train')
self.ids = list(open(os.path.join(self.root, 'ImageList', 'TrainImagelist.txt')))
else:
self.tags = self.load_tags('test')
self.ids = list(open(os.path.join(self.root, 'ImageList', 'TestImagelist.txt')))
if perm is not None:
self.tags = np.array(self.tags)[perm]
self.ids = np.array(self.ids)[perm]
else:
self.tags = np.array(self.tags)
self.ids = np.array(self.ids)
self.category_list = [x.rstrip() for x in self.category_list]