def training_function()

in notebooks/text-classification/scripts/train.py [0:0]


def training_function(args):
    # set seed
    set_seed(args.seed)

    # Load the dataset
    emotions = load_dataset("dair-ai/emotion")
    model_id = args.model_id
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Tokenize the dataset
    def tokenize_function(example):
        ret = tokenizer(
            example["text"],
            padding="max_length",
            truncation=True,
            max_length=args.train_max_length,
        )
        return ret

    tokenized_emotions = emotions.map(tokenize_function, batched=True)

    num_labels = len(emotions["train"].features["label"].names)

    # Load the model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_id,
        num_labels=num_labels,
    )

    if args.output_dir:
        output_dir = args.output_dir
    else:
        output_dir = f"{model_id}-finetuned"

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        num_train_epochs=args.epochs,
        max_steps=args.max_steps,
        do_train=True,
        bf16=True,
        logging_dir=f"{output_dir}/logs",
        logging_strategy="steps",
        logging_steps=500,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        # push to hub parameters
        push_to_hub=True if args.repository_id else False,
        hub_strategy="every_save",
        hub_model_id=args.repository_id if args.repository_id else None,
        hub_token=args.hf_token,
    )

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_emotions["train"],
        eval_dataset=tokenized_emotions["validation"],
        processing_class=tokenizer,
    )

    # Train the model
    train_result = trainer.train()
    metrics = train_result.metrics

    eval_dataset = tokenized_emotions["validation"]
    eval_metrics = trainer.evaluate(eval_dataset=eval_dataset)
    metrics.update(eval_metrics)
    trainer.log_metrics("train", metrics)

    trainer.save_model(output_dir)
    trainer.create_model_card()
    if args.repository_id:
        trainer.push_to_hub(repository_id=args.repository_id, token=args.hf_token)