in attacks/privacy_attacks.py [0:0]
def auxiliary_attack(params, aux_epochs, attack_type='loss', aug_style='mean', norm_type=None, public_data='train', num_aux=1,aux_style='sum'):
"""
run an auxiliary attack, type (loss, grad_norm, conf, dist) specified by attack_type.
"""
#load the masks
known_masks, hidden_masks = {}, {}
hidden_masks['public'], hidden_masks['private']={},{}
known_masks['public'] = torch.load(params.mask_path + "public.pth")
known_masks['private'] = torch.load( params.mask_path + "private.pth")
hidden_masks['private']['train']=torch.load( params.mask_path + "hidden/train.pth")
hidden_masks['private']['heldout'] = torch.load( params.mask_path + "hidden/heldout.pth")
hidden_masks['public']['train']=torch.load( params.mask_path + "hidden/public_train.pth")
hidden_masks['public']['heldout'] = torch.load( params.mask_path + "hidden/public_heldout.pth")
#get the final model parameters
private_model=build_model(params)
private_model_path = os.path.join(params.model_path, "checkpoint.pth")
state_dict_private = torch.load(private_model_path,map_location='cuda:0')
if params.dataset=='imagenet':
new_state_dict = OrderedDict()
for k, v in state_dict_private["model"].items():
if k[:7]=='module.': # remove `module.`
new_state_dict[k[7:]] = v
else:
new_state_dict[k]=v
private_model.load_state_dict(new_state_dict)
else:
private_model.load_state_dict(state_dict_private['model'])
private_model=private_model.cuda()
# updated_params=copy.deepcopy(params)
updated_params=params
updated_params.epochs=updated_params.epochs+aux_epochs
private_train_ids=(hidden_masks['private']['train']==True).nonzero().flatten().numpy()
private_heldout_ids=(hidden_masks['private']['heldout']==True).nonzero().flatten().numpy()
train_losses=np.zeros(len(known_masks['public']))
heldout_losses=np.zeros(len(known_masks['public']))
for i in np.arange(num_aux):
if params.dataset=='cifar10' or params.dataset=='credit' or params.dataset=='hep' or params.dataset=='adult' or params.dataset=='mnist':
model_num=params.model_path[-6:-5]
elif params.dataset=='cifar100':
model_num=params.model_path[-15:-14]
else:
model_num='0'
new_model_path='updated_model_'+str(aux_epochs) +'_'+str(params.batch_size)+'_'+params.optimizer+'_aux_model_'+str(i)+'_num_aux_'+str(num_aux)+'_public_data_'+params.public_data+'_model_'+model_num
if not os.path.isdir(new_model_path):
os.mkdir(new_model_path)
updated_params.dump_path=new_model_path
if updated_params.local_rank!=-1:
updated_params.local_rank=-1
path = os.path.join(updated_params.dump_path, 'checkpoint.pth')
torch.save(state_dict_private, path)
if public_data=='train':
print('Using public training data for auxiliary model')
updated_model=train(updated_params, hidden_masks['public']['train'])
elif public_data[:4]=='rand':
print('Using random subset for auxiliary model')
public_ids=(known_masks['public']==True).nonzero().flatten().numpy()
prop_selected=float(public_data[4:])/100
num_selected=math.ceil(prop_selected*len(public_ids))
permuted_ids=np.random.permutation(public_ids)
aux_data_mask=to_mask(len(known_masks['public']),permuted_ids[:num_selected])
print('Number of public model training points', len((aux_data_mask==True).nonzero().flatten().numpy()))
updated_model=train(updated_params, aux_data_mask)
else:
print('Using all public data for auxiliary model')
updated_model=train(updated_params, known_masks['public'])
updated_model=updated_model.cuda()
new_model=build_model(params)
new_model_path=os.path.join(updated_params.dump_path, "checkpoint.pth")
state_dict_new=torch.load(new_model_path,map_location='cuda:0')
if params.dataset=='imagenet':
new_state_dict = OrderedDict()
for k, v in state_dict_new["model"].items():
if k[:7]=='module.': # remove `module.`
new_state_dict[k[7:]] = v
else:
new_state_dict[k]=v
new_model.load_state_dict(new_state_dict)
else:
new_model.load_state_dict(state_dict_new['model'])
new_model=new_model.cuda()
#get losses
if attack_type=='loss':
train_vals=get_calibrated_losses(params, private_model, updated_model,private_train_ids,hidden_masks['private']['train'], aug_style)
heldout_vals=get_calibrated_losses(params, private_model, updated_model,private_heldout_ids,hidden_masks['private']['heldout'], aug_style)
elif attack_type=='conf':
train_vals=get_calibrated_confidences(params, private_model, updated_model,private_train_ids,hidden_masks['private']['train'], aug_style)
heldout_vals=get_calibrated_confidences(params, private_model, updated_model,private_heldout_ids,hidden_masks['private']['heldout'], aug_style)
elif attack_type=='dist':
private_train_ids=private_train_ids[np.random.choice(len(private_train_ids), size=params.num_points, replace=False)]
private_heldout_ids=private_heldout_ids[np.random.choice(len(private_heldout_ids), size=params.num_points, replace=False)]
train_vals=get_calibrated_distances(params, private_model, updated_model,private_train_ids)
heldout_vals=get_calibrated_distances(params, private_model, updated_model,private_heldout_ids)
else:
original_private_model=[]
for p in private_model.parameters():
original_private_model.append(p.view(-1))
original_private_model=torch.cat(original_private_model)
original_updated_model=[]
for p in new_model.parameters():
original_updated_model.append(p.view(-1))
original_updated_model=torch.cat(original_updated_model)
if i==0:
private_model=GradSampleModule(private_model)
attack_model=GradSampleModule(new_model)
train_vals=get_calibrated_gradnorm(params, private_model,original_private_model, attack_model,original_updated_model,private_train_ids,hidden_masks['private']['train'], aug_style=aug_style, norm_type=norm_type)
heldout_vals=get_calibrated_gradnorm(params, private_model, original_private_model,attack_model,original_updated_model,private_heldout_ids,hidden_masks['private']['heldout'], aug_style=aug_style,norm_type=norm_type)
if aux_style=='max':
train_losses=np.maximum(train_losses, train_vals)
heldout_losses=np.maximum(heldout_losses, heldout_vals)
else:
if params.attack_type=='conf' or params.attack_type=='dist':
train_losses=train_vals
heldout_losses=heldout_vals
else:
train_losses+=train_vals
heldout_losses+=heldout_vals
if aux_style=='mean':
train_losses=train_losses/num_aux
heldout_losses=heldout_losses/num_aux
return train_losses, heldout_losses