in src/data_loader.py [0:0]
def __getitem__(self, item):
path = self.ids[item].rstrip()
data = self.tags[item]
path = '/'.join(path.split('\\'))
impath = os.path.join(self.root, 'Flickr', path)
img = Image.open(impath).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# data = data.rstrip().split(' ')
# data = np.array([int(i.rstrip()) for i in data])
data = np.asarray(data.rstrip().split('\n '), dtype='uint8')
target = np.where(data == 1)[0]
idxs = list(range(len(target)))
if self.shuffle:
np.random.shuffle(idxs)
# build target
target_list = []
for t in idxs:
category_id = target[t] + 1 if self.include_eos else target[t]
if category_id not in target_list:
target_list.append(category_id)
# eos
if self.include_eos:
target_list.append(0)
return img, target_list