def generate_input_fn()

in recommended-item-search/input_pipeline.py [0:0]


def generate_input_fn(file_pattern, batch_size, mode=tf.estimator.ModeKeys.EVAL):
  """Generate input function for Estimator. 
  
  Args:
    file_pattern: pattern of input file names. 
    batch_size: batch size used in input function.
  Returns:
    input function which returns sequences of movie_ids.
  """
  def _input_fn():
    #ToDo(yaboo): num_cpu should be parameterized.
    files = tf.data.Dataset.list_files(file_pattern)
    dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=8)
    
    #ToDo(yaboo): buffer_size should be parameterized.
    if mode == tf.estimator.ModeKeys.TRAIN:
      dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.map(map_func=parse_fn, num_parallel_calls=8)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(2 * batch_size)
    
    # Note that movie_id sequences are padded with -1.
    dataset = dataset.padded_batch(
      batch_size=batch_size, padded_shapes=(tf.TensorShape([None])),
      padding_values=(tf.constant(-1, dtype=tf.int64)))
    return dataset
  return _input_fn