in svhn_based_exp/utils.py [0:0]
def get_repre(outputs, repre, device):
if repre == "preds":
_, preds = outputs.max(1)
new_outputs = torch.zeros(outputs.size()).to(device)
new_outputs[torch.arange(outputs.size()[0]).long(), preds] = 1
elif repre in ["logits", "feat"]:
new_outputs = outputs
else:
print("invalid representation")
exit(0)
return new_outputs