in ssl/real-dataset/linear_feature_eval.py [0:0]
def eval_model(self, encoder, save_path=None, num_epoch=50):
remove_projection_head = True
if remove_projection_head:
output_feature_dim = encoder.feature_dim
encoder = torch.nn.Sequential(*list(encoder.children())[:-1])
else:
output_feature_dim = encoder.projetion.net[-1].out_features
device = self.device
encoder = encoder.to(device)
logreg = LogisticRegression(output_feature_dim, 10)
logreg = logreg.to(device)
encoder.eval()
x_train, y_train = get_features_from_encoder(encoder, self.stl_train_loader, device)
x_test, y_test = get_features_from_encoder(encoder, self.stl_test_loader, device)
if save_path:
np.savez(save_path, x_train.cpu().numpy(), x_test.cpu().numpy(), y_train.cpu().numpy(),
y_test.cpu().numpy())
if len(x_train.shape) > 2:
x_train = torch.mean(x_train, dim=[2, 3])
x_test = torch.mean(x_test, dim=[2, 3])
# log.info("Training data shape:", x_train.shape, y_train.shape)
# log.info("Testing data shape:", x_test.shape, y_test.shape)
x_train = x_train.cpu().numpy()
x_test = x_test.cpu().numpy()
scaler = preprocessing.StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train).astype(np.float32)
x_test = scaler.transform(x_test).astype(np.float32)
train_loader, test_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), y_train,
torch.from_numpy(x_test), y_test)
optimizer = torch.optim.Adam(logreg.parameters(), lr=3e-4, weight_decay=1e-3)
criterion = torch.nn.CrossEntropyLoss()
eval_every_n_epochs = 1
best_acc = 0.
for epoch in range(num_epoch):
for x, y in train_loader:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
logits = logreg(x)
predictions = torch.argmax(logits, dim=1)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
if epoch % eval_every_n_epochs == 0:
correct = 0
total = 0
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
logits = logreg(x)
predictions = torch.argmax(logits, dim=1)
total += y.size(0)
correct += (predictions == y).sum().item()
acc = 100 * correct / total
# log.info(f"Epoch {epoch} Testing accuracy: {acc}")
if acc > best_acc:
best_acc = acc
return best_acc