in utils.py [0:0]
def extract_features(extr, device, data_loader):
extr.eval()
features = None
labels = None
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
data, target = data.to(device), target.to(device)
output = extr(data).data.cpu()
if features is None:
features = output.squeeze()
labels = target
else:
features = torch.cat([features, output.squeeze()], dim=0)
labels = torch.cat([labels, target], dim=0)
return features, labels