in data_loaders.py [0:0]
def __init__(self, classes, args, set_type='train', sample_num=3000):
# initialization of data locations
self.args = args
self.surf_location = '../data/surfaces/'
self.img_location = '../data/images/'
self.touch_location = '../data/scene_info/'
self.sheet_location = '../data/sheets/'
self.sample_num = sample_num
self.set_type = set_type
self.set_list = np.load('../data/split.npy', allow_pickle='TRUE').item()
names = [[f.split('/')[-1], f.split('/')[-2]] for f in glob((f'{self.img_location}/*/*'))]
self.names = []
self.classes_names = [[] for _ in classes]
np.random.shuffle(names)
for n in tqdm(names):
if n[1] in classes:
if os.path.exists(self.surf_location + n[1] + '/' + n[0] + '.npy'):
if os.path.exists(self.touch_location + n[1] + '/' + n[0]):
if n[0] + n[1] in self.set_list[self.set_type]:
if n[0] +n[1] in self.set_list[self.set_type]:
self.names.append(n)
self.classes_names[classes.index(n[1])].append(n)
print(f'The number of {set_type} set objects found : {len(self.names)}')