vision/run_weak_strong.py (135 lines of code) (raw):
import fire
import numpy as np
import torch
import tqdm
from data import get_imagenet
from models import alexnet, resnet50_dino, vitb8_dino
from torch import nn
def get_model(name):
if name == "alexnet":
model = alexnet()
elif name == "resnet50_dino":
model = resnet50_dino()
elif name == "vitb8_dino":
model = vitb8_dino()
else:
raise ValueError(f"Unknown model {name}")
model.cuda()
model.eval()
model = nn.DataParallel(model)
return model
def get_embeddings(model, loader):
all_embeddings, all_y, all_probs = [], [], []
for x, y in tqdm.tqdm(loader):
output = model(x.cuda())
if len(output) == 2:
embeddings, logits = output
probs = torch.nn.functional.softmax(logits, dim=-1).detach().cpu()
all_probs.append(probs)
else:
embeddings = output
all_embeddings.append(embeddings.detach().cpu())
all_y.append(y)
all_embeddings = torch.cat(all_embeddings, axis=0)
all_y = torch.cat(all_y, axis=0)
if len(all_probs) > 0:
all_probs = torch.cat(all_probs, axis=0)
acc = (torch.argmax(all_probs, dim=1) == all_y).float().mean()
else:
all_probs = None
acc = None
return all_embeddings, all_y, all_probs, acc
def train_logreg(
x_train,
y_train,
eval_datasets,
n_epochs=10,
weight_decay=0.0,
lr=1.0e-3,
batch_size=100,
n_classes=1000,
):
x_train = x_train.float()
train_ds = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=batch_size)
d = x_train.shape[1]
model = torch.nn.Linear(d, n_classes).cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr)
n_batches = len(train_loader)
n_iter = n_batches * n_epochs
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter)
results = {f"{key}_all": [] for key in eval_datasets.keys()}
for epoch in (pbar := tqdm.tqdm(range(n_epochs), desc="Epoch 0")):
correct, total = 0, 0
for x, y in train_loader:
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
schedule.step()
if len(y.shape) > 1:
y = torch.argmax(y, dim=1)
correct += (torch.argmax(pred, -1) == y).detach().float().sum().item()
total += len(y)
pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}")
for key, (x_test, y_test) in eval_datasets.items():
x_test = x_test.float().cuda()
pred = torch.argmax(model(x_test), axis=-1).detach().cpu()
acc = (pred == y_test).float().mean()
results[f"{key}_all"].append(acc)
for key in eval_datasets.keys():
results[key] = results[f"{key}_all"][-1]
return results
def main(
batch_size: int = 128,
weak_model_name: str = "alexnet",
strong_model_name: str = "resnet50_dino",
n_train: int = 40000,
seed: int = 0,
data_path: str = "/root/",
n_epochs: int = 10,
lr: float = 1e-3,
):
weak_model = get_model(weak_model_name)
strong_model = get_model(strong_model_name)
_, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False)
print("Getting weak labels...")
_, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader)
print(f"Weak label accuracy: {weak_acc:.3f}")
print("Getting strong embeddings...")
embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader)
assert torch.all(gt_labels == strong_gt_labels)
del strong_gt_labels
order = np.arange(len(embeddings))
rng = np.random.default_rng(seed)
rng.shuffle(order)
x = embeddings[order]
y = gt_labels[order]
yw = weak_labels[order]
x_train, x_test = x[:n_train], x[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
yw_train, yw_test = yw[:n_train], yw[n_train:]
yw_test = torch.argmax(yw_test, dim=1)
eval_datasets = {"test": (x_test, y_test), "test_weak": (x_test, yw_test)}
print("Training logreg on weak labels...")
results_weak = train_logreg(x_train, yw_train, eval_datasets, n_epochs=n_epochs, lr=lr)
print(f"Final accuracy: {results_weak['test']:.3f}")
print(f"Final supervisor-student agreement: {results_weak['test_weak']:.3f}")
print(f"Accuracy by epoch: {[acc.item() for acc in results_weak['test_all']]}")
print(
f"Supervisor-student agreement by epoch: {[acc.item() for acc in results_weak['test_weak_all']]}"
)
print("Training logreg on ground truth labels...")
results_gt = train_logreg(x_train, y_train, eval_datasets, n_epochs=n_epochs, lr=lr)
print(f"Final accuracy: {results_gt['test']:.3f}")
print(f"Accuracy by epoch: {[acc.item() for acc in results_gt['test_all']]}")
print("\n\n" + "=" * 100)
print(f"Weak label accuracy: {weak_acc:.3f}")
print(f"Weak→Strong accuracy: {results_weak['test']:.3f}")
print(f"Strong accuracy: {results_gt['test']:.3f}")
print("=" * 100)
if __name__ == "__main__":
fire.Fire(main)