def main()

in jamba1.5-retriever/scripts/train.py [0:0]


def main():

    # Command-line arguments for hyperparameters
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--eval_batch_size", type=int, default=32)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--model_name", type=str, default="ai21labs/AI21-Jamba-1.5-Mini")
    parser.add_argument("--output_dir", type=str, default="/opt/ml/model")
    parser.add_argument("--log_dir", type=str, default="/opt/ml/output")
    parser.add_argument("--cache_dir_ds", type=str, default="/opt/ml/dataset_cache")
    parser.add_argument("--cache_dir_model", type=str, default="/opt/ml/model_cache")
    parser.add_argument("--huggingface_token", type=str, default="<myToken>")
    parser.add_argument("--dataset_name", type=str, default="stsb_multi_mt")

    args = parser.parse_args()

    print("Processing Datasets and building Training Configurations" )

    print("load Tokenizer")
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir_model, token=args.huggingface_token)
    print("After tokenizer / Before model load")

    print("load Model")
    model = AutoModel.from_pretrained(args.model_name,cache_dir=args.cache_dir_model, token=args.huggingface_token)
    print("After model load / Before dataset load")

    # Load and preprocess dataset
    train_ds, test_ds, dev_ds = load_dataset(args.dataset_name, 'en', split=['train[:20%]','test[:20%]','dev[:20%]'], cache_dir=args.cache_dir_ds)
    train_ds_size = train_ds.num_rows
    test_ds_size = test_ds.num_rows
    dev_ds_size = dev_ds.num_rows
    print(f"After dataset load. # of rows: train {train_ds_size}, test {test_ds_size}, dev {dev_ds_size}")

    print("Tokenizing and formatting Datasets")
    tokenized_train_ds = train_ds.map(lambda examples: preprocess_function(examples, tokenizer, max_lenght=400), batched=True)
    tokenized_test_ds = test_ds.map(lambda examples: preprocess_function(examples, tokenizer, max_lenght=400), batched=True)
    tokenized_dev_ds = dev_ds.map(lambda examples: preprocess_function(examples, tokenizer, max_lenght=400), batched=True)

    print("Here are the first row for each dataset split after processing: Train, Test and Dev")
    print(tokenized_train_ds[0])
    print(tokenized_test_ds[0])
    print(tokenized_dev_ds[0])

    # Define Hugging Face Datasets compatible format
    tokenized_train_ds.set_format(type='torch', columns=['input_ids1', 'attention_mask1', 'input_ids2', 'attention_mask2', 'labels'])
    tokenized_test_ds.set_format(type='torch', columns=['input_ids1', 'attention_mask1', 'input_ids2', 'attention_mask2', 'labels'])
    tokenized_dev_ds.set_format(type='torch', columns=['input_ids1', 'attention_mask1', 'input_ids2', 'attention_mask2', 'labels'])
    print("First rows for each dataset tensors: ")
    print(tokenized_train_ds[0])
    print(tokenized_test_ds[0])
    print(tokenized_dev_ds[0])

    # Initialize the custom data collator
    data_collator = CustomDataCollatorWithPadding(tokenizer=tokenizer)

    # Define Step related metrics to drive training loop
    steps_per_epoch = train_ds_size // args.train_batch_size
    num_saves_per_epoch = 2
    total_steps = steps_per_epoch * args.epochs
    warmup_steps = int(0.1 * total_steps)

    # Define the training arguments
    training_args = TrainingArguments(

        # Output and Checkpointing
        output_dir=args.output_dir,
        save_strategy="steps",
        save_steps=steps_per_epoch // num_saves_per_epoch,
        save_total_limit=2,  
        load_best_model_at_end=True,

        # Training Control
        do_train=True,
        do_eval=True,
        do_predict=False,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        num_train_epochs=args.epochs,
        max_steps=-1,
        fp16=True,
        gradient_checkpointing=False,
        gradient_accumulation_steps=1,

        # Logging and Reporting
        logging_dir=args.log_dir,
        logging_steps=1,
        logging_first_step=True,
        report_to="tensorboard",

        # Evaluation Control
        evaluation_strategy="steps",
        eval_steps=steps_per_epoch // num_saves_per_epoch,
        eval_accumulation_steps=None,
        batch_eval_metrics=True,

        # Optimization
        learning_rate=args.learning_rate,
        warmup_steps=warmup_steps,
        lr_scheduler_type="linear",
        weight_decay=0.01,

        # Model Evaluation
        metric_for_best_model='eval_loss',
        greater_is_better=False,

        # Other
        remove_unused_columns=False,
        label_smoothing_factor=0.0
    )

    # Initialize the Trainer based on Custom Trainer Class. Needed for Compute_loss and Prediction_Step overrides
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_ds,
        eval_dataset=tokenized_dev_ds,
        compute_metrics=compute_metrics,
        data_collator=data_collator
    )

    print("Starting Training")
    # Train the model
    trainer.train()
    print("Training is done")

    print("Start Final Evaluation")
    trainer.evaluate(eval_dataset=tokenized_test_ds)
    print("Final Evaluation done")

    # Save the model and tokenizers
    print("Saving the Model and Tokenizer")
    tokenizer.save_pretrained(args.output_dir)
    trainer.save_model(args.output_dir)

    print("Model is ready for deployment")