in ml/trainer.py [0:0]
def main():
# Initialize wandb for logging
wandb.init(project="kto")
# Get timestamp at start of training
training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
print("Loading base model and tokenizer...")
model, tokenizer = load_model_and_tokenizer(model_args)
ref_model, _ = load_model_and_tokenizer(model_args)
print("Models and tokenizer loaded.")
# Load existing adapter or create new one
loaded_model, previous_timestamp = load_latest_adapter(
model,
model_args.model_name,
script_args.language
)
if loaded_model is not None:
model = loaded_model
print(f"Loaded existing adapter trained at {previous_timestamp}")
else:
# Initialize new LoRA adapter
peft_config = get_peft_config(model_args)
model = get_peft_model(model, peft_config)
print("Initialized new adapter")
# -----------------------------
# Data Preparation and Training
# -----------------------------
print("Processing dataset...")
dataset = script_args.process_dataset_func(script_args.language)
print("Dataset processed.")
print("Initializing trainer...")
trainer = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
peft_config=peft_config,
)
# Training
print("Starting training...")
trainer.train()
print("Training completed.")
# Evaluation
print("Evaluating model...")
metrics = trainer.evaluate()
print(f"Metrics: {metrics}")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Log metrics to wandb
wandb.log({
"epoch": metrics.get("epoch"),
"grad_norm": metrics.get("grad_norm"),
"kl": metrics.get("kl"),
"learning_rate": metrics.get("learning_rate"),
"logits/chosen": metrics.get("logits/chosen"),
"logits/rejected": metrics.get("logits/rejected"),
"logps/chosen": metrics.get("logps/chosen"),
"logps/rejected": metrics.get("logps/rejected"),
"loss": metrics.get("loss"),
"rewards/chosen": metrics.get("rewards/chosen"),
"rewards/margins": metrics.get("rewards/margins"),
"rewards/rejected": metrics.get("rewards/rejected"),
"step": metrics.get("step")
})
# Save the adapter
adapter_path = get_adapter_path(
model_args.model_name,
script_args.language,
training_timestamp
)
adapter_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Saving adapter to: {adapter_path}")
model.save_pretrained(adapter_path)
# Save metadata
metadata = AdapterMetadata(
training_timestamp=training_timestamp,
model_name=model_args.model_name,
language=script_args.language,
)
metadata.save(adapter_path / "metadata.json")
if script_args.push_to_hub:
repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}"
print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
model.push_to_hub(repo_id=repo_id)
print("Process completed.")
# Finish wandb run
wandb.finish()