def __init__()

in docker_images/fairseq/app/pipelines/audio_to_audio.py [0:0]


    def __init__(self, model_id: str):
        arg_overrides = ARG_OVERRIDES_MAP.get(
            model_id, {}
        )  # Model specific override. TODO: Update on checkpoint side in the future
        arg_overrides["config_yaml"] = "config.yaml"  # common override
        models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
            model_id,
            arg_overrides=arg_overrides,
            cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
        )
        self.cfg = cfg
        self.model = models[0].cpu()
        self.model.eval()
        self.task = task

        self.sampling_rate = getattr(self.task, "sr", None) or 16_000

        tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
        pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""

        generation_args = self.task.data_cfg.hub.get(f"{pfx}generation_args", None)
        if generation_args is not None:
            for key in generation_args:
                setattr(cfg.generation, key, generation_args[key])
        self.generator = task.build_generator([self.model], cfg.generation)

        tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
        self.unit_vocoder = self.task.data_cfg.hub.get(f"{pfx}unit_vocoder", None)
        self.tts_model, self.tts_task, self.tts_generator = None, None, None
        if tts_model_id is not None:
            _id = tts_model_id.split(":")[-1]
            cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE")
            if self.unit_vocoder is not None:
                library_name = "fairseq"
                cache_dir = (
                    cache_dir or (Path.home() / ".cache" / library_name).as_posix()
                )
                cache_dir = snapshot_download(
                    f"facebook/{_id}", cache_dir=cache_dir, library_name=library_name
                )

                x = hub_utils.from_pretrained(
                    cache_dir,
                    "model.pt",
                    ".",
                    archive_map=CodeHiFiGANVocoder.hub_models(),
                    config_yaml="config.json",
                    fp16=False,
                    is_vocoder=True,
                )

                with open(f"{x['args']['data']}/config.json") as f:
                    vocoder_cfg = json.load(f)
                assert (
                    len(x["args"]["model_path"]) == 1
                ), "Too many vocoder models in the input"

                vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg)
                self.tts_model = VocoderHubInterface(vocoder_cfg, vocoder)

            else:
                (
                    tts_models,
                    tts_cfg,
                    self.tts_task,
                ) = load_model_ensemble_and_task_from_hf_hub(
                    f"facebook/{_id}",
                    arg_overrides={"vocoder": "griffin_lim", "fp16": False},
                    cache_dir=cache_dir,
                )
                self.tts_model = tts_models[0].cpu()
                self.tts_model.eval()
                tts_cfg["task"].cpu = True
                TTSHubInterface.update_cfg_with_data_cfg(
                    tts_cfg, self.tts_task.data_cfg
                )
                self.tts_generator = self.tts_task.build_generator(
                    [self.tts_model], tts_cfg
                )