in tensorflow_examples/lite/model_maker/core/data_util/audio_dataloader.py [0:0]
def gen_dataset(self,
batch_size=1,
is_training=False,
shuffle=False,
input_pipeline_context=None,
preprocess=None,
drop_remainder=False):
"""Generate a shared and batched tf.data.Dataset for training/evaluation.
Args:
batch_size: A integer, the returned dataset will be batched by this size.
is_training: A boolean, when True, the returned dataset will be optionally
shuffled. Data augmentation, if exists, will also be applied to the
returned dataset.
shuffle: A boolean, when True, the returned dataset will be shuffled to
create randomness during model training. Only applies when `is_training`
is set to True.
input_pipeline_context: A InputContext instance, used to shared dataset
among multiple workers when distribution strategy is used.
preprocess: Not in use.
drop_remainder: boolean, whether the finaly batch drops remainder.
Returns:
A TF dataset ready to be consumed by Keras model.
"""
# This argument is only used for image dataset for now. Audio preprocessing
# is defined in the spec.
del preprocess
ds = self._dataset
spec = self._spec
autotune = tf.data.AUTOTUNE
if is_training and shuffle:
options = tf.data.Options()
options.experimental_deterministic = False
ds = ds.with_options(options)
ds = dataloader.shard(ds, input_pipeline_context)
@tf.function
def _load_wav(filepath, label):
file_contents = tf.io.read_file(filepath)
# shape: (audio_samples, 1), dtype: float32
wav, sample_rate = tf.audio.decode_wav(file_contents, desired_channels=1)
# shape: (audio_samples,)
wav = tf.squeeze(wav, axis=-1)
return wav, sample_rate, label
# This is a eager mode numpy_function. It can be converted to a tf.function
# using https://www.tensorflow.org/io/api_docs/python/tfio/audio/resample
def _resample_numpy(waveform, sample_rate, label):
if ENABLE_RESAMPLE:
waveform = librosa.resample(
waveform, orig_sr=sample_rate, target_sr=spec.target_sample_rate)
else:
error_message = (
'Failed to import librosa. You might be missing sndfile, which '
'can be installed via `sudo apt-get install libsndfile1` on '
'Ubuntu/Debian.')
raise RuntimeError(error_message) from error_import_librosa
return waveform, label
@tf.function
def _resample(waveform, sample_rate, label):
# Short circuit resampling if possible.
if sample_rate == spec.target_sample_rate:
return [waveform, label]
return tf.numpy_function(
_resample_numpy,
inp=(waveform, sample_rate, label),
Tout=[tf.float32, tf.int32])
@tf.function
def _elements_finite(preprocess_data, unused_label):
# Make sure that the data sent to the model does not contain nan or inf
# values. This should be the last filter applied to the dataset.
# Arguably we could possibly apply this filter to all tasks.
return tf.size(preprocess_data) > 0 and tf.math.reduce_all(
tf.math.is_finite(preprocess_data))
ds = ds.map(_load_wav, num_parallel_calls=autotune)
ds = ds.map(_resample, num_parallel_calls=autotune)
def _cache_fn(dataset):
if self._cache:
if isinstance(self._cache, str):
# Cache to a file
dataset = dataset.cache(self._cache)
else:
# In ram cache.
dataset = dataset.cache()
return dataset
# `preprocess_ds` contains data augmentation, so it knows when it's the best
# time to do caching.
ds = spec.preprocess_ds(ds, is_training=is_training, cache_fn=_cache_fn)
ds = ds.filter(_elements_finite)
# Apply one-hot encoding after caching to reduce the cache size.
@tf.function
def _one_hot_encoding_label(wav, label):
return wav, tf.one_hot(label, len(self.index_to_label))
ds = ds.map(_one_hot_encoding_label, num_parallel_calls=autotune)
# Shuffle needs to be done after caching to create randomness across epochs.
if is_training:
if shuffle:
# Shuffle size should be bigger than the batch_size. Otherwise it's only
# shuffling within the batch, which equals to not having shuffle.
buffer_size = 3 * batch_size
# But since we are doing shuffle before repeat, it doesn't make sense to
# shuffle more than total available entries.
# TODO(wangtz): Do we want to do shuffle before / after repeat?
# Shuffle after repeat will give a more randomized dataset and mix the
# epoch boundary: https://www.tensorflow.org/guide/data
ds = ds.shuffle(buffer_size=min(self._size, buffer_size))
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
# TODO(b/171449557): Consider converting ds to distributed ds here.
return ds