def get_forced_decoder_ids()

in training/flax/distil_whisper/pipeline.py [0:0]


    def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
        if generation_config is None:
            generation_config = self.model.generation_config

        if hasattr(generation_config, "is_multilingual"):
            is_multilingual = generation_config.is_multilingual
        else:
            is_multilingual = None

        forced_decoder_ids = []

        if is_multilingual:
            if language is not None:
                language = language.lower()
                if language in generation_config.lang_to_id.keys():
                    language_token = language
                elif language in TO_LANGUAGE_CODE.values():
                    language_token = f"<|{language}|>"
                elif language in TO_LANGUAGE_CODE.keys():
                    language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
                else:
                    if len(language) == 2:
                        # ISO 639-1 language code
                        acceptable_languages = list(TO_LANGUAGE_CODE.values())
                    elif "<" in language or "|" in language or ">" in language:
                        # generation config language code
                        acceptable_languages = list(generation_config.lang_to_id.keys())
                    else:
                        # language passed as a string
                        acceptable_languages = list(TO_LANGUAGE_CODE.keys())
                    raise ValueError(
                        f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
                    )
                forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))

            if task is not None:
                forced_decoder_ids.append((2, generation_config.task_to_id[task]))
            else:
                forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))

        if not return_timestamps:
            if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
                idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
                forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
            else:
                forced_decoder_ids.append((1, generation_config.no_timestamps_token_id))

        return forced_decoder_ids