in eval_knn.py [0:0]
def extract_feature_pipeline(args):
# ============ preparing data ... ============
transform = pth_transforms.Compose([
pth_transforms.Resize(256, interpolation=3),
pth_transforms.CenterCrop(224),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
# ============ building network ... ============
if "vit" in args.arch:
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
elif "xcit" in args.arch:
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
elif args.arch in torchvision_models.__dict__.keys():
model = torchvision_models.__dict__[args.arch](num_classes=0)
model.fc = nn.Identity()
else:
print(f"Architecture {args.arch} non supported")
sys.exit(1)
model.cuda()
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
model.eval()
# ============ extract features ... ============
print("Extracting features for train set...")
train_features = extract_features(model, data_loader_train, args.use_cuda)
print("Extracting features for val set...")
test_features = extract_features(model, data_loader_val, args.use_cuda)
if utils.get_rank() == 0:
train_features = nn.functional.normalize(train_features, dim=1, p=2)
test_features = nn.functional.normalize(test_features, dim=1, p=2)
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
# save features and labels
if args.dump_features and dist.get_rank() == 0:
torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
return train_features, test_features, train_labels, test_labels