vision/run_weak_strong.py (135 lines of code) (raw):

import fire import numpy as np import torch import tqdm from data import get_imagenet from models import alexnet, resnet50_dino, vitb8_dino from torch import nn def get_model(name): if name == "alexnet": model = alexnet() elif name == "resnet50_dino": model = resnet50_dino() elif name == "vitb8_dino": model = vitb8_dino() else: raise ValueError(f"Unknown model {name}") model.cuda() model.eval() model = nn.DataParallel(model) return model def get_embeddings(model, loader): all_embeddings, all_y, all_probs = [], [], [] for x, y in tqdm.tqdm(loader): output = model(x.cuda()) if len(output) == 2: embeddings, logits = output probs = torch.nn.functional.softmax(logits, dim=-1).detach().cpu() all_probs.append(probs) else: embeddings = output all_embeddings.append(embeddings.detach().cpu()) all_y.append(y) all_embeddings = torch.cat(all_embeddings, axis=0) all_y = torch.cat(all_y, axis=0) if len(all_probs) > 0: all_probs = torch.cat(all_probs, axis=0) acc = (torch.argmax(all_probs, dim=1) == all_y).float().mean() else: all_probs = None acc = None return all_embeddings, all_y, all_probs, acc def train_logreg( x_train, y_train, eval_datasets, n_epochs=10, weight_decay=0.0, lr=1.0e-3, batch_size=100, n_classes=1000, ): x_train = x_train.float() train_ds = torch.utils.data.TensorDataset(x_train, y_train) train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=batch_size) d = x_train.shape[1] model = torch.nn.Linear(d, n_classes).cuda() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr) n_batches = len(train_loader) n_iter = n_batches * n_epochs schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter) results = {f"{key}_all": [] for key in eval_datasets.keys()} for epoch in (pbar := tqdm.tqdm(range(n_epochs), desc="Epoch 0")): correct, total = 0, 0 for x, y in train_loader: x, y = x.cuda(), y.cuda() optimizer.zero_grad() pred = model(x) loss = criterion(pred, y) loss.backward() optimizer.step() schedule.step() if len(y.shape) > 1: y = torch.argmax(y, dim=1) correct += (torch.argmax(pred, -1) == y).detach().float().sum().item() total += len(y) pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}") for key, (x_test, y_test) in eval_datasets.items(): x_test = x_test.float().cuda() pred = torch.argmax(model(x_test), axis=-1).detach().cpu() acc = (pred == y_test).float().mean() results[f"{key}_all"].append(acc) for key in eval_datasets.keys(): results[key] = results[f"{key}_all"][-1] return results def main( batch_size: int = 128, weak_model_name: str = "alexnet", strong_model_name: str = "resnet50_dino", n_train: int = 40000, seed: int = 0, data_path: str = "/root/", n_epochs: int = 10, lr: float = 1e-3, ): weak_model = get_model(weak_model_name) strong_model = get_model(strong_model_name) _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) print("Getting weak labels...") _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) print(f"Weak label accuracy: {weak_acc:.3f}") print("Getting strong embeddings...") embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader) assert torch.all(gt_labels == strong_gt_labels) del strong_gt_labels order = np.arange(len(embeddings)) rng = np.random.default_rng(seed) rng.shuffle(order) x = embeddings[order] y = gt_labels[order] yw = weak_labels[order] x_train, x_test = x[:n_train], x[n_train:] y_train, y_test = y[:n_train], y[n_train:] yw_train, yw_test = yw[:n_train], yw[n_train:] yw_test = torch.argmax(yw_test, dim=1) eval_datasets = {"test": (x_test, y_test), "test_weak": (x_test, yw_test)} print("Training logreg on weak labels...") results_weak = train_logreg(x_train, yw_train, eval_datasets, n_epochs=n_epochs, lr=lr) print(f"Final accuracy: {results_weak['test']:.3f}") print(f"Final supervisor-student agreement: {results_weak['test_weak']:.3f}") print(f"Accuracy by epoch: {[acc.item() for acc in results_weak['test_all']]}") print( f"Supervisor-student agreement by epoch: {[acc.item() for acc in results_weak['test_weak_all']]}" ) print("Training logreg on ground truth labels...") results_gt = train_logreg(x_train, y_train, eval_datasets, n_epochs=n_epochs, lr=lr) print(f"Final accuracy: {results_gt['test']:.3f}") print(f"Accuracy by epoch: {[acc.item() for acc in results_gt['test_all']]}") print("\n\n" + "=" * 100) print(f"Weak label accuracy: {weak_acc:.3f}") print(f"Weak→Strong accuracy: {results_weak['test']:.3f}") print(f"Strong accuracy: {results_gt['test']:.3f}") print("=" * 100) if __name__ == "__main__": fire.Fire(main)