in train_svhn.py [0:0]
def main():
parser = argparse.ArgumentParser(description='Training an SVHN model')
parser.add_argument('--data-dir', type=str, required=True, help='directory for SVHN data')
parser.add_argument('--save-dir', type=str, default='save', help='directory for saving trained model')
parser.add_argument('--batch-size', type=int, default=500, help='batch size for training')
parser.add_argument('--process-batch-size', type=int, default=500, help='batch size for processing')
parser.add_argument('--test-batch-size', type=int, default=1000, help='batch size for testing')
parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--lam', type=float, default=0, help='L2 regularization')
parser.add_argument('--std', type=float, default=6.0, help='noise multiplier for DP training')
parser.add_argument('--delta', type=float, default=1e-5, help='delta for DP training')
parser.add_argument('--num-filters', type=int, default=64, help='number of conv filters')
parser.add_argument('--seed', type=int, default=1, help='manual random seed')
parser.add_argument('--log-interval', type=int, default=10,
help='logging interval')
parser.add_argument('--train-mode', type=str, default='default', help='train mode [default/private/full_private]')
parser.add_argument('--test-mode', type=str, default='default', help='test mode [default/linear/extract]')
parser.add_argument('--save-suffix', type=str, default='', help='suffix for model name')
parser.add_argument('--normalize', action='store_true', default=False,
help='normalize extracted features')
parser.add_argument('--single-layer', action='store_true', default=False,
help='single convolutional layer')
parser.add_argument('--save-model', action='store_true', default=False,
help='for saving the trained model')
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda")
kwargs = {'num_workers': 1, 'pin_memory': True}
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0)),
])
trainset = torchvision.datasets.SVHN(root=args.data_dir, split='train', download=True, transform=transform)
extraset = torchvision.datasets.SVHN(root=args.data_dir, split='extra', download=True, transform=transform)
trainset = torch.utils.data.ConcatDataset([trainset, extraset])
testset = torchvision.datasets.SVHN(root=args.data_dir, split='test', download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs)
if args.single_layer:
extr = FastGradExtractor([3, args.num_filters], 9, 1, 2, normalize=args.normalize).to(device)
clf = FastGradMLP([12*12*args.num_filters, 10]).to(device)
else:
extr = FastGradExtractor([3, args.num_filters, args.num_filters], 5, 1, 2, normalize=args.normalize).to(device)
clf = FastGradMLP([5*5*args.num_filters, 10]).to(device)
loss_fn = lambda x, y: F.nll_loss(F.log_softmax(x, dim=1), y)
save_path = "%s/svhn_cnn_delta_%.2e_std_%.2f%s.pth" % (args.save_dir, args.delta, args.std, args.save_suffix)
if not os.path.exists(save_path):
optimizer = optim.Adam(list(extr.parameters()) + list(clf.parameters()), lr=args.lr, weight_decay=args.lam)
C = 4
n = len(train_loader.dataset)
q = float(args.batch_size) / float(n)
T = args.epochs * len(train_loader)
# compute privacy loss using RDP analysis
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
list(range(5, 64)) + [128, 256, 512, 1024, 2048, 4096])
epsilon, _ = get_privacy_spent(orders, compute_rdp(q, args.std, T, orders), args.delta)
print('RDP computed privacy loss: epsilon = %.2f at delta = %.2e' % (epsilon, args.delta))
start = time.time()
for epoch in range(1, args.epochs + 1):
if args.train_mode == 'private' or args.train_mode == 'full_private':
include_linear = (args.train_mode == 'full_private')
train_private(args, extr, clf, loss_fn, device, train_loader, optimizer, epoch, C, args.std, include_linear=include_linear)
else:
train(args, extr, clf, loss_fn, device, train_loader, optimizer, epoch)
test(args, extr, clf, loss_fn, device, test_loader)
print(time.time() - start)
if args.save_model:
torch.save({'extr': extr.state_dict(), 'clf': clf.state_dict()}, save_path)
else:
checkpoint = torch.load(save_path)
extr.load_state_dict(checkpoint['extr'])
clf.load_state_dict(checkpoint['clf'])
if args.test_mode == 'linear':
test_linear(args, extr, device, train_loader, test_loader)
elif args.test_mode == 'extract':
# this option can be used to extract features for training the removal-enabled linear model
X_train, y_train = utils.extract_features(extr, device, train_loader)
X_test, y_test = utils.extract_features(extr, device, test_loader)
torch.save({'X_train': X_train, 'y_train': y_train, 'X_test': X_test, 'y_test': y_test},
'%s/dp_delta_%.2e_std_%.2f_SVHN_extracted.pth' % (args.data_dir, args.delta, args.std))
else:
test(args, extr, clf, loss_fn, device, test_loader)