in src/autotrain/trainers/clm/train_clm_default.py [0:0]
def train(config):
logger.info("Starting default/generic CLM training...")
if isinstance(config, dict):
config = LLMTrainingParams(**config)
train_data, valid_data = utils.process_input_data(config)
tokenizer = utils.get_tokenizer(config)
train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)
train_data = process_data(
data=train_data,
tokenizer=tokenizer,
config=config,
)
if config.valid_split is not None:
valid_data = process_data(
data=valid_data,
tokenizer=tokenizer,
config=config,
)
logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
training_args = utils.configure_training_args(config, logging_steps)
config = utils.configure_block_size(config, tokenizer)
args = TrainingArguments(**training_args)
model = utils.get_model(config, tokenizer)
tokenize_fn = partial(utils.tokenize, tokenizer=tokenizer, config=config)
group_texts_fn = partial(utils.group_texts, config=config)
train_data = train_data.map(
tokenize_fn,
batched=True,
num_proc=1,
remove_columns=list(train_data.features),
desc="Running tokenizer on train dataset",
)
if config.valid_split is not None:
valid_data = valid_data.map(
tokenize_fn,
batched=True,
num_proc=1,
remove_columns=list(valid_data.features),
desc="Running tokenizer on validation dataset",
)
train_data = train_data.map(
group_texts_fn,
batched=True,
num_proc=4,
desc=f"Grouping texts in chunks of {config.block_size}",
)
if config.valid_split is not None:
valid_data = valid_data.map(
group_texts_fn,
batched=True,
num_proc=4,
desc=f"Grouping texts in chunks of {config.block_size}",
)
logger.info("creating trainer")
callbacks = utils.get_callbacks(config)
trainer_args = dict(
args=args,
model=model,
callbacks=callbacks,
)
trainer = Trainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
tokenizer=tokenizer,
data_collator=default_data_collator,
)
for name, module in trainer.model.named_modules():
if isinstance(module, LoraLayer):
if config.mixed_precision == "bf16":
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"):
if config.mixed_precision == "bf16" and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
trainer.remove_callback(PrinterCallback)
trainer.train()
utils.post_training_steps(config, trainer)