def gen_dataset()

in tensorflow_examples/lite/model_maker/core/data_util/audio_dataloader.py [0:0]


  def gen_dataset(self,
                  batch_size=1,
                  is_training=False,
                  shuffle=False,
                  input_pipeline_context=None,
                  preprocess=None,
                  drop_remainder=False):
    """Generate a shared and batched tf.data.Dataset for training/evaluation.

    Args:
      batch_size: A integer, the returned dataset will be batched by this size.
      is_training: A boolean, when True, the returned dataset will be optionally
        shuffled. Data augmentation, if exists, will also be applied to the
        returned dataset.
      shuffle: A boolean, when True, the returned dataset will be shuffled to
        create randomness during model training. Only applies when `is_training`
        is set to True.
      input_pipeline_context: A InputContext instance, used to shared dataset
        among multiple workers when distribution strategy is used.
      preprocess: Not in use.
      drop_remainder: boolean, whether the finaly batch drops remainder.

    Returns:
      A TF dataset ready to be consumed by Keras model.
    """
    # This argument is only used for image dataset for now. Audio preprocessing
    # is defined in the spec.
    del preprocess
    ds = self._dataset
    spec = self._spec
    autotune = tf.data.AUTOTUNE

    if is_training and shuffle:
      options = tf.data.Options()
      options.experimental_deterministic = False
      ds = ds.with_options(options)

    ds = dataloader.shard(ds, input_pipeline_context)

    @tf.function
    def _load_wav(filepath, label):
      file_contents = tf.io.read_file(filepath)
      # shape: (audio_samples, 1), dtype: float32
      wav, sample_rate = tf.audio.decode_wav(file_contents, desired_channels=1)
      # shape: (audio_samples,)
      wav = tf.squeeze(wav, axis=-1)
      return wav, sample_rate, label

    # This is a eager mode numpy_function. It can be converted to a tf.function
    # using https://www.tensorflow.org/io/api_docs/python/tfio/audio/resample
    def _resample_numpy(waveform, sample_rate, label):
      if ENABLE_RESAMPLE:
        waveform = librosa.resample(
            waveform, orig_sr=sample_rate, target_sr=spec.target_sample_rate)
      else:
        error_message = (
            'Failed to import librosa. You might be missing sndfile, which '
            'can be installed via `sudo apt-get install libsndfile1` on '
            'Ubuntu/Debian.')
        raise RuntimeError(error_message) from error_import_librosa
      return waveform, label

    @tf.function
    def _resample(waveform, sample_rate, label):
      # Short circuit resampling if possible.
      if sample_rate == spec.target_sample_rate:
        return [waveform, label]
      return tf.numpy_function(
          _resample_numpy,
          inp=(waveform, sample_rate, label),
          Tout=[tf.float32, tf.int32])

    @tf.function
    def _elements_finite(preprocess_data, unused_label):
      # Make sure that the data sent to the model does not contain nan or inf
      # values. This should be the last filter applied to the dataset.
      # Arguably we could possibly apply this filter to all tasks.
      return tf.size(preprocess_data) > 0 and tf.math.reduce_all(
          tf.math.is_finite(preprocess_data))

    ds = ds.map(_load_wav, num_parallel_calls=autotune)
    ds = ds.map(_resample, num_parallel_calls=autotune)

    def _cache_fn(dataset):
      if self._cache:
        if isinstance(self._cache, str):
          # Cache to a file
          dataset = dataset.cache(self._cache)
        else:
          # In ram cache.
          dataset = dataset.cache()
      return dataset

    # `preprocess_ds` contains data augmentation, so it knows when it's the best
    # time to do caching.
    ds = spec.preprocess_ds(ds, is_training=is_training, cache_fn=_cache_fn)
    ds = ds.filter(_elements_finite)

    # Apply one-hot encoding after caching to reduce the cache size.
    @tf.function
    def _one_hot_encoding_label(wav, label):
      return wav, tf.one_hot(label, len(self.index_to_label))

    ds = ds.map(_one_hot_encoding_label, num_parallel_calls=autotune)

    # Shuffle needs to be done after caching to create randomness across epochs.
    if is_training:
      if shuffle:
        # Shuffle size should be bigger than the batch_size. Otherwise it's only
        # shuffling within the batch, which equals to not having shuffle.
        buffer_size = 3 * batch_size
        # But since we are doing shuffle before repeat, it doesn't make sense to
        # shuffle more than total available entries.
        # TODO(wangtz): Do we want to do shuffle before / after repeat?
        # Shuffle after repeat will give a more randomized dataset and mix the
        # epoch boundary: https://www.tensorflow.org/guide/data
        ds = ds.shuffle(buffer_size=min(self._size, buffer_size))

    ds = ds.batch(batch_size, drop_remainder=drop_remainder)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    # TODO(b/171449557): Consider converting ds to distributed ds here.
    return ds