in example_zoo/tensorflow/probability/vq_vae/trainer/vq_vae.py [0:0]
def build_input_pipeline(data_dir, batch_size, heldout_size, mnist_type):
"""Builds an Iterator switching between train and heldout data."""
# Build an iterator over training batches.
if mnist_type in [MnistType.FAKE_DATA, MnistType.THRESHOLD]:
if mnist_type == MnistType.FAKE_DATA:
mnist_data = build_fake_data()
else:
mnist_data = mnist.read_data_sets(data_dir)
training_dataset = tf.data.Dataset.from_tensor_slices(
(mnist_data.train.images, np.int32(mnist_data.train.labels)))
heldout_dataset = tf.data.Dataset.from_tensor_slices(
(mnist_data.validation.images,
np.int32(mnist_data.validation.labels)))
elif mnist_type == MnistType.BERNOULLI:
training_dataset = load_bernoulli_mnist_dataset(data_dir, "train")
heldout_dataset = load_bernoulli_mnist_dataset(data_dir, "valid")
else:
raise ValueError("Unknown MNIST type.")
training_batches = training_dataset.repeat().batch(batch_size)
training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)
# Build a iterator over the heldout set with batch_size=heldout_size,
# i.e., return the entire heldout set as a constant.
heldout_frozen = (heldout_dataset.take(heldout_size).
repeat().batch(heldout_size))
heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_frozen)
# Combine these into a feedable iterator that can switch between training
# and validation inputs.
handle = tf.compat.v1.placeholder(tf.string, shape=[])
feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
handle, training_batches.output_types, training_batches.output_shapes)
images, labels = feedable_iterator.get_next()
# Reshape as a pixel image and binarize pixels.
images = tf.reshape(images, shape=[-1] + IMAGE_SHAPE)
if mnist_type in [MnistType.FAKE_DATA, MnistType.THRESHOLD]:
images = tf.cast(images > 0.5, dtype=tf.int32)
return images, labels, handle, training_iterator, heldout_iterator