def get_repre()

in svhn_based_exp/utils.py [0:0]


def get_repre(outputs, repre, device):
    if repre == "preds":
        _, preds = outputs.max(1)
        new_outputs = torch.zeros(outputs.size()).to(device)
        new_outputs[torch.arange(outputs.size()[0]).long(), preds] = 1
    elif repre in ["logits", "feat"]:
        new_outputs = outputs
    else:
        print("invalid representation")
        exit(0)
    return new_outputs