def main()

in private_model_inversion.py [0:0]


def main(args):
    regression = (args.dataset == "iwpc" or args.dataset == "synth")
    data = dataloading.load_dataset(
        name=args.dataset, split="train", normalize=False,
        num_classes=2, root=args.data_folder, regression=regression)
    test_data = dataloading.load_dataset(
        name=args.dataset, split="test", normalize=False,
        num_classes=2, root=args.data_folder, regression=regression)

    if args.subsample > 0:
        data = dataloading.subsample(data, args.subsample)

    if args.weights_file is not None:
        all_weights = torch.load(args.weights_file)
    else:
        all_weights = [torch.ones(len(data["targets"]))]

    results = []
    for it, weights in enumerate(all_weights):
        if len(all_weights) > 1:
            logging.info(f"Iteration {it} weights for model inversion.")
        results.append(run_inversion(args, data, test_data, weights))

    if args.results_file is not None:
        with open(args.results_file, 'w') as fid:
            json.dump(results, fid)