def get_calibrated_losses()

in attacks/privacy_attacks.py [0:0]


def get_calibrated_losses(params, private_model, attack_model, ids, mask, aug_style='mean'):
    """
    return calibrated losses 
    """
    #load the dataset
    dataset = get_dataset(params)
    #initialize dot products to 0
    losses=np.zeros(len(mask))
    
    if params.aug:
        summed_loss=[[0] for i in np.arange(len(mask))]
        for j in np.arange(10):
            print('aug',j)
            batched_ids=np.array_split(ids, 1000)
            for b_ids in batched_ids:
                image_data=torch.stack([dataset[i][0] for i in b_ids])
                image_data=image_data.cuda()
                target_data=torch.stack([torch.tensor(dataset[i][1]) for i in b_ids])
                target_data=target_data.cuda()
                out_private=private_model(image_data)
                out_attack=attack_model(image_data)
                for i,id in enumerate(b_ids):
                    output=out_private[i].unsqueeze(0)
                    loss=F.cross_entropy(output, target_data[i].unsqueeze(0))
                    attack_output=out_attack[i].unsqueeze(0)
                    attack_loss=F.cross_entropy(attack_output, target_data[i].unsqueeze(0))
                    loss_diff=loss-attack_loss
                    summed_loss[id].append(loss_diff.cpu().detach().numpy())
        for id in ids:
            if aug_style=='mean':
                losses[id]=np.mean(summed_loss[id][1:])
            elif aug_style=='max':
                losses[id]=np.max(summed_loss[id][1:])
            elif aug_style=='median':
                losses[id]=np.median(summed_loss[id][1:])
            elif aug_style=='std':
                losses[id]=np.std(summed_loss[id][1:])
    else:
        for id in ids:
            #load each image and target
            image = dataset[id][0].unsqueeze(0)
            image = image.cuda(non_blocking=True)
            target = torch.tensor(dataset[id][1]).unsqueeze(0)
            target = target.cuda(non_blocking=True)

            #get the loss
            output=private_model(image)
            loss=F.cross_entropy(output, target)

            attack_output=attack_model(image)
            attack_loss=F.cross_entropy(attack_output, target)

            losses[id]=loss-attack_loss
       
    return losses