def main()

in ml/trainer.py [0:0]


def main():
    # Initialize wandb for logging
    wandb.init(project="kto")

    # Get timestamp at start of training
    training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    print("Loading base model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer(model_args)
    ref_model, _ = load_model_and_tokenizer(model_args)
    print("Models and tokenizer loaded.")

    # Load existing adapter or create new one
    loaded_model, previous_timestamp = load_latest_adapter(
        model,
        model_args.model_name,
        script_args.language
    )

    if loaded_model is not None:
        model = loaded_model
        print(f"Loaded existing adapter trained at {previous_timestamp}")
    else:
        # Initialize new LoRA adapter
        peft_config = get_peft_config(model_args)
        model = get_peft_model(model, peft_config)
        print("Initialized new adapter")

    # -----------------------------
    # Data Preparation and Training
    # -----------------------------
    print("Processing dataset...")
    dataset = script_args.process_dataset_func(script_args.language)
    print("Dataset processed.")

    print("Initializing trainer...")
    trainer = KTOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    # Training
    print("Starting training...")
    trainer.train()
    print("Training completed.")

    # Evaluation
    print("Evaluating model...")
    metrics = trainer.evaluate()
    print(f"Metrics: {metrics}")
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

    # Log metrics to wandb
    wandb.log({
        "epoch": metrics.get("epoch"),
        "grad_norm": metrics.get("grad_norm"),
        "kl": metrics.get("kl"),
        "learning_rate": metrics.get("learning_rate"),
        "logits/chosen": metrics.get("logits/chosen"),
        "logits/rejected": metrics.get("logits/rejected"),
        "logps/chosen": metrics.get("logps/chosen"),
        "logps/rejected": metrics.get("logps/rejected"),
        "loss": metrics.get("loss"),
        "rewards/chosen": metrics.get("rewards/chosen"),
        "rewards/margins": metrics.get("rewards/margins"),
        "rewards/rejected": metrics.get("rewards/rejected"),
        "step": metrics.get("step")
    })

    # Save the adapter
    adapter_path = get_adapter_path(
        model_args.model_name,
        script_args.language,
        training_timestamp
    )
    adapter_path.parent.mkdir(parents=True, exist_ok=True)

    print(f"Saving adapter to: {adapter_path}")
    model.save_pretrained(adapter_path)

    # Save metadata
    metadata = AdapterMetadata(
        training_timestamp=training_timestamp,
        model_name=model_args.model_name,
        language=script_args.language,
    )
    metadata.save(adapter_path / "metadata.json")

    if script_args.push_to_hub:
        repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}"
        print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
        model.push_to_hub(repo_id=repo_id)

    print("Process completed.")

    # Finish wandb run
    wandb.finish()