def train()

in src/autotrain/trainers/clm/train_clm_default.py [0:0]


def train(config):
    logger.info("Starting default/generic CLM training...")
    if isinstance(config, dict):
        config = LLMTrainingParams(**config)
    train_data, valid_data = utils.process_input_data(config)
    tokenizer = utils.get_tokenizer(config)
    train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)

    train_data = process_data(
        data=train_data,
        tokenizer=tokenizer,
        config=config,
    )
    if config.valid_split is not None:
        valid_data = process_data(
            data=valid_data,
            tokenizer=tokenizer,
            config=config,
        )

    logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
    training_args = utils.configure_training_args(config, logging_steps)
    config = utils.configure_block_size(config, tokenizer)
    args = TrainingArguments(**training_args)

    model = utils.get_model(config, tokenizer)

    tokenize_fn = partial(utils.tokenize, tokenizer=tokenizer, config=config)
    group_texts_fn = partial(utils.group_texts, config=config)

    train_data = train_data.map(
        tokenize_fn,
        batched=True,
        num_proc=1,
        remove_columns=list(train_data.features),
        desc="Running tokenizer on train dataset",
    )

    if config.valid_split is not None:
        valid_data = valid_data.map(
            tokenize_fn,
            batched=True,
            num_proc=1,
            remove_columns=list(valid_data.features),
            desc="Running tokenizer on validation dataset",
        )

    train_data = train_data.map(
        group_texts_fn,
        batched=True,
        num_proc=4,
        desc=f"Grouping texts in chunks of {config.block_size}",
    )

    if config.valid_split is not None:
        valid_data = valid_data.map(
            group_texts_fn,
            batched=True,
            num_proc=4,
            desc=f"Grouping texts in chunks of {config.block_size}",
        )

    logger.info("creating trainer")
    callbacks = utils.get_callbacks(config)
    trainer_args = dict(
        args=args,
        model=model,
        callbacks=callbacks,
    )
    trainer = Trainer(
        **trainer_args,
        train_dataset=train_data,
        eval_dataset=valid_data if config.valid_split is not None else None,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )
    for name, module in trainer.model.named_modules():
        if isinstance(module, LoraLayer):
            if config.mixed_precision == "bf16":
                module = module.to(torch.bfloat16)
        if "norm" in name:
            module = module.to(torch.float32)
        if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
            if hasattr(module, "weight"):
                if config.mixed_precision == "bf16" and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    trainer.remove_callback(PrinterCallback)
    trainer.train()
    utils.post_training_steps(config, trainer)