def get_parser()

in attacks/privacy_attacks.py [0:0]


def get_parser():
    """
    Generate a parameters parser.
    """
    parser = argparse.ArgumentParser(description='Privacy attack parameters')

    # config parameters
    parser.add_argument("--dump_path", type=str, default=None) # model saving location
    parser.add_argument('--print_freq', type=int, default=50) # training printing frequency
    parser.add_argument("--save_periodic", type=int, default=0) # training saving frequency

    # attack parameters 
    parser.add_argument("--model_path", type=str, default="model") # path to the private model
    parser.add_argument("--attack_type", type=str, default="loss") # type of auxiliary attack
    parser.add_argument("--aux_epochs", type=int, default=20) # number of auxiliary training epochs
    parser.add_argument("--num_aux", type=int, default=1) # number of auxiliary models 
    parser.add_argument("--aug_style", type=str, default="mean") # combination method for augmented data values 
    parser.add_argument("--aux_style", type=str, default="sum") # combination method for multiple aux. model values 
    parser.add_argument("--public_data", type=str, default="train") # specify which part of the public data to use for aux model training (e.g. train is the training mask, rand50 is a random selection of the public data) 
    parser.add_argument("--norm_type", type=str, default=None) # norm for gradient norm  
    parser.add_argument("--num_points", type=int, default=10) # number of points to use for the label-only attack
    parser.add_argument("--clip_min", type=float, default=0) # minimum value for adversarial feature in label-only attack
    parser.add_argument("--clip_max", type=float, default=1) # maximum value for adversarial feature in label-only attack

    # Data parameters
    parser.add_argument("--data_root", type=str, default="data") # path to the data
    parser.add_argument("--dataset", type=str, choices=["cifar10", "imagenet", "cifar100", "gaussian","credit", "hep", "adult", "mnist", "lfw"], default="cifar10")
    parser.add_argument("--mask_path", type=str, required=True) # path to the data mask
    parser.add_argument('--n_data', type=int, default=500) # specify number of data points for gaussian data
    parser.add_argument('--data_num_dimensions', type=int, default=75) # number of features for non-image data
    parser.add_argument('--random_seed', type=int, default=10) # seed for gaussian data 
    parser.add_argument("--num_classes", type=int, default=10) # number of classes for classification task 
    parser.add_argument("--in_channels", type=int, default=3) # number of input channels for image data

    # Model parameters
    parser.add_argument("--architecture", choices=["lenet", "smallnet", "resnet18", "kllenet","linear", "mlp"], default="lenet")
    
    # training parameters
    parser.add_argument("--aug", type=bool_flag, default=False) # data augmentation flag 
    parser.add_argument("--batch_size", type=int, default=32) 
    parser.add_argument("--epochs", type=int, default=50) 
    parser.add_argument("--optimizer", default="sgd,lr=0.1,momentum=0.9")
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--log_gradients", type=bool_flag, default=False) 
    parser.add_argument("--log_batch_models", type=bool_flag, default=False) # save model for each batch of data
    parser.add_argument("--log_epoch_models", type=bool_flag, default=False) # save model for each training epoch

    # privacy parameters
    parser.add_argument("--private", type=bool_flag, default=False) # privacy flag 
    parser.add_argument("--noise_multiplier", type=float, default=None)
    parser.add_argument("--privacy_epsilon", type=float, default=None)
    parser.add_argument("--privacy_delta", type=float, default=None)
    parser.add_argument("--max_grad_norm", type=float, default=1.0)

    #multi gpu paramaeters
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--master_port", type=int, default=-1)
    parser.add_argument("--debug_slurm", type=bool_flag, default=False)

    return parser