def loss_perturbation_method()

in private_prediction.py [0:0]


def loss_perturbation_method(data, args, visualizer=None, title=None):
    """
    Given a dataset `data` and arguments `args`, run a full test of the private
    prediction algorithms of Chaudhuri et al. (2011) / Kifer et al. (2012)
    generalized to the multi-class setting. Returns a `dict` containing the
    `predictions` for the training and test data.

    Note: This algorithm only guarantees privacy under the following assumptions:
    - The loss is strictly convex and has a continuous Hessian.
    - The model is linear.
    - The inputs have a 2-norm restricted to be less than or equal 1.
    - The Lipschitz constant of the loss function and the spectral
        norm of the Hessian must be bounded.
    """
    assert args.model == "linear", f"Model {args.model} not supported."
    assert args.noise_dist != "advanced_gaussian", \
        "Advanced Gaussian method not supported for loss perturbation."

    # get dataset properties:
    num_classes = int(data["train"]["targets"].max()) + 1
    num_samples, num_features = data["train"]["features"].size()

    # initialize model and criterion:
    model = modeling.initialize_model(num_features, num_classes, device=args.device)
    criterion = nn.CrossEntropyLoss()

    precision, weight_decay = loss_perturbation_params(
        args.epsilon, args.delta, args.noise_dist,
        criterion, num_samples, num_classes)
    weight_decay = max(weight_decay, args.weight_decay)

    # sample loss perturbation vector:
    param = modeling.get_parameter_vector(model)
    mean = torch.zeros_like(param)
    perturbation = getattr(noise, args.noise_dist)(mean, precision)
    perturbations = [torch.zeros_like(p) for p in model.parameters()]
    modeling.set_parameter_vector(perturbations, perturbation)

    # closure implementing the loss-perturbation criterion:
    def loss_perturbation_criterion(predictions, targets):
        loss = criterion(predictions, targets)
        for param, perturb in zip(model.parameters(), perturbations):
            loss += ((param * perturb).sum() / num_samples)
        return loss

    # add L2-regularizer to the loss:
    regularized_criterion = modeling.add_l2_regularization(
        loss_perturbation_criterion, model, weight_decay
    )

    # train classifier:
    logging.info("Training classifier with loss perturbation...")
    modeling.train_model(model, data["train"],
                         criterion=regularized_criterion,
                         optimizer=args.optimizer,
                         num_epochs=args.num_epochs,
                         learning_rate=args.learning_rate,
                         batch_size=args.batch_size,
                         visualizer=visualizer,
                         title=title)

    # perform inference on both training and test set:
    logging.info("Performing inference with loss-perturbed predictor...")
    predictions = {split: model(data_split["features"]).argmax(dim=1)
                   for split, data_split in data.items()}
    return predictions