in tensorflow_ranking/python/data.py [0:0]
def build_ranking_dataset_with_parsing_fn(
file_pattern,
parsing_fn,
batch_size,
reader=tf.data.TFRecordDataset,
reader_args=None,
num_epochs=None,
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
prefetch_buffer_size=tf.data.experimental.AUTOTUNE,
reader_num_threads=tf.data.experimental.AUTOTUNE,
sloppy_ordering=False,
drop_final_batch=False,
num_parser_threads=tf.data.experimental.AUTOTUNE):
"""Builds a ranking tf.dataset using the provided `parsing_fn`.
Args:
file_pattern: (str | list(str)) List of files or patterns of file paths
containing serialized data. See `tf.gfile.Glob` for pattern rules.
parsing_fn: (function) It has a single argument parsing_fn(serialized).
Users can customize this for their own data formats.
batch_size: (int) Number of records to combine in a single batch.
reader: A function or class that can be called with a `filenames` tensor and
(optional) `reader_args` and returns a `Dataset`. Defaults to
`tf.data.TFRecordDataset`.
reader_args: (list) Additional argument list to pass to the reader class.
num_epochs: (int) Number of times to read through the dataset. If None,
cycles through the dataset forever. Defaults to `None`.
shuffle: (bool) Indicates whether the input should be shuffled. Defaults to
`True`.
shuffle_buffer_size: (int) Buffer size of the ShuffleDataset. A large
capacity ensures better shuffling but would increase memory usage and
startup time.
shuffle_seed: (int) Randomization seed to use for shuffling.
prefetch_buffer_size: (int) Number of feature batches to prefetch in order
to improve performance. Recommended value is the number of batches
consumed per training step. Defaults to auto-tune.
reader_num_threads: (int) Number of threads used to read records. If greater
than 1, the results will be interleaved. Defaults to auto-tune.
sloppy_ordering: (bool) If `True`, reading performance will be improved at
the cost of non-deterministic ordering. If `False`, the order of elements
produced is deterministic prior to shuffling (elements are still
randomized if `shuffle=True`. Note that if the seed is set, then order of
elements after shuffling is deterministic). Defaults to `False`.
drop_final_batch: (bool) If `True`, and the batch size does not evenly
divide the input dataset size, the final smaller batch will be dropped.
Defaults to `False`. If `True`, the batch_size can be statically inferred.
num_parser_threads: (int) Optional number of threads to be used with
dataset.map() when invoking parsing_fn. Defaults to auto-tune.
Returns:
A dataset of `dict` elements. Each `dict` maps feature keys to
`Tensor` or `SparseTensor` objects.
"""
dataset = tf.data.Dataset.list_files(
file_pattern, shuffle=shuffle, seed=shuffle_seed)
if reader_num_threads == tf.data.experimental.AUTOTUNE:
dataset = dataset.interleave(
lambda filename: reader(filename, *(reader_args or [])),
num_parallel_calls=reader_num_threads)
else:
# cycle_length needs to be set when reader_num_threads is not AUTOTUNE.
dataset = dataset.interleave(
lambda filename: reader(filename, *(reader_args or [])),
cycle_length=reader_num_threads,
num_parallel_calls=reader_num_threads)
options = tf.data.Options()
options.experimental_deterministic = not sloppy_ordering
dataset = dataset.with_options(options)
# Extract values if tensors are stored as key-value tuples.
if tf.compat.v1.data.get_output_types(dataset) == (tf.string, tf.string):
dataset = dataset.map(lambda _, v: v)
# Repeat and shuffle, if needed.
if num_epochs != 1:
dataset = dataset.repeat(num_epochs)
if shuffle:
dataset = dataset.shuffle(
buffer_size=shuffle_buffer_size, seed=shuffle_seed)
# The drop_remainder=True allows for static inference of batch size.
dataset = dataset.batch(
batch_size, drop_remainder=drop_final_batch or num_epochs is None)
# Parse a batch.
dataset = dataset.map(parsing_fn, num_parallel_calls=num_parser_threads)
# Prefetching allows for data fetching to happen on host while model runs
# on the accelerator. When run on CPU, makes data fetching asynchronous.
dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)
return dataset