def make_petastorm_dataset()

in petastorm/tf_utils.py [0:0]


def make_petastorm_dataset(reader):
    """Creates a `tensorflow.data.Dataset <https://www.tensorflow.org/api_docs/python/tf/data/Dataset>`_ object from
    a Petastorm :class:`~petastorm.reader.Reader`.

    The returned object can be used as any ``tf.data.Dataset`` with some limitations described below.

    * ``repeat``: An error will be raised if you call ``repeat`` on the returned dataset. Please use ``num_epochs``
      argument of the :meth:`~petastorm.reader.Reader` constructor.
    * ``shard``: Consider using ``training_partition`` and ``num_training_partitions`` arguments of the
      :class:`~petastorm.reader.Reader` constructor as it will not load any unused shards.
    * ``filter``: Consider using :class:`~petastorm.reader.Reader` ``predicate`` constructor argument.
      It will make use of columnar nature of the underlying Apache Parquet store to load only the columns that the
      predicate operates on prior to loading and decoding other columns. :class:`~petastorm.reader.Reader`'s predicate
      feature will also make use of Parquet partitioning (if the dataset is partitioned).

    The elements produced by the returned dataset object are namedtuples based on the
    :class:`~petastorm.unischema.Unischema`.

    >>> import tensorflow.compat.v1 as tf  # pylint: disable=import-error
    >>> from petastorm.reader import Reader
    >>> from petastorm.tf_utils import make_petastorm_dataset
    >>>
    >>> with Reader('file:///some/path') as reader:
    >>>     dataset = make_petastorm_dataset(reader)
    >>>     next_sample = dataset.make_one_shot_iterator().get_next()
    >>>     with tf.Session() as sess:
    >>>         x = sess.run(next_sample)


    NGrams are not yet supported by this function.

    :param reader: An instance of :class:`~petastorm.reader.Reader` object that would serve as a data source.
    :return: A ``tf.data.Dataset`` instance.
    """

    if not reader.ngram:

        def dequeue_sample_impl():
            if reader.last_row_consumed:
                # This means that Dataset is trying to create a new instance of the generator. Can not do that
                # (nor want to do that) since this is an expensive operation. num_epochs is a more efficient way
                # to do this.
                raise RuntimeError('Multiple iterations over make_petastorm_dataset are not supported. '
                                   'Multiple iterations can be triggered by calling \'repeat\' method of Datset class.'
                                   'Use Reader\'s num_epochs contructor arguments to set number of iterations.')
            for row in reader:
                yield _sanitize_field_tf_types(row)

        flat_dataset = tf.data.Dataset.from_generator(dequeue_sample_impl, tuple(_schema_to_tf_dtypes(reader.schema)))

        # Don't write this function as a inline lambda like `dataset.map(lambda row: _set_shape_to_named_tuple(...))`,
        # It can avoid this error: https://github.com/tensorflow/tensorflow/issues/30149
        def set_shape(row):
            return _set_shape_to_named_tuple(reader.schema, row, reader.batched_output)

        schema_tuple = reader.schema._get_namedtuple()
        named_tuple_dataset = flat_dataset \
            .map(schema_tuple) \
            .map(set_shape)
        return named_tuple_dataset
    else:
        # flat_dataset is a tf.data.Dataset with a tuple containined ngram field stored in one flat tuple produced by
        # _flatten() function.
        flat_dataset = tf.data.Dataset.from_generator(lambda: _ngrams_generator(reader),
                                                      tuple(_schema_to_tf_dtypes_ngram(reader.schema, reader.ngram)))

        # Unflatten the tuple into a dictionary
        named_tuple_dataset = flat_dataset.map(
            lambda *nargs: _unflatten_and_set_shape(reader.schema, reader.ngram, nargs))

        return named_tuple_dataset