in tftrt/blog_posts/Leveraging TensorFlow-TensorRT integration for Low latency Inference/tf2_inference.py [0:0]
def dataloader_fn(data_dir, batch_size):
import tensorflow_datasets as tfds
from official.vision.image_classification.mnist_main import decode_image
mnist = tfds.builder('mnist', data_dir=data_dir)
mnist.download_and_prepare()
_, mnist_test = mnist.as_dataset(
split=['train', 'test'],
decoders={'image': decode_image()},
as_supervised=True
)
ds = mnist_test.cache().repeat().batch(batch_size)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds