def main()

in reweighted.py [0:0]


def main(args):
    regression = args.dataset == "iwpc" or args.dataset == "synth"
    data = dataloading.load_dataset(
        name=args.dataset, split="train", normalize=not args.no_norm,
        num_classes=2, root=args.data_folder, regression=regression)
    test_data = dataloading.load_dataset(
        name=args.dataset, split="test", normalize=not args.no_norm,
        num_classes=2, root=args.data_folder, regression=regression)
    if args.pca_dims > 0:
        data, pca = dataloading.pca(data, num_dims=args.pca_dims)
        test_data, _ = dataloading.pca(test_data, mapping=pca)

    model = models.get_model(args.model)

    # Find the optimal parameters for the model:
    logging.info(f"Training {args.model} model.")
    model.train(data, l2=args.l2)

    train_accuracy = compute_accuracy(model, data, regression=regression)
    test_accuracy = compute_accuracy(model, test_data, regression=regression)
    if regression:
        logging.info(f"MSE train {train_accuracy:.3f},"
            f" test: {test_accuracy:.3f}.")
    else:
        logging.info(f"Accuracy train {train_accuracy:.3f},"
            f" test: {test_accuracy:.3f}.")

    # Compute the Fisher information loss, eta, for each example in the
    # training set:
    logging.info("Computing unweighted etas on training set...")
    J = model.influence_jacobian(data)
    etas = models.compute_information_loss(J, target_attribute=args.attribute,
                                           constrained=args.constrained)
    logging.info(f"etas max: {etas.max().item():.4f},"
        f" mean: {etas.mean().item():.4f}, std: {etas.std().item():.4f}.")

    # Reweight using the fisher information loss:
    updated_fi = etas.reciprocal().detach()
    maxs = [etas.max().item()]
    means = [etas.mean().item()]
    stds = [etas.std().item()]
    train_accs = [train_accuracy]
    test_accs = [test_accuracy]
    all_weights = [torch.ones(len(updated_fi))]
    for i in range(args.iters):
        logging.info(f"Iter {i}: Training weighted model...")
        updated_fi *= (len(updated_fi) / updated_fi.sum())
        # TODO does it make sense to renormalize after clamping?
        updated_fi.clamp_(min=args.min_weight, max=args.max_weight)
        weights = get_weights(args.weight_method, updated_fi, data)
        model.train(data, l2=args.l2, weights=weights.detach())

        # Check predictions of weighted model:
        train_accuracy = compute_accuracy(model, data, regression=regression)
        test_accuracy = compute_accuracy(model, test_data, regression=regression)
        if regression:
            logging.info(f"Weighted model MSE train {train_accuracy:.3f},"
                f" test: {test_accuracy:.3f}.")
        else:
            logging.info(f"Weighted model accuracy train {train_accuracy:.3f},"
                f" test: {test_accuracy:.3f}.")

        J = model.influence_jacobian(data)
        weighted_etas = models.compute_information_loss(J, target_attribute=args.attribute,
                                                        constrained=args.constrained)
        updated_fi /= weighted_etas
        maxs.append(weighted_etas.max().item())
        means.append(weighted_etas.mean().item())
        stds.append(weighted_etas.std().item())
        train_accs.append(train_accuracy)
        test_accs.append(test_accuracy)
        all_weights.append(weights)
        logging.info(f"Weighted etas max: {maxs[-1]:.4f},"
            f" mean: {means[-1]:.4f},"
            f" std: {stds[-1]:.4f}.")

    results = {
        "weights" : weights.tolist(),
        "etas" : etas.tolist(),
        "weighted_etas" : weighted_etas.tolist(),
        "eta_maxes" : maxs,
        "eta_means" : means,
        "eta_stds" : stds,
        "train_accs" : train_accs,
        "test_accs" : test_accs,
    }

    with open(args.results_file + ".json", 'w') as fid:
        json.dump(results, fid)
    torch.save(torch.stack(all_weights), args.results_file + ".pth")