course/en/chapter3/section3.ipynb (206 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tuning a model with the Trainer API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install datasets evaluate transformers[sentencepiece]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
"\n",
"raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
"checkpoint = \"bert-base-uncased\"\n",
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
"\n",
"\n",
"def tokenize_function(example):\n",
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
"\n",
"\n",
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\"test-trainer\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model,\n",
" training_args,\n",
" train_dataset=tokenized_datasets[\"train\"],\n",
" eval_dataset=tokenized_datasets[\"validation\"],\n",
" data_collator=data_collator,\n",
" processing_class=tokenizer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(408, 2) (408,)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions = trainer.predict(tokenized_datasets[\"validation\"])\n",
"print(predictions.predictions.shape, predictions.label_ids.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"preds = np.argmax(predictions.predictions, axis=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'accuracy': 0.8578431372549019, 'f1': 0.8996539792387542}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluate\n",
"\n",
"metric = evaluate.load(\"glue\", \"mrpc\")\n",
"metric.compute(predictions=preds, references=predictions.label_ids)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(eval_preds):\n",
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
" logits, labels = eval_preds\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
"\n",
"trainer = Trainer(\n",
" model,\n",
" training_args,\n",
" train_dataset=tokenized_datasets[\"train\"],\n",
" eval_dataset=tokenized_datasets[\"validation\"],\n",
" data_collator=data_collator,\n",
" processing_class=tokenizer,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
}
],
"metadata": {
"colab": {
"name": "Fine-tuning a model with the Trainer API or Keras",
"provenance": []
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 4
}