def build_ranking_dataset_with_parsing_fn()

in tensorflow_ranking/python/data.py [0:0]


def build_ranking_dataset_with_parsing_fn(
    file_pattern,
    parsing_fn,
    batch_size,
    reader=tf.data.TFRecordDataset,
    reader_args=None,
    num_epochs=None,
    shuffle=True,
    shuffle_buffer_size=10000,
    shuffle_seed=None,
    prefetch_buffer_size=tf.data.experimental.AUTOTUNE,
    reader_num_threads=tf.data.experimental.AUTOTUNE,
    sloppy_ordering=False,
    drop_final_batch=False,
    num_parser_threads=tf.data.experimental.AUTOTUNE):
  """Builds a ranking tf.dataset using the provided `parsing_fn`.

  Args:
    file_pattern: (str | list(str)) List of files or patterns of file paths
      containing serialized data. See `tf.gfile.Glob` for pattern rules.
    parsing_fn: (function) It has a single argument parsing_fn(serialized).
      Users can customize this for their own data formats.
    batch_size: (int) Number of records to combine in a single batch.
    reader: A function or class that can be called with a `filenames` tensor and
      (optional) `reader_args` and returns a `Dataset`. Defaults to
      `tf.data.TFRecordDataset`.
    reader_args: (list) Additional argument list to pass to the reader class.
    num_epochs: (int) Number of times to read through the dataset. If None,
      cycles through the dataset forever. Defaults to `None`.
    shuffle: (bool) Indicates whether the input should be shuffled. Defaults to
      `True`.
    shuffle_buffer_size: (int) Buffer size of the ShuffleDataset. A large
      capacity ensures better shuffling but would increase memory usage and
      startup time.
    shuffle_seed: (int) Randomization seed to use for shuffling.
    prefetch_buffer_size: (int) Number of feature batches to prefetch in order
      to improve performance. Recommended value is the number of batches
      consumed per training step. Defaults to auto-tune.
    reader_num_threads: (int) Number of threads used to read records. If greater
      than 1, the results will be interleaved. Defaults to auto-tune.
    sloppy_ordering: (bool) If `True`, reading performance will be improved at
      the cost of non-deterministic ordering. If `False`, the order of elements
      produced is deterministic prior to shuffling (elements are still
      randomized if `shuffle=True`. Note that if the seed is set, then order of
      elements after shuffling is deterministic). Defaults to `False`.
    drop_final_batch: (bool) If `True`, and the batch size does not evenly
      divide the input dataset size, the final smaller batch will be dropped.
      Defaults to `False`. If `True`, the batch_size can be statically inferred.
    num_parser_threads: (int) Optional number of threads to be used with
      dataset.map() when invoking parsing_fn. Defaults to auto-tune.

  Returns:
    A dataset of `dict` elements. Each `dict` maps feature keys to
    `Tensor` or `SparseTensor` objects.
  """
  dataset = tf.data.Dataset.list_files(
      file_pattern, shuffle=shuffle, seed=shuffle_seed)

  if reader_num_threads == tf.data.experimental.AUTOTUNE:
    dataset = dataset.interleave(
        lambda filename: reader(filename, *(reader_args or [])),
        num_parallel_calls=reader_num_threads)
  else:
    # cycle_length needs to be set when reader_num_threads is not AUTOTUNE.
    dataset = dataset.interleave(
        lambda filename: reader(filename, *(reader_args or [])),
        cycle_length=reader_num_threads,
        num_parallel_calls=reader_num_threads)

  options = tf.data.Options()
  options.experimental_deterministic = not sloppy_ordering
  dataset = dataset.with_options(options)

  # Extract values if tensors are stored as key-value tuples.
  if tf.compat.v1.data.get_output_types(dataset) == (tf.string, tf.string):
    dataset = dataset.map(lambda _, v: v)

  # Repeat and shuffle, if needed.
  if num_epochs != 1:
    dataset = dataset.repeat(num_epochs)
  if shuffle:
    dataset = dataset.shuffle(
        buffer_size=shuffle_buffer_size, seed=shuffle_seed)
  # The drop_remainder=True allows for static inference of batch size.
  dataset = dataset.batch(
      batch_size, drop_remainder=drop_final_batch or num_epochs is None)

  # Parse a batch.
  dataset = dataset.map(parsing_fn, num_parallel_calls=num_parser_threads)

  # Prefetching allows for data fetching to happen on host while model runs
  # on the accelerator. When run on CPU, makes data fetching asynchronous.
  dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)

  return dataset