def main()

in example_zoo/tensorflow/probability/bayesian_neural_network/trainer/bayesian_neural_network.py [0:0]


def main(argv):
  del argv  # unused
  if tf.io.gfile.exists(FLAGS.model_dir):
    tf.compat.v1.logging.warning(
        "Warning: deleting old log directory at {}".format(FLAGS.model_dir))
    tf.io.gfile.rmtree(FLAGS.model_dir)
  tf.io.gfile.makedirs(FLAGS.model_dir)

  if FLAGS.fake_data:
    mnist_data = build_fake_data()
  else:
    mnist_data = mnist.read_data_sets(FLAGS.data_dir, reshape=False)

  (images, labels, handle,
   training_iterator, heldout_iterator) = build_input_pipeline(
       mnist_data, FLAGS.batch_size, mnist_data.validation.num_examples)

  # Build a Bayesian LeNet5 network. We use the Flipout Monte Carlo estimator
  # for the convolution and fully-connected layers: this enables lower
  # variance stochastic gradients than naive reparameterization.
  with tf.compat.v1.name_scope("bayesian_neural_net", values=[images]):
    neural_net = tf.keras.Sequential([
        tfp.layers.Convolution2DFlipout(6,
                                        kernel_size=5,
                                        padding="SAME",
                                        activation=tf.nn.relu),
        tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                     strides=[2, 2],
                                     padding="SAME"),
        tfp.layers.Convolution2DFlipout(16,
                                        kernel_size=5,
                                        padding="SAME",
                                        activation=tf.nn.relu),
        tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                     strides=[2, 2],
                                     padding="SAME"),
        tfp.layers.Convolution2DFlipout(120,
                                        kernel_size=5,
                                        padding="SAME",
                                        activation=tf.nn.relu),
        tf.keras.layers.Flatten(),
        tfp.layers.DenseFlipout(84, activation=tf.nn.relu),
        tfp.layers.DenseFlipout(10)
        ])

    logits = neural_net(images)
    labels_distribution = tfd.Categorical(logits=logits)

  # Compute the -ELBO as the loss, averaged over the batch size.
  neg_log_likelihood = -tf.reduce_mean(
      input_tensor=labels_distribution.log_prob(labels))
  kl = sum(neural_net.losses) / mnist_data.train.num_examples
  elbo_loss = neg_log_likelihood + kl

  # Build metrics for evaluation. Predictions are formed from a single forward
  # pass of the probabilistic layers. They are cheap but noisy predictions.
  predictions = tf.argmax(input=logits, axis=1)
  accuracy, accuracy_update_op = tf.compat.v1.metrics.accuracy(
      labels=labels, predictions=predictions)

  # Extract weight posterior statistics for layers with weight distributions
  # for later visualization.
  names = []
  qmeans = []
  qstds = []
  for i, layer in enumerate(neural_net.layers):
    try:
      q = layer.kernel_posterior
    except AttributeError:
      continue
    names.append("Layer {}".format(i))
    qmeans.append(q.mean())
    qstds.append(q.stddev())

  with tf.compat.v1.name_scope("train"):
    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate)
    train_op = optimizer.minimize(elbo_loss)

  init_op = tf.group(tf.compat.v1.global_variables_initializer(),
                     tf.compat.v1.local_variables_initializer())

  with tf.compat.v1.Session() as sess:
    sess.run(init_op)

    # Run the training loop.
    train_handle = sess.run(training_iterator.string_handle())
    heldout_handle = sess.run(heldout_iterator.string_handle())
    for step in range(FLAGS.max_steps):
      _ = sess.run([train_op, accuracy_update_op],
                   feed_dict={handle: train_handle})

      if step % 100 == 0:
        loss_value, accuracy_value = sess.run(
            [elbo_loss, accuracy], feed_dict={handle: train_handle})
        print("Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f}".format(
            step, loss_value, accuracy_value))

      if (step+1) % FLAGS.viz_steps == 0:
        # Compute log prob of heldout set by averaging draws from the model:
        # p(heldout | train) = int_model p(heldout|model) p(model|train)
        #                   ~= 1/n * sum_{i=1}^n p(heldout | model_i)
        # where model_i is a draw from the posterior p(model|train).
        probs = np.asarray([sess.run((labels_distribution.probs),
                                     feed_dict={handle: heldout_handle})
                            for _ in range(FLAGS.num_monte_carlo)])
        mean_probs = np.mean(probs, axis=0)

        image_vals, label_vals = sess.run((images, labels),
                                          feed_dict={handle: heldout_handle})
        heldout_lp = np.mean(np.log(mean_probs[np.arange(mean_probs.shape[0]),
                                               label_vals.flatten()]))
        print(" ... Held-out nats: {:.3f}".format(heldout_lp))

        qm_vals, qs_vals = sess.run((qmeans, qstds))

        if HAS_SEABORN:
          plot_weight_posteriors(names, qm_vals, qs_vals,
                                 fname=os.path.join(
                                     FLAGS.model_dir,
                                     "step{:05d}_weights.png".format(step)))

          plot_heldout_prediction(image_vals, probs,
                                  fname=os.path.join(
                                      FLAGS.model_dir,
                                      "step{:05d}_pred.png".format(step)),
                                  title="mean heldout logprob {:.2f}"
                                  .format(heldout_lp))