in src/data_loader.py [0:0]
def __init__(self,
data_dir,
split,
maxnumims,
transform=None,
use_lmdb=False,
suff='',
shuffle=False,
perm=None,
include_eos=False):
self.aux_data_dir = os.path.join(data_dir, 'preprocessed')
self.ingrs_vocab = pickle.load(
open(os.path.join(self.aux_data_dir, suff + 'recipe1m_vocab_ingrs.pkl'), 'rb'))
self.dataset = pickle.load(
open(os.path.join(self.aux_data_dir, suff + 'recipe1m_' + split + '.pkl'), 'rb'))
self.use_lmdb = use_lmdb
if use_lmdb:
self.image_file = lmdb.open(
os.path.join(self.aux_data_dir, 'lmdb_' + split),
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
self.ids = []
self.split = split
for i, entry in enumerate(self.dataset):
if len(entry['images']) == 0:
continue
self.ids.append(i)
self.root = os.path.join(data_dir, 'images', split)
self.transform = transform
self.maxnumims = maxnumims
self.shuffle = shuffle
self.include_eos = include_eos
# remove eos from vocabulary list if not needed
if not self.include_eos:
self.ingrs_vocab.remove_eos()
if perm is not None:
self.ids = np.array(self.ids)[perm]
else:
self.ids = np.array(self.ids)