def main()

in sample_info/scripts/compute_influence_functions_brute_force.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist4vs9',
                        choices=['mnist4vs9', 'synthetic', 'cifar10-cat-vs-dog'],
                        help='Which dataset to use. One can add more choices if needed.')
    parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation', action='store_true', default=False)
    parser.add_argument('--resize_to_imagenet', action='store_true', dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset', action='store_true', dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)

    parser.add_argument('--output_dir', '-o', type=str, default='sample_info/results/ground-truth/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    args = parser.parse_args()
    print(args)

    # Build data
    train_data, val_data, test_data, _ = load_data_from_arguments(args, build_loaders=False)
    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    train_loader, val_loader, test_loader = get_loaders_from_datasets(train_data, val_data, test_data,
                                                                      batch_size=2 ** 30,
                                                                      shuffle_train=False)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        seed=args.seed,
                        device=args.device)

    # load the final parameters
    saved_file_path = os.path.join(args.output_dir, 'ground-truth', args.exp_name, 'full-data-training.pkl')
    with open(saved_file_path, 'rb') as f:
        saved_data = pickle.load(f)

    params = dict(model.named_parameters())
    for k, v in saved_data['weights'].items():
        params[k].data = v.to(args.device)

    # brute force compute hessian and its inverse
    total_loss = 0.0
    for x, y in train_loader:
        out = model.forward(inputs=[x], labels=[y])
        losses, _ = model.compute_loss(inputs=[x], labels=[y], outputs=out)
        total_loss = total_loss + sum([v for k, v in losses.items()])

    with utils.Timing(description='Computing the Hessian'):
        H = hessian(ys=[total_loss], xs=tuple(model.parameters()))

    params = tuple(model.parameters())
    for i in range(len(H)):
        for j in range(len(H[i])):
            ni = params[i].nelement()
            nj = params[j].nelement()
            H[i][j] = H[i][j].reshape((ni, nj))
        H[i] = torch.cat(H[i], dim=1)
    H = torch.cat(H, dim=0)
    # add extra eps to the diagonal to make it invertible
    if args.l2_reg_coef < 1e-10:
        H += 1e-10 * torch.eye(H.shape[0], dtype=torch.float, device=H.device)
    print(f"Hessian shape: {H.shape}")
    H_inv = torch.inverse(H)

    # compute per example gradients (d loss / d weights for train and d pred / d weights for validation)
    train_grads = gradients.get_weight_gradients(model=model, dataset=train_data, cpu=False,
                                                 description='computing per example gradients on train data')

    jacobian_estimator = JacobianEstimator()
    val_grads = jacobian_estimator.compute_jacobian(model=model, dataset=val_data, cpu=False,
                                                    description='computing jacobian on validation data')

    # compute weight and prediction influences
    weight_vectors = []
    weight_quantities = []

    pred_vectors = []
    pred_quantities = []

    for sample_idx in tqdm(range(len(train_data)), desc='computing influences'):
        # compute for weights
        train_grad_flat = []
        for k, v in dict(model.named_parameters()).items():
            train_grad_flat.append(train_grads[k][sample_idx].flatten())
        train_grad_flat = torch.cat(train_grad_flat, dim=0)

        cur_weight_influence = 1.0 / len(train_data) * torch.mm(H_inv, train_grad_flat.view((-1, 1)))
        cur_weight_influence = cur_weight_influence.view((-1,))
        weight_vectors.append(cur_weight_influence)
        weight_quantities.append(torch.sum(cur_weight_influence ** 2))

        # compute for predictions
        cur_pred_influences = []
        for val_sample_idx in range(len(val_data)):
            val_grad_flat = []
            for k, v in dict(model.named_parameters()).items():
                val_grad_flat.append(val_grads[k][val_sample_idx].flatten())
            val_grad_flat = torch.cat(val_grad_flat, dim=0)
            cur_pred_influences.append(torch.dot(cur_weight_influence, val_grad_flat))

        cur_pred_influences = torch.stack(cur_pred_influences)
        pred_vectors.append(cur_pred_influences)
        pred_quantities.append(torch.sum(cur_pred_influences ** 2))

    # save weights
    meta = {
        'description': f'weight influence functions',
        'args': args
    }

    exp_dir = os.path.join(args.output_dir, 'influence-functions-brute-force', args.exp_name)
    process_results(vectors=weight_vectors, quantities=weight_quantities, meta=meta,
                    exp_name='weights', output_dir=exp_dir, train_data=train_data)

    # save preds
    meta = {
        'description': f'pred influence functions',
        'args': args
    }

    exp_dir = os.path.join(args.output_dir, 'influence-functions-brute-force', args.exp_name)
    process_results(vectors=pred_vectors, quantities=pred_quantities, meta=meta,
                    exp_name='pred', output_dir=exp_dir, train_data=train_data)