def __init__()

in ultravox/training/model_types.py [0:0]


    def __init__(self, args: config_base.TrainConfig):
        self.args = args
        self.text_tokenizer: transformers.PreTrainedTokenizerFast = (
            transformers.AutoTokenizer.from_pretrained(args.text_model)
        )
        self.text_tokenizer.padding_side = "right"
        self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
        audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model)

        # Instantiate the model and processor
        self.config = ultravox_config.UltravoxConfig(
            audio_model_id=args.audio_model,
            text_model_id=args.text_model,
            text_model_lora_config=args.text_model_lora_config,
            audio_model_lora_config=args.audio_model_lora_config,
            torch_dtype=args.data_type,
            pad_token_id=self.text_tokenizer.eos_token_id,
            projector_ln_mid=args.projector_ln_mid,
        )

        # Instantiate the model
        self.model: ultravox_model.UltravoxModel = ultravox_model.UltravoxModel(
            self.config
        )

        self.processor = ultravox_processing.UltravoxProcessor(
            audio_processor,
            self.text_tokenizer,
            audio_context_size=self.model.audio_tower_context_length,
        )

        # loss_config needs to be passed separately just for model training
        if args.loss_config is not None:
            self.model.set_loss_config(args.loss_config)

        # Set up the data loader
        self.data_collator = ultravox_processing.DataCollatorForSeq2SeqWithAudio(
            tokenizer=self.text_tokenizer,
            include_alt_fields=self.model.loss_config.requires_alt_fields,
        )

        assert self.model.get_input_embeddings().num_embeddings == len(
            self.text_tokenizer
        ), f"Model and tokenizer mismatch: {self.model.get_input_embeddings().num_embeddings} != {len(self.text_tokenizer)}"

        self.model.language_model.config.use_cache = False
        if args.disable_layerdrop and hasattr(
            self.model.audio_tower.config, "layerdrop"
        ):
            # layerdrop causes issues when training with DDP
            # https://github.com/huggingface/transformers/issues/17116#issuecomment-1121340890
            self.model.audio_tower.config.layerdrop = 0.0