in private_model_inversion.py [0:0]
def main(args):
regression = (args.dataset == "iwpc" or args.dataset == "synth")
data = dataloading.load_dataset(
name=args.dataset, split="train", normalize=False,
num_classes=2, root=args.data_folder, regression=regression)
test_data = dataloading.load_dataset(
name=args.dataset, split="test", normalize=False,
num_classes=2, root=args.data_folder, regression=regression)
if args.subsample > 0:
data = dataloading.subsample(data, args.subsample)
if args.weights_file is not None:
all_weights = torch.load(args.weights_file)
else:
all_weights = [torch.ones(len(data["targets"]))]
results = []
for it, weights in enumerate(all_weights):
if len(all_weights) > 1:
logging.info(f"Iteration {it} weights for model inversion.")
results.append(run_inversion(args, data, test_data, weights))
if args.results_file is not None:
with open(args.results_file, 'w') as fid:
json.dump(results, fid)