def train()

in src/autotrain/trainers/clm/train_clm_dpo.py [0:0]


def train(config):
    logger.info("Starting DPO training...")
    if isinstance(config, dict):
        config = LLMTrainingParams(**config)
    train_data, valid_data = utils.process_input_data(config)
    tokenizer = utils.get_tokenizer(config)
    train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)

    logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
    training_args = utils.configure_training_args(config, logging_steps)
    config = utils.configure_block_size(config, tokenizer)

    training_args["max_length"] = config.block_size
    training_args["max_prompt_length"] = config.max_prompt_length
    training_args["max_target_length"] = config.max_completion_length
    training_args["beta"] = config.dpo_beta
    args = DPOConfig(**training_args)

    logger.info("loading model config...")
    model_config = AutoConfig.from_pretrained(
        config.model,
        token=config.token,
        trust_remote_code=ALLOW_REMOTE_CODE,
        use_cache=config.disable_gradient_checkpointing,
    )

    logger.info("loading model...")
    if config.peft:
        if config.quantization == "int4":
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=False,
            )
        elif config.quantization == "int8":
            bnb_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            bnb_config = None

        model = AutoModelForCausalLM.from_pretrained(
            config.model,
            config=model_config,
            token=config.token,
            quantization_config=bnb_config,
            trust_remote_code=ALLOW_REMOTE_CODE,
            use_flash_attention_2=config.use_flash_attention_2,
        )
        logger.info("Using PEFT, model_ref will be set to None")
        model_ref = None
    else:
        model = AutoModelForCausalLM.from_pretrained(
            config.model,
            config=model_config,
            token=config.token,
            trust_remote_code=ALLOW_REMOTE_CODE,
            use_flash_attention_2=config.use_flash_attention_2,
        )
        if config.model_ref is not None:
            model_ref = AutoModelForCausalLM.from_pretrained(
                config.model_ref,
                config=model_config,
                token=config.token,
                trust_remote_code=ALLOW_REMOTE_CODE,
                use_flash_attention_2=config.use_flash_attention_2,
            )
        else:
            model_ref = None

    logger.info(f"model dtype: {model.dtype}")
    model.resize_token_embeddings(len(tokenizer))

    if model_ref is not None:
        logger.info(f"model_ref dtype: {model_ref.dtype}")
        model_ref.resize_token_embeddings(len(tokenizer))

    if config.peft:
        peft_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=utils.get_target_modules(config),
        )

    logger.info("creating trainer")
    callbacks = utils.get_callbacks(config)
    trainer_args = dict(
        args=args,
        model=model,
        callbacks=callbacks,
    )

    trainer = DPOTrainer(
        **trainer_args,
        ref_model=model_ref,
        train_dataset=train_data,
        eval_dataset=valid_data if config.valid_split is not None else None,
        processing_class=tokenizer,
        peft_config=peft_config if config.peft else None,
    )

    trainer.remove_callback(PrinterCallback)
    trainer.train()
    utils.post_training_steps(config, trainer)