in src/data_loader.py [0:0]
def __init__(self,
root,
split,
transform=None,
shuffle=False,
perm=None,
include_eos=False):
labels_dir = os.path.join(root, 'objectInfo150.txt')
lines = [x.split('\t')[-1].split(',')[0].rstrip() for x in open(labels_dir).readlines()][1:]
self.category_list = ['eos'] + lines + ['<pad>']
self.root = root
if split == 'train' or split == 'val':
samples_dir = os.path.join(root, 'ADE20K_object150_train.txt')
self.ids = list(open(samples_dir, 'r'))
else:
samples_dir = os.path.join(root, 'ADE20K_object150_val.txt')
self.ids = list(open(samples_dir, 'r'))
if perm is not None:
self.ids = np.array(self.ids)[perm]
else:
self.ids = np.array(self.ids)
self.transform = transform
self.shuffle = shuffle
self.include_eos = include_eos
# remove eos from category list if not needed
if not self.include_eos:
self.category_list = self.category_list[1:]