def auxiliary_attack()

in attacks/privacy_attacks.py [0:0]


def auxiliary_attack(params, aux_epochs, attack_type='loss', aug_style='mean', norm_type=None, public_data='train', num_aux=1,aux_style='sum'):
    """
    run an auxiliary attack, type (loss, grad_norm, conf, dist) specified by attack_type. 
    """
    #load the masks
    known_masks, hidden_masks = {}, {}
    hidden_masks['public'], hidden_masks['private']={},{}
    known_masks['public'] = torch.load(params.mask_path + "public.pth")
    known_masks['private'] = torch.load( params.mask_path + "private.pth")
    hidden_masks['private']['train']=torch.load( params.mask_path + "hidden/train.pth")
    hidden_masks['private']['heldout'] = torch.load( params.mask_path + "hidden/heldout.pth")
    hidden_masks['public']['train']=torch.load( params.mask_path + "hidden/public_train.pth")
    hidden_masks['public']['heldout'] = torch.load( params.mask_path + "hidden/public_heldout.pth")

    #get the final model parameters
    private_model=build_model(params)
    private_model_path = os.path.join(params.model_path, "checkpoint.pth")
    state_dict_private = torch.load(private_model_path,map_location='cuda:0')
    if params.dataset=='imagenet':
        new_state_dict = OrderedDict()
        for k, v in state_dict_private["model"].items():
            if k[:7]=='module.': # remove `module.`
                new_state_dict[k[7:]] = v
            else:
                new_state_dict[k]=v
        private_model.load_state_dict(new_state_dict)
    else:
        private_model.load_state_dict(state_dict_private['model'])
    private_model=private_model.cuda()

    # updated_params=copy.deepcopy(params)
    updated_params=params
    updated_params.epochs=updated_params.epochs+aux_epochs
   
    private_train_ids=(hidden_masks['private']['train']==True).nonzero().flatten().numpy()
    private_heldout_ids=(hidden_masks['private']['heldout']==True).nonzero().flatten().numpy()

    train_losses=np.zeros(len(known_masks['public']))
    heldout_losses=np.zeros(len(known_masks['public']))

    for i in np.arange(num_aux):
        if params.dataset=='cifar10' or params.dataset=='credit' or params.dataset=='hep' or params.dataset=='adult' or params.dataset=='mnist':
            model_num=params.model_path[-6:-5]
        elif params.dataset=='cifar100':
            model_num=params.model_path[-15:-14]
        else:
            model_num='0'
        new_model_path='updated_model_'+str(aux_epochs) +'_'+str(params.batch_size)+'_'+params.optimizer+'_aux_model_'+str(i)+'_num_aux_'+str(num_aux)+'_public_data_'+params.public_data+'_model_'+model_num
        if not os.path.isdir(new_model_path):
            os.mkdir(new_model_path)
        updated_params.dump_path=new_model_path
        if updated_params.local_rank!=-1:
            updated_params.local_rank=-1
        path = os.path.join(updated_params.dump_path, 'checkpoint.pth')
        torch.save(state_dict_private, path)

        if public_data=='train':
            print('Using public training data for auxiliary model')
            updated_model=train(updated_params, hidden_masks['public']['train'])
        elif public_data[:4]=='rand':
            print('Using random subset for auxiliary model')
            public_ids=(known_masks['public']==True).nonzero().flatten().numpy()
            prop_selected=float(public_data[4:])/100
            num_selected=math.ceil(prop_selected*len(public_ids))
            permuted_ids=np.random.permutation(public_ids)
            aux_data_mask=to_mask(len(known_masks['public']),permuted_ids[:num_selected])
            print('Number of public model training points', len((aux_data_mask==True).nonzero().flatten().numpy()))
            updated_model=train(updated_params, aux_data_mask)
        else:
            print('Using all public data for auxiliary model')
            updated_model=train(updated_params, known_masks['public'])
        updated_model=updated_model.cuda()

        new_model=build_model(params)
        new_model_path=os.path.join(updated_params.dump_path, "checkpoint.pth")
        state_dict_new=torch.load(new_model_path,map_location='cuda:0')
        if params.dataset=='imagenet':
            new_state_dict = OrderedDict()
            for k, v in state_dict_new["model"].items():
                if k[:7]=='module.': # remove `module.`
                    new_state_dict[k[7:]] = v
                else:
                    new_state_dict[k]=v
            new_model.load_state_dict(new_state_dict)
        else:
            new_model.load_state_dict(state_dict_new['model'])
        new_model=new_model.cuda()

        #get losses
        if attack_type=='loss':
            train_vals=get_calibrated_losses(params, private_model, updated_model,private_train_ids,hidden_masks['private']['train'], aug_style)
            heldout_vals=get_calibrated_losses(params, private_model, updated_model,private_heldout_ids,hidden_masks['private']['heldout'], aug_style)
        elif attack_type=='conf':
            train_vals=get_calibrated_confidences(params, private_model, updated_model,private_train_ids,hidden_masks['private']['train'], aug_style)
            heldout_vals=get_calibrated_confidences(params, private_model, updated_model,private_heldout_ids,hidden_masks['private']['heldout'], aug_style)
        elif attack_type=='dist':
            private_train_ids=private_train_ids[np.random.choice(len(private_train_ids), size=params.num_points, replace=False)]
            private_heldout_ids=private_heldout_ids[np.random.choice(len(private_heldout_ids), size=params.num_points, replace=False)]
            train_vals=get_calibrated_distances(params, private_model, updated_model,private_train_ids)
            heldout_vals=get_calibrated_distances(params, private_model, updated_model,private_heldout_ids)
        else:
            original_private_model=[]
            for p in private_model.parameters():
                original_private_model.append(p.view(-1))
            original_private_model=torch.cat(original_private_model)

            original_updated_model=[]
            for p in new_model.parameters():
                original_updated_model.append(p.view(-1))
            original_updated_model=torch.cat(original_updated_model)

            if i==0:
                private_model=GradSampleModule(private_model)
            attack_model=GradSampleModule(new_model)

            train_vals=get_calibrated_gradnorm(params, private_model,original_private_model, attack_model,original_updated_model,private_train_ids,hidden_masks['private']['train'],  aug_style=aug_style, norm_type=norm_type)
            heldout_vals=get_calibrated_gradnorm(params, private_model, original_private_model,attack_model,original_updated_model,private_heldout_ids,hidden_masks['private']['heldout'], aug_style=aug_style,norm_type=norm_type)
        if aux_style=='max':
            train_losses=np.maximum(train_losses, train_vals)
            heldout_losses=np.maximum(heldout_losses, heldout_vals)
        else: 
            if params.attack_type=='conf' or params.attack_type=='dist':
                train_losses=train_vals
                heldout_losses=heldout_vals
            else:
                train_losses+=train_vals
                heldout_losses+=heldout_vals
    if aux_style=='mean':
        train_losses=train_losses/num_aux
        heldout_losses=heldout_losses/num_aux
    return train_losses, heldout_losses