in benchmarks/rnnt/ootb/inference/pytorch/parts/manifest.py [0:0]
def __init__(self, data_dir, manifest_paths, labels, blank_index, max_duration=None, pad_to_max=False,
min_duration=None, sort_by_duration=False, max_utts=0,
normalize=True, speed_perturbation=False, filter_speed=1.0):
self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
self.blank_index = blank_index
self.max_duration = max_duration
ids = []
duration = 0.0
filtered_duration = 0.0
# If removing punctuation, make a list of punctuation to remove
table = None
if normalize:
# Punctuation to remove
punctuation = string.punctuation
punctuation = punctuation.replace("+", "")
punctuation = punctuation.replace("&", "")
# We might also want to consider:
# @ -> at
# -> number, pound, hashtag
# ~ -> tilde
# _ -> underscore
# % -> percent
# If a punctuation symbol is inside our vocab, we do not remove from text
for l in labels:
punctuation = punctuation.replace(l, "")
# Turn all punctuation to whitespace
table = str.maketrans(punctuation, " " * len(punctuation))
for manifest_path in manifest_paths:
with open(manifest_path, "r", encoding="utf-8") as fh:
a = json.load(fh)
for data in a:
files_and_speeds = data['files']
if pad_to_max:
if not speed_perturbation:
min_speed = filter_speed
else:
min_speed = min(x['speed']
for x in files_and_speeds)
max_duration = self.max_duration * min_speed
data['duration'] = data['original_duration']
if min_duration is not None and data['duration'] < min_duration:
filtered_duration += data['duration']
continue
if max_duration is not None and data['duration'] > max_duration:
filtered_duration += data['duration']
continue
# Prune and normalize according to transcript
transcript_text = data[
'transcript'] if "transcript" in data else self.load_transcript(
data['text_filepath'])
if normalize:
transcript_text = normalize_string(transcript_text, labels=labels,
table=table)
if not isinstance(transcript_text, str):
print(
"WARNING: Got transcript: {}. It is not a string. Dropping data point".format(
transcript_text))
filtered_duration += data['duration']
continue
data["transcript"] = self.parse_transcript(
transcript_text) # convert to vocab indices
if speed_perturbation:
audio_paths = [x['fname'] for x in files_and_speeds]
data['audio_duration'] = [x['duration']
for x in files_and_speeds]
else:
audio_paths = [
x['fname'] for x in files_and_speeds if x['speed'] == filter_speed]
data['audio_duration'] = [x['duration']
for x in files_and_speeds if x['speed'] == filter_speed]
data['audio_filepath'] = [os.path.join(
data_dir, x) for x in audio_paths]
data.pop('files')
data.pop('original_duration')
ids.append(data)
duration += data['duration']
if max_utts > 0 and len(ids) >= max_utts:
print(
'Stopping parsing %s as max_utts=%d' % (manifest_path, max_utts))
break
if sort_by_duration:
ids = sorted(ids, key=lambda x: x['duration'])
self._data = ids
self._size = len(ids)
self._duration = duration
self._filtered_duration = filtered_duration