in example_zoo/tensorflow/probability/bayesian_neural_network/trainer/bayesian_neural_network.py [0:0]
def main(argv):
del argv # unused
if tf.io.gfile.exists(FLAGS.model_dir):
tf.compat.v1.logging.warning(
"Warning: deleting old log directory at {}".format(FLAGS.model_dir))
tf.io.gfile.rmtree(FLAGS.model_dir)
tf.io.gfile.makedirs(FLAGS.model_dir)
if FLAGS.fake_data:
mnist_data = build_fake_data()
else:
mnist_data = mnist.read_data_sets(FLAGS.data_dir, reshape=False)
(images, labels, handle,
training_iterator, heldout_iterator) = build_input_pipeline(
mnist_data, FLAGS.batch_size, mnist_data.validation.num_examples)
# Build a Bayesian LeNet5 network. We use the Flipout Monte Carlo estimator
# for the convolution and fully-connected layers: this enables lower
# variance stochastic gradients than naive reparameterization.
with tf.compat.v1.name_scope("bayesian_neural_net", values=[images]):
neural_net = tf.keras.Sequential([
tfp.layers.Convolution2DFlipout(6,
kernel_size=5,
padding="SAME",
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
strides=[2, 2],
padding="SAME"),
tfp.layers.Convolution2DFlipout(16,
kernel_size=5,
padding="SAME",
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
strides=[2, 2],
padding="SAME"),
tfp.layers.Convolution2DFlipout(120,
kernel_size=5,
padding="SAME",
activation=tf.nn.relu),
tf.keras.layers.Flatten(),
tfp.layers.DenseFlipout(84, activation=tf.nn.relu),
tfp.layers.DenseFlipout(10)
])
logits = neural_net(images)
labels_distribution = tfd.Categorical(logits=logits)
# Compute the -ELBO as the loss, averaged over the batch size.
neg_log_likelihood = -tf.reduce_mean(
input_tensor=labels_distribution.log_prob(labels))
kl = sum(neural_net.losses) / mnist_data.train.num_examples
elbo_loss = neg_log_likelihood + kl
# Build metrics for evaluation. Predictions are formed from a single forward
# pass of the probabilistic layers. They are cheap but noisy predictions.
predictions = tf.argmax(input=logits, axis=1)
accuracy, accuracy_update_op = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=predictions)
# Extract weight posterior statistics for layers with weight distributions
# for later visualization.
names = []
qmeans = []
qstds = []
for i, layer in enumerate(neural_net.layers):
try:
q = layer.kernel_posterior
except AttributeError:
continue
names.append("Layer {}".format(i))
qmeans.append(q.mean())
qstds.append(q.stddev())
with tf.compat.v1.name_scope("train"):
optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=FLAGS.learning_rate)
train_op = optimizer.minimize(elbo_loss)
init_op = tf.group(tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer())
with tf.compat.v1.Session() as sess:
sess.run(init_op)
# Run the training loop.
train_handle = sess.run(training_iterator.string_handle())
heldout_handle = sess.run(heldout_iterator.string_handle())
for step in range(FLAGS.max_steps):
_ = sess.run([train_op, accuracy_update_op],
feed_dict={handle: train_handle})
if step % 100 == 0:
loss_value, accuracy_value = sess.run(
[elbo_loss, accuracy], feed_dict={handle: train_handle})
print("Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f}".format(
step, loss_value, accuracy_value))
if (step+1) % FLAGS.viz_steps == 0:
# Compute log prob of heldout set by averaging draws from the model:
# p(heldout | train) = int_model p(heldout|model) p(model|train)
# ~= 1/n * sum_{i=1}^n p(heldout | model_i)
# where model_i is a draw from the posterior p(model|train).
probs = np.asarray([sess.run((labels_distribution.probs),
feed_dict={handle: heldout_handle})
for _ in range(FLAGS.num_monte_carlo)])
mean_probs = np.mean(probs, axis=0)
image_vals, label_vals = sess.run((images, labels),
feed_dict={handle: heldout_handle})
heldout_lp = np.mean(np.log(mean_probs[np.arange(mean_probs.shape[0]),
label_vals.flatten()]))
print(" ... Held-out nats: {:.3f}".format(heldout_lp))
qm_vals, qs_vals = sess.run((qmeans, qstds))
if HAS_SEABORN:
plot_weight_posteriors(names, qm_vals, qs_vals,
fname=os.path.join(
FLAGS.model_dir,
"step{:05d}_weights.png".format(step)))
plot_heldout_prediction(image_vals, probs,
fname=os.path.join(
FLAGS.model_dir,
"step{:05d}_pred.png".format(step)),
title="mean heldout logprob {:.2f}"
.format(heldout_lp))