recommended-item-search/input_pipeline.py (29 lines of code) (raw):
#!/usr/bin/python
#
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
# tf.enable_eager_execution()
def parse_fn(serialized_example):
"""Parse a serialized example."""
# user_id is not currently used.
context_features = {
'user_id': tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
'movie_ids': tf.FixedLenSequenceFeature([], dtype=tf.int64)
}
parsed_feature, parsed_sequence_feature = tf.parse_single_sequence_example(
serialized=serialized_example,
context_features=context_features,
sequence_features=sequence_features
)
movie_ids = parsed_sequence_feature['movie_ids']
return movie_ids
def generate_input_fn(file_pattern, batch_size, mode=tf.estimator.ModeKeys.EVAL):
"""Generate input function for Estimator.
Args:
file_pattern: pattern of input file names.
batch_size: batch size used in input function.
Returns:
input function which returns sequences of movie_ids.
"""
def _input_fn():
#ToDo(yaboo): num_cpu should be parameterized.
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=8)
#ToDo(yaboo): buffer_size should be parameterized.
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.map(map_func=parse_fn, num_parallel_calls=8)
dataset = dataset.repeat()
dataset = dataset.prefetch(2 * batch_size)
# Note that movie_id sequences are padded with -1.
dataset = dataset.padded_batch(
batch_size=batch_size, padded_shapes=(tf.TensorShape([None])),
padding_values=(tf.constant(-1, dtype=tf.int64)))
return dataset
return _input_fn