in vision/run_weak_strong.py [0:0]
def get_embeddings(model, loader):
all_embeddings, all_y, all_probs = [], [], []
for x, y in tqdm.tqdm(loader):
output = model(x.cuda())
if len(output) == 2:
embeddings, logits = output
probs = torch.nn.functional.softmax(logits, dim=-1).detach().cpu()
all_probs.append(probs)
else:
embeddings = output
all_embeddings.append(embeddings.detach().cpu())
all_y.append(y)
all_embeddings = torch.cat(all_embeddings, axis=0)
all_y = torch.cat(all_y, axis=0)
if len(all_probs) > 0:
all_probs = torch.cat(all_probs, axis=0)
acc = (torch.argmax(all_probs, dim=1) == all_y).float().mean()
else:
all_probs = None
acc = None
return all_embeddings, all_y, all_probs, acc