def get_calibrated_confidences()

in attacks/privacy_attacks.py [0:0]


def get_calibrated_confidences(params, private_model, attack_model, ids, mask, aug_style='mean'):
    """
    return calibrated confidences. 
    """
    #load the dataset
    dataset = get_dataset(params)
    #initialize dot products to 0
    confidences=np.zeros(len(mask))
    
    if params.aug:
        summed_confs=[[0] for i in np.arange(len(mask))]
        for j in np.arange(10):
            print('Aug', j)
            images=torch.stack([dataset[i][0] for i in ids])
            images=images.cuda()

            log_softmax = torch.nn.LogSoftmax(dim=1)
        
            output=private_model(images)
            attack_output=attack_model(images)
        
            log_output=log_softmax(output)
            log_attack_output=log_softmax(attack_output)
        
            private_confidences,_=torch.max(log_output,dim=1)
            attack_confidences,_=torch.max(log_attack_output,dim=1)
            confs=private_confidences-attack_confidences
            confs=confs.cpu().detach().numpy()
            for i,id in enumerate(ids):
                summed_confs[id].append(confs[i])
        for id in ids:
            if aug_style=='mean':
                confidences[id]=np.mean(summed_confs[id][1:])
            elif aug_style=='max':
                confidences[id]=np.max(summed_confs[id][1:])
            elif aug_style=='median':
                confidences[id]=np.median(summed_confs[id][1:])
            elif aug_style=='std':
                confidences[id]=np.std(summed_confs[id][1:])
    else:
        images=torch.stack([dataset[i][0] for i in ids])
        images=images.cuda()

        log_softmax = torch.nn.LogSoftmax(dim=1)
    
        output=private_model(images)
        attack_output=attack_model(images)
    
        log_output=log_softmax(output)
        log_attack_output=log_softmax(attack_output)
    
        private_confidences,_=torch.max(log_output,dim=1)
        attack_confidences,_=torch.max(log_attack_output,dim=1)
        confidences=private_confidences-attack_confidences
       
    return confidences