def main()

in research/GDP_2019/adult_tutorial.py [0:0]


def main(unused_argv):
  tf.compat.v1.logging.set_verbosity(0)

  # Load training and test data.
  train_data, train_labels, test_data, test_labels = load_adult()

  # Instantiate the tf.Estimator.
  adult_classifier = tf.compat.v1.estimator.Estimator(
      model_fn=nn_model_fn, model_dir=FLAGS.model_dir)

  # Create tf.Estimator input functions for the training and test data.
  eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
      x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)

  # Training loop.
  steps_per_epoch = num_examples // sampling_batch
  test_accuracy_list = []
  for epoch in range(1, FLAGS.epochs + 1):
    for _ in range(steps_per_epoch):
      whether = np.random.random_sample(num_examples) > (
          1 - sampling_batch / num_examples)
      subsampling = [i for i in np.arange(num_examples) if whether[i]]
      global microbatches
      microbatches = len(subsampling)

      train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
          x={'x': train_data[subsampling]},
          y=train_labels[subsampling],
          batch_size=len(subsampling),
          num_epochs=1,
          shuffle=True)
      # Train the model for one step.
      adult_classifier.train(input_fn=train_input_fn, steps=1)

    # Evaluate the model and print results
    eval_results = adult_classifier.evaluate(input_fn=eval_input_fn)
    test_accuracy = eval_results['accuracy']
    test_accuracy_list.append(test_accuracy)
    print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))

    # Compute the privacy budget expended so far.
    if FLAGS.dpsgd:
      eps = compute_eps_poisson(epoch, FLAGS.noise_multiplier, num_examples,
                                sampling_batch, 1e-5)
      mu = compute_mu_poisson(epoch, FLAGS.noise_multiplier, num_examples,
                              sampling_batch)
      print('For delta=1e-5, the current epsilon is: %.2f' % eps)
      print('For delta=1e-5, the current mu is: %.2f' % mu)

      if mu > FLAGS.max_mu:
        break
    else:
      print('Trained with vanilla non-private SGD optimizer')