def main()

in fisher_experiment.py [0:0]


def main(args):
    regression = (args.dataset == "iwpc")
    data = dataloading.load_dataset(
        name=args.dataset, split="train", normalize=not args.no_norm,
        num_classes=2, root=args.data_folder, regression=regression)
    test_data = dataloading.load_dataset(
        name=args.dataset, split="test", normalize=not args.no_norm,
        num_classes=2, root=args.data_folder, regression=regression)
    if args.pca_dims > 0:
        data, pca = dataloading.pca(data, num_dims=args.pca_dims)
        test_data, _ = dataloading.pca(test_data, mapping=pca)

    model = models.get_model(args.model)

    # Find the optimal parameters for the model:
    logging.info(f"Training model {args.model}")
    model.train(data, l2=args.l2)

    # Check predictions for sanity:
    accuracy, _ = compute_accuracy(model, data, regression=regression)
    if regression:
        logging.info("Training MSE of classifier {:.3f}".format(accuracy))
    else:
        logging.info("Training accuracy of classifier {:.3f}".format(accuracy))

    # Compute the Jacobian of the influence of each example on the optimal
    # parameters:
    logging.info(f"Computing influence Jacobian on training set...")
    start = time.time()
    J = model.influence_jacobian(data)
    time_per_sample = 1e3 * (time.time() - start) / len(data["targets"])
    logging.info("Time taken per example {:.3f} (ms)".format(time_per_sample))

    # Compute the Fisher information loss from the FIM (J^T J) for each example
    # in the training set (J^T J is the Fisher information with Gaussian output
    # perturbation on the parameters at a scale of 1):
    start = time.time()
    logging.info(f"Computing Fisher information loss...")
    etas = models.compute_information_loss(J)
    time_per_sample = 1e3 * (time.time() - start) / len(etas)
    logging.info(
        "Computed {} examples, maximum eta: {:.3f}, "
        "time per sample {:.3f} (ms).".format(
            len(etas), max(etas), time_per_sample))

    # Compute some comparison points:
    losses, weight_dots, grad_norms = eval_comparison_stats(model, data)

    # Retrain the model and measure the new etas if removing most lossy
    # examples:
    if args.clip > 0:
        clipped_data = clip_data(data, etas, args.clip)
        logging.info(
            "Kept {}/{} samples, retrain and compute eta..".format(
                len(clipped_data["targets"]), len(data["targets"])))
        model.train(clipped_data, l2=args.l2)
        J = model.influence_jacobian(clipped_data)
        etas = models.compute_information_loss(J)
        etamax = max(etas)
    else:
        etamax = max(etas)

    # Measure the test accuracy as a function of the noise needed to attain a
    # desired eta:
    accuracies = []
    stds = []
    for eta in args.etas:
        # Compute the Gaussian noise scale needed for eta:
        scale = etamax / eta
        # Measure test accuracy:
        accuracy, std = compute_accuracy(
            model, test_data, noise_scale=scale, trials=args.trials,
            regression=regression)
        accuracies.append(accuracy)
        stds.append(std)

    results = {
        "clip" : args.clip,
        "accuracies" : accuracies,
        "stds" : stds,
        "etas" : etas.tolist(),
        "train_losses" : losses,
        "train_dot_weights" : weight_dots,
        "train_grad_norms" : grad_norms,
    }
    with open(args.results_file, 'w') as fid:
        json.dump(results, fid)