in training/flax/distil_whisper/pipeline.py [0:0]
def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
if generation_config is None:
generation_config = self.model.generation_config
if hasattr(generation_config, "is_multilingual"):
is_multilingual = generation_config.is_multilingual
else:
is_multilingual = None
forced_decoder_ids = []
if is_multilingual:
if language is not None:
language = language.lower()
if language in generation_config.lang_to_id.keys():
language_token = language
elif language in TO_LANGUAGE_CODE.values():
language_token = f"<|{language}|>"
elif language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
else:
if len(language) == 2:
# ISO 639-1 language code
acceptable_languages = list(TO_LANGUAGE_CODE.values())
elif "<" in language or "|" in language or ">" in language:
# generation config language code
acceptable_languages = list(generation_config.lang_to_id.keys())
else:
# language passed as a string
acceptable_languages = list(TO_LANGUAGE_CODE.keys())
raise ValueError(
f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
if task is not None:
forced_decoder_ids.append((2, generation_config.task_to_id[task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if not return_timestamps:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
else:
forced_decoder_ids.append((1, generation_config.no_timestamps_token_id))
return forced_decoder_ids