def main()

in sample_info/scripts/compute_influence_functions.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--cpu', dest='cpu', action='store_true')
    parser.set_defaults(cpu=False)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist4vs9',
                        choices=['mnist4vs9', 'synthetic', 'cifar10-cat-vs-dog', 'cats-and-dogs'],
                        help='Which dataset to use. One can add more choices if needed.')
    parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation', action='store_true', default=False)
    parser.add_argument('--resize_to_imagenet', action='store_true', dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset', action='store_true', dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--damping', type=float, default=1e-10)
    parser.add_argument('--scale', type=float, default=10.0)
    parser.add_argument('--recursion_depth', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=128)

    parser.add_argument('--output_dir', '-o', type=str, default='sample_info/results/ground-truth/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    args = parser.parse_args()
    print(args)

    # Build data
    train_data, val_data, test_data, _ = load_data_from_arguments(args, build_loaders=False)
    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        seed=args.seed,
                        device=args.device)

    # load the final parameters
    saved_file_path = os.path.join(args.output_dir, 'ground-truth', args.exp_name, 'full-data-training.pkl')
    with open(saved_file_path, 'rb') as f:
        saved_data = pickle.load(f)

    params = dict(model.named_parameters())
    for k, v in saved_data['weights'].items():
        params[k].data = v.to(args.device)

    # compute per example gradients (d loss / d weights for train and d pred / d weights for validation)
    train_grads = gradients.get_weight_gradients(model=model, dataset=train_data, cpu=args.cpu,
                                                 description='computing per example gradients on train data')

    jacobian_estimator = JacobianEstimator()
    val_grads = jacobian_estimator.compute_jacobian(model=model, dataset=val_data, cpu=args.cpu,
                                                    description='computing jacobian on validation data')

    # compute weight and prediction influences
    weight_vectors = []
    weight_quantities = []

    pred_vectors = []
    pred_quantities = []

    for sample_idx in tqdm(range(len(train_data)), desc='computing influences'):
        # compute weights
        v = []
        for k in dict(model.named_parameters()).keys():
            v.append(train_grads[k][sample_idx].to(model.device))
        inv_hvp = inverse_hvp_lissa(model, dataset=train_data, v=v, batch_size=args.batch_size,
                                    recursion_depth=args.recursion_depth, damping=args.damping,
                                    scale=args.scale)
        if args.cpu:
            inv_hvp = [utils.to_cpu(a) for a in inv_hvp]

        for a in inv_hvp:
            if torch.isnan(a).any():
                raise ValueError("Inverse hessian vector product contains NaNs. Increase the scale.")

        cur_weight_influence = 1.0 / len(train_data) * torch.cat([a.flatten() for a in inv_hvp])
        weight_vectors.append(cur_weight_influence)
        weight_quantities.append(torch.sum(cur_weight_influence ** 2))

        # compute for predictions
        cur_pred_influences = []
        for val_sample_idx in range(len(val_data)):
            val_grad_flat = []
            for k, v in dict(model.named_parameters()).items():
                val_grad_flat.append(val_grads[k][val_sample_idx].flatten())
            val_grad_flat = torch.cat(val_grad_flat, dim=0)
            cur_pred_influences.append(torch.dot(cur_weight_influence, val_grad_flat))

        cur_pred_influences = torch.stack(cur_pred_influences)
        pred_vectors.append(cur_pred_influences)
        pred_quantities.append(torch.sum(cur_pred_influences ** 2))

    # save weights
    meta = {
        'description': f'weight influence functions',
        'args': args
    }

    exp_dir = os.path.join(args.output_dir, 'influence-functions', args.exp_name)
    process_results(vectors=weight_vectors, quantities=weight_quantities, meta=meta,
                    exp_name='weights', output_dir=exp_dir, train_data=train_data)

    # save preds
    meta = {
        'description': f'pred influence functions',
        'args': args
    }

    exp_dir = os.path.join(args.output_dir, 'influence-functions', args.exp_name)
    process_results(vectors=pred_vectors, quantities=pred_quantities, meta=meta,
                    exp_name='pred', output_dir=exp_dir, train_data=train_data)