def __init__()

in benchmarks/rnnt/ootb/inference/pytorch/parts/manifest.py [0:0]


    def __init__(self, data_dir, manifest_paths, labels, blank_index, max_duration=None, pad_to_max=False,
                 min_duration=None, sort_by_duration=False, max_utts=0,
                 normalize=True, speed_perturbation=False, filter_speed=1.0):
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        self.blank_index = blank_index
        self.max_duration = max_duration
        ids = []
        duration = 0.0
        filtered_duration = 0.0

        # If removing punctuation, make a list of punctuation to remove
        table = None
        if normalize:
            # Punctuation to remove
            punctuation = string.punctuation
            punctuation = punctuation.replace("+", "")
            punctuation = punctuation.replace("&", "")
            # We might also want to consider:
            # @ -> at
            # -> number, pound, hashtag
            # ~ -> tilde
            # _ -> underscore
            # % -> percent
            # If a punctuation symbol is inside our vocab, we do not remove from text
            for l in labels:
                punctuation = punctuation.replace(l, "")
            # Turn all punctuation to whitespace
            table = str.maketrans(punctuation, " " * len(punctuation))
        for manifest_path in manifest_paths:
            with open(manifest_path, "r", encoding="utf-8") as fh:
                a = json.load(fh)
                for data in a:
                    files_and_speeds = data['files']

                    if pad_to_max:
                        if not speed_perturbation:
                            min_speed = filter_speed
                        else:
                            min_speed = min(x['speed']
                                            for x in files_and_speeds)
                        max_duration = self.max_duration * min_speed

                    data['duration'] = data['original_duration']
                    if min_duration is not None and data['duration'] < min_duration:
                        filtered_duration += data['duration']
                        continue
                    if max_duration is not None and data['duration'] > max_duration:
                        filtered_duration += data['duration']
                        continue

                    # Prune and normalize according to transcript
                    transcript_text = data[
                        'transcript'] if "transcript" in data else self.load_transcript(
                        data['text_filepath'])
                    if normalize:
                        transcript_text = normalize_string(transcript_text, labels=labels,
                                                           table=table)
                    if not isinstance(transcript_text, str):
                        print(
                            "WARNING: Got transcript: {}. It is not a string. Dropping data point".format(
                                transcript_text))
                        filtered_duration += data['duration']
                        continue
                    data["transcript"] = self.parse_transcript(
                        transcript_text)  # convert to vocab indices

                    if speed_perturbation:
                        audio_paths = [x['fname'] for x in files_and_speeds]
                        data['audio_duration'] = [x['duration']
                                                  for x in files_and_speeds]
                    else:
                        audio_paths = [
                            x['fname'] for x in files_and_speeds if x['speed'] == filter_speed]
                        data['audio_duration'] = [x['duration']
                                                  for x in files_and_speeds if x['speed'] == filter_speed]
                    data['audio_filepath'] = [os.path.join(
                        data_dir, x) for x in audio_paths]
                    data.pop('files')
                    data.pop('original_duration')

                    ids.append(data)
                    duration += data['duration']

                    if max_utts > 0 and len(ids) >= max_utts:
                        print(
                            'Stopping parsing %s as max_utts=%d' % (manifest_path, max_utts))
                        break

        if sort_by_duration:
            ids = sorted(ids, key=lambda x: x['duration'])
        self._data = ids
        self._size = len(ids)
        self._duration = duration
        self._filtered_duration = filtered_duration