def model_sensitivity_method()

in private_prediction.py [0:0]


def model_sensitivity_method(data, args, visualizer=None, title=None):
    """
    Given a dataset `data` and arguments `args`, run a full test of private
    prediction using the model sensitivity method.

    Note: This algorithm only guarantees privacy for models with convex losses.
    """
    assert args.model == "linear", f"Model {args.model} not supported."

    # initialize model and criterion:
    num_classes = int(data["train"]["targets"].max()) + 1
    num_samples, num_features = data["train"]["features"].size()
    model = modeling.initialize_model(num_features, num_classes, device=args.device)
    criterion = nn.CrossEntropyLoss()
    regularized_criterion = modeling.add_l2_regularization(
        criterion, model, args.weight_decay
    )

    # train classifier:
    logging.info("Training non-private classifier...")
    modeling.train_model(model, data["train"],
                         criterion=regularized_criterion,
                         optimizer=args.optimizer,
                         num_epochs=args.num_epochs,
                         learning_rate=args.learning_rate,
                         batch_size=args.batch_size,
                         visualizer=visualizer,
                         title=title)

    # perturb model parameters:
    logging.info("Applying model sensitivity method...")
    scale = sensitivity_scale(args.epsilon, args.delta, args.weight_decay,
                              criterion, num_samples, args.noise_dist)
    param = modeling.get_parameter_vector(model)
    mean = torch.zeros_like(param)
    noise_dist = "gaussian" if args.noise_dist in ["gaussian", "advanced_gaussian"] \
        else args.noise_dist
    perturbation = getattr(noise, noise_dist)(mean, scale)

    with torch.no_grad():
        param.add_(perturbation)
    modeling.set_parameter_vector(model, param)

    # perform inference on both training and test set:
    logging.info("Performing inference with perturbed predictor...")
    predictions = {split: modeling.test_model(model, data_split).argmax(dim=1)
                   for split, data_split in data.items()}
    return predictions