in src/loader.py [0:0]
def load_images(params):
"""
Load celebA dataset.
"""
# load data
images_filename = 'images_%i_%i_20000.pth' if params.debug else 'images_%i_%i.pth'
images_filename = images_filename % (params.img_sz, params.img_sz)
images = torch.load(os.path.join(DATA_PATH, images_filename))
attributes = torch.load(os.path.join(DATA_PATH, 'attributes.pth'))
# parse attributes
attrs = []
for name, n_cat in params.attr:
for i in range(n_cat):
attrs.append(torch.FloatTensor((attributes[name] == i).astype(np.float32)))
attributes = torch.cat([x.unsqueeze(1) for x in attrs], 1)
# split train / valid / test
if params.debug:
train_index = 10000
valid_index = 15000
test_index = 20000
else:
train_index = 162770
valid_index = 162770 + 19867
test_index = len(images)
train_images = images[:train_index]
valid_images = images[train_index:valid_index]
test_images = images[valid_index:test_index]
train_attributes = attributes[:train_index]
valid_attributes = attributes[train_index:valid_index]
test_attributes = attributes[valid_index:test_index]
# log dataset statistics / return dataset
logger.info('%i / %i / %i images with attributes for train / valid / test sets'
% (len(train_images), len(valid_images), len(test_images)))
log_attributes_stats(train_attributes, valid_attributes, test_attributes, params)
images = train_images, valid_images, test_images
attributes = train_attributes, valid_attributes, test_attributes
return images, attributes