in examples/mnist/tf_example.py [0:0]
def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):
"""
Train a model for training iterations with a batch size batch_size, printing accuracy every log_interval.
:param dataset_url: The MNIST dataset url.
:param training_iterations: The training iterations to train for.
:param batch_size: The batch size for training.
:param evaluation_interval: The interval used to print the accuracy.
:return:
"""
with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
train_readout = tf_tensors(train_reader)
train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
train_label = train_readout.digit
batch_image, batch_label = tf.train.batch(
[train_image, train_label], batch_size=batch_size
)
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(batch_image, W) + b
# The raw formulation of cross-entropy,
#
# tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
# reduction_indices=[1]))
#
# can be numerically unstable.
#
# So here we use tf.losses.sparse_softmax_cross_entropy on the raw
# outputs of 'y', and then average across the batch.
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=batch_label, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y, 1), batch_label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
test_readout = tf_tensors(test_reader)
test_image = tf.cast(tf.reshape(test_readout.image, [784]), tf.float32)
test_label = test_readout.digit
test_batch_image, test_batch_label = tf.train.batch(
[test_image, test_label], batch_size=batch_size
)
# Train
print('Training model for {0} training iterations with batch size {1} and evaluation interval {2}'.format(
training_iterations, batch_size, evaluation_interval
))
with tf.Session() as sess:
sess.run([
tf.local_variables_initializer(),
tf.global_variables_initializer(),
])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for i in range(training_iterations):
if coord.should_stop():
break
sess.run(train_step)
if (i % evaluation_interval) == 0 or i == (training_iterations - 1):
feed_batch_image, feed_batch_label = sess.run([test_batch_image, test_batch_label])
print('After {0} training iterations, the accuracy of the model is: {1:.2f}'.format(
i,
sess.run(accuracy, feed_dict={
batch_image: feed_batch_image, batch_label: feed_batch_label
})))
finally:
coord.request_stop()
coord.join(threads)