def dataloader_fn()

in tftrt/blog_posts/Leveraging TensorFlow-TensorRT integration for Low latency Inference/tf2_inference.py [0:0]


    def dataloader_fn(data_dir, batch_size):

        import tensorflow_datasets as tfds
        from official.vision.image_classification.mnist_main import decode_image

        mnist = tfds.builder('mnist', data_dir=data_dir)
        mnist.download_and_prepare()

        _, mnist_test = mnist.as_dataset(
            split=['train', 'test'],
            decoders={'image': decode_image()},
            as_supervised=True
        )

        ds = mnist_test.cache().repeat().batch(batch_size)
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

        return ds