in recommended-item-search/input_pipeline.py [0:0]
def generate_input_fn(file_pattern, batch_size, mode=tf.estimator.ModeKeys.EVAL):
"""Generate input function for Estimator.
Args:
file_pattern: pattern of input file names.
batch_size: batch size used in input function.
Returns:
input function which returns sequences of movie_ids.
"""
def _input_fn():
#ToDo(yaboo): num_cpu should be parameterized.
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=8)
#ToDo(yaboo): buffer_size should be parameterized.
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.map(map_func=parse_fn, num_parallel_calls=8)
dataset = dataset.repeat()
dataset = dataset.prefetch(2 * batch_size)
# Note that movie_id sequences are padded with -1.
dataset = dataset.padded_batch(
batch_size=batch_size, padded_shapes=(tf.TensorShape([None])),
padding_values=(tf.constant(-1, dtype=tf.int64)))
return dataset
return _input_fn