def _input_fn()

in benchmarking/pipemode_benchmark/docker/gpu_pipeline_script.py [0:0]


def _input_fn():
    features = {
        'data': tf.FixedLenFeature([], tf.string),
        'labels': tf.FixedLenFeature([], tf.int64),
    }

    def parse(record):
        return tf.parse_single_example(record, features)

    ds = PipeModeDataset(config.channel, benchmark=True)
    if config.epochs > 1:
        ds = ds.repeat(config.epochs)
    if config.prefetch_size > 0:
        ds = ds.prefetch(config.prefetch_size)
    ds = ds.apply(map_and_batch(parse, batch_size=config.batch_size,
                                num_parallel_batches=config.parallel_transform_calls))
    return ds