in utils.py [0:0]
def load_features(args):
ckpt_file = '%s/%s_%s_extracted.pth' % (args.data_dir, args.extractor, args.dataset)
if os.path.exists(ckpt_file):
checkpoint = torch.load(ckpt_file)
X_train = checkpoint['X_train'].cpu()
y_train = checkpoint['y_train'].cpu()
X_test = checkpoint['X_test'].cpu()
y_test = checkpoint['y_test'].cpu()
else:
print('Extracted features not found, loading raw features.')
if args.dataset == 'MNIST':
trainset = datasets.MNIST(args.data_dir, train=True, transform=transforms.ToTensor())
testset = datasets.MNIST(args.data_dir, train=False, transform=transforms.ToTensor())
X_train = torch.zeros(len(trainset), 784)
y_train = torch.zeros(len(trainset))
X_test = torch.zeros(len(testset), 784)
y_test = torch.zeros(len(testset))
for i in range(len(trainset)):
x, y = trainset[i]
X_train[i] = x.view(784) - 0.5
y_train[i] = y
for i in range(len(testset)):
x, y = testset[i]
X_test[i] = x.view(784) - 0.5
y_test[i] = y
# load classes 3 and 8
train_indices = (y_train.eq(3) + y_train.eq(8)).gt(0)
test_indices = (y_test.eq(3) + y_test.eq(8)).gt(0)
X_train = X_train[train_indices]
y_train = y_train[train_indices].eq(3).float()
X_test = X_test[test_indices]
y_test = y_test[test_indices].eq(3).float()
else:
print("Error: Unknown dataset %s. Aborting." % args.dataset)
sys.exit(1)
# L2 normalize features
X_train /= X_train.norm(2, 1).unsqueeze(1)
X_test /= X_test.norm(2, 1).unsqueeze(1)
# convert labels to +/-1 or one-hot vectors
if args.train_mode == 'binary':
y_train_onehot = y_train
y_train = (2 * y_train - 1)
else:
y_train_onehot = onehot(y_train)
if len(y_train_onehot.size()) == 1:
y_train_onehot = y_train_onehot.unsqueeze(1)
return X_train, X_test, y_train, y_train_onehot, y_test