transformers_doc/ja/token_classification.ipynb (1,007 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Token classification"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hide_input": true
},
"outputs": [
{
"data": {
"text/html": [
"<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/wVHdVlPScxA?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#@title\n",
"from IPython.display import HTML\n",
"\n",
"HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/wVHdVlPScxA?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"トークン分類では、文内の個々のトークンにラベルを割り当てます。最も一般的なトークン分類タスクの 1 つは、固有表現認識 (NER) です。 NER は、人、場所、組織など、文内の各エンティティのラベルを見つけようとします。\n",
"\n",
"このガイドでは、次の方法を説明します。\n",
"\n",
"1. [WNUT 17](https://huggingface.co/datasets/wnut_17) データセットで [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased) を微調整して、新しいエンティティを検出します。\n",
"2. 微調整されたモデルを推論に使用します。\n",
"\n",
"<Tip>\n",
"\n",
"このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/token-classification) を確認することをお勧めします。\n",
"\n",
"</Tip>\n",
"\n",
"始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
"\n",
"```bash\n",
"pip install transformers datasets evaluate seqeval\n",
"```\n",
"モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load WNUT 17 dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"まず、🤗 データセット ライブラリから WNUT 17 データセットをロードします。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"wnut = load_dataset(\"wnut_17\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"次に、例を見てみましょう。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': '0',\n",
" 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" 'tokens': ['@paulwalk', 'It', \"'s\", 'the', 'view', 'from', 'where', 'I', \"'m\", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.']\n",
"}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wnut[\"train\"][0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`ner_tags`内の各数字はエンティティを表します。数値をラベル名に変換して、エンティティが何であるかを調べます。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[\n",
" \"O\",\n",
" \"B-corporation\",\n",
" \"I-corporation\",\n",
" \"B-creative-work\",\n",
" \"I-creative-work\",\n",
" \"B-group\",\n",
" \"I-group\",\n",
" \"B-location\",\n",
" \"I-location\",\n",
" \"B-person\",\n",
" \"I-person\",\n",
" \"B-product\",\n",
" \"I-product\",\n",
"]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label_list = wnut[\"train\"].features[f\"ner_tags\"].feature.names\n",
"label_list"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"各 `ner_tag` の前に付く文字は、エンティティのトークンの位置を示します。\n",
"\n",
"- `B-` はエンティティの始まりを示します。\n",
"- `I-` は、トークンが同じエンティティ内に含まれていることを示します (たとえば、`State` トークンは次のようなエンティティの一部です)\n",
" `Empire State Building`)。\n",
"- `0` は、トークンがどのエンティティにも対応しないことを示します。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hide_input": true
},
"outputs": [
{
"data": {
"text/html": [
"<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/iY2AZYdZAr0?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#@title\n",
"from IPython.display import HTML\n",
"\n",
"HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/iY2AZYdZAr0?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"次のステップでは、DistilBERT トークナイザーをロードして`tokens`フィールドを前処理します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"上の `tokens`フィールドの例で見たように、入力はすでにトークン化されているようです。しかし、実際には入力はまだトークン化されていないため、単語をサブワードにトークン化するには`is_split_into_words=True` を設定する必要があります。例えば:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['[CLS]', '@', 'paul', '##walk', 'it', \"'\", 's', 'the', 'view', 'from', 'where', 'i', \"'\", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example = wnut[\"train\"][0]\n",
"tokenized_input = tokenizer(example[\"tokens\"], is_split_into_words=True)\n",
"tokens = tokenizer.convert_ids_to_tokens(tokenized_input[\"input_ids\"])\n",
"tokens"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ただし、これによりいくつかの特別なトークン `[CLS]` と `[SEP]` が追加され、サブワードのトークン化により入力とラベルの間に不一致が生じます。 1 つのラベルに対応する 1 つの単語を 2 つのサブワードに分割できるようになりました。次の方法でトークンとラベルを再調整する必要があります。\n",
"\n",
"1. [`word_ids`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.BatchEncoding.word_ids) メソッドを使用して、すべてのトークンを対応する単語にマッピングします。\n",
"2. 特別なトークン `[CLS]` と `[SEP]` にラベル `-100` を割り当て、それらが PyTorch 損失関数によって無視されるようにします ([CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html))。\n",
"3. 特定の単語の最初のトークンのみにラベルを付けます。同じ単語の他のサブトークンに `-100`を割り当てます。\n",
"\n",
"トークンとラベルを再調整し、シーケンスを DistilBERT の最大入力長以下に切り詰める関数を作成する方法を次に示します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_and_align_labels(examples):\n",
" tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n",
"\n",
" labels = []\n",
" for i, label in enumerate(examples[f\"ner_tags\"]):\n",
" word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word.\n",
" previous_word_idx = None\n",
" label_ids = []\n",
" for word_idx in word_ids: # Set the special tokens to -100.\n",
" if word_idx is None:\n",
" label_ids.append(-100)\n",
" elif word_idx != previous_word_idx: # Only label the first token of a given word.\n",
" label_ids.append(label[word_idx])\n",
" else:\n",
" label_ids.append(-100)\n",
" previous_word_idx = word_idx\n",
" labels.append(label_ids)\n",
"\n",
" tokenized_inputs[\"labels\"] = labels\n",
" return tokenized_inputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"データセット全体に前処理関数を適用するには、🤗 Datasets `map` 関数を使用します。 `batched=True` を設定してデータセットの複数の要素を一度に処理することで、`map` 関数を高速化できます。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"次に、`DataCollatorWithPadding` を使用してサンプルのバッチを作成します。データセット全体を最大長までパディングするのではなく、照合中にバッチ内の最長の長さまで文を *動的にパディング* する方が効率的です。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import DataCollatorForTokenClassification\n",
"\n",
"data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import DataCollatorForTokenClassification\n",
"\n",
"data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors=\"tf\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用して、評価メソッドをすばやくロードできます。このタスクでは、[seqeval](https://huggingface.co/spaces/evaluate-metric/seqeval) フレームワークを読み込みます (🤗 Evaluate [クイック ツアー](https://huggingface.co/docs/evaluate/a_quick_tour) を参照してください) ) メトリクスの読み込みと計算の方法について詳しくは、こちらをご覧ください)。 Seqeval は実際に、精度、再現率、F1、精度などのいくつかのスコアを生成します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import evaluate\n",
"\n",
"seqeval = evaluate.load(\"seqeval\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"まず NER ラベルを取得してから、真の予測と真のラベルを `compute` に渡してスコアを計算する関数を作成します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"labels = [label_list[i] for i in example[f\"ner_tags\"]]\n",
"\n",
"\n",
"def compute_metrics(p):\n",
" predictions, labels = p\n",
" predictions = np.argmax(predictions, axis=2)\n",
"\n",
" true_predictions = [\n",
" [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
" true_labels = [\n",
" [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
"\n",
" results = seqeval.compute(predictions=true_predictions, references=true_labels)\n",
" return {\n",
" \"precision\": results[\"overall_precision\"],\n",
" \"recall\": results[\"overall_recall\"],\n",
" \"f1\": results[\"overall_f1\"],\n",
" \"accuracy\": results[\"overall_accuracy\"],\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"これで`compute_metrics`関数の準備が整いました。トレーニングをセットアップするときにこの関数に戻ります。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"モデルのトレーニングを開始する前に、`id2label`と`label2id`を使用して、予想される ID とそのラベルのマップを作成します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"id2label = {\n",
" 0: \"O\",\n",
" 1: \"B-corporation\",\n",
" 2: \"I-corporation\",\n",
" 3: \"B-creative-work\",\n",
" 4: \"I-creative-work\",\n",
" 5: \"B-group\",\n",
" 6: \"I-group\",\n",
" 7: \"B-location\",\n",
" 8: \"I-location\",\n",
" 9: \"B-person\",\n",
" 10: \"I-person\",\n",
" 11: \"B-product\",\n",
" 12: \"I-product\",\n",
"}\n",
"label2id = {\n",
" \"O\": 0,\n",
" \"B-corporation\": 1,\n",
" \"I-corporation\": 2,\n",
" \"B-creative-work\": 3,\n",
" \"I-creative-work\": 4,\n",
" \"B-group\": 5,\n",
" \"I-group\": 6,\n",
" \"B-location\": 7,\n",
" \"I-location\": 8,\n",
" \"B-person\": 9,\n",
" \"I-person\": 10,\n",
" \"B-product\": 11,\n",
" \"I-product\": 12,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<Tip>\n",
"\n",
"[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) を使用したモデルの微調整に慣れていない場合は、[ここ](https://huggingface.co/docs/transformers/main/ja/tasks/../training#train-with-pytorch-trainer) の基本的なチュートリアルをご覧ください。\n",
"\n",
"</Tip>\n",
"\n",
"これでモデルのトレーニングを開始する準備が整いました。 [AutoModelForTokenClassification](https://huggingface.co/docs/transformers/main/ja/model_doc/auto#transformers.AutoModelForTokenClassification) を使用して、予期されるラベルの数とラベル マッピングを指定して DistilBERT を読み込みます。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n",
"\n",
"model = AutoModelForTokenClassification.from_pretrained(\n",
" \"distilbert/distilbert-base-uncased\", num_labels=13, id2label=id2label, label2id=label2id\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"この時点で残っているステップは 3 つだけです。\n",
"\n",
"1. [TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) でトレーニング ハイパーパラメータを定義します。唯一の必須パラメータは、モデルの保存場所を指定する `output_dir` です。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) は連続スコアを評価し、トレーニング チェックポイントを保存します。\n",
"2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) に渡します。\n",
"3. [train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出してモデルを微調整します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"my_awesome_wnut_model\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=2,\n",
" weight_decay=0.01,\n",
" eval_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
" push_to_hub=True,\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_wnut[\"train\"],\n",
" eval_dataset=tokenized_wnut[\"test\"],\n",
" processing_class=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"トレーニングが完了したら、 [push_to_hub()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.push_to_hub) メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.push_to_hub()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<Tip>\n",
"\n",
"Keras を使用したモデルの微調整に慣れていない場合は、[こちら](https://huggingface.co/docs/transformers/main/ja/tasks/../training#train-a-tensorflow-model-with-keras) の基本的なチュートリアルをご覧ください。\n",
"\n",
"</Tip>\n",
"TensorFlow でモデルを微調整するには、オプティマイザー関数、学習率スケジュール、およびいくつかのトレーニング ハイパーパラメーターをセットアップすることから始めます。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import create_optimizer\n",
"\n",
"batch_size = 16\n",
"num_train_epochs = 3\n",
"num_train_steps = (len(tokenized_wnut[\"train\"]) // batch_size) * num_train_epochs\n",
"optimizer, lr_schedule = create_optimizer(\n",
" init_lr=2e-5,\n",
" num_train_steps=num_train_steps,\n",
" weight_decay_rate=0.01,\n",
" num_warmup_steps=0,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"次に、[TFAutoModelForTokenClassification](https://huggingface.co/docs/transformers/main/ja/model_doc/auto#transformers.TFAutoModelForTokenClassification) を使用して、予期されるラベルの数とラベル マッピングを指定して DistilBERT をロードできます。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TFAutoModelForTokenClassification\n",
"\n",
"model = TFAutoModelForTokenClassification.from_pretrained(\n",
" \"distilbert/distilbert-base-uncased\", num_labels=13, id2label=id2label, label2id=label2id\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[prepare_tf_dataset()](https://huggingface.co/docs/transformers/main/ja/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset) を使用して、データセットを `tf.data.Dataset` 形式に変換します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tf_train_set = model.prepare_tf_dataset(\n",
" tokenized_wnut[\"train\"],\n",
" shuffle=True,\n",
" batch_size=16,\n",
" collate_fn=data_collator,\n",
")\n",
"\n",
"tf_validation_set = model.prepare_tf_dataset(\n",
" tokenized_wnut[\"validation\"],\n",
" shuffle=False,\n",
" batch_size=16,\n",
" collate_fn=data_collator,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[`compile`](https://keras.io/api/models/model_training_apis/#compile-method) を使用してトレーニング用のモデルを設定します。 Transformers モデルにはすべてデフォルトのタスク関連の損失関数があるため、次の場合を除き、損失関数を指定する必要はないことに注意してください。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"model.compile(optimizer=optimizer) # No loss argument!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"トレーニングを開始する前にセットアップする最後の 2 つのことは、予測から連続スコアを計算することと、モデルをハブにプッシュする方法を提供することです。どちらも [Keras コールバック](https://huggingface.co/docs/transformers/main/ja/tasks/../main_classes/keras_callbacks) を使用して行われます。\n",
"\n",
"`compute_metrics` 関数を [KerasMetricCallback](https://huggingface.co/docs/transformers/main/ja/main_classes/keras_callbacks#transformers.KerasMetricCallback) に渡します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers.keras_callbacks import KerasMetricCallback\n",
"\n",
"metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_validation_set)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[PushToHubCallback](https://huggingface.co/docs/transformers/main/ja/main_classes/keras_callbacks#transformers.PushToHubCallback) でモデルとトークナイザーをプッシュする場所を指定します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers.keras_callbacks import PushToHubCallback\n",
"\n",
"push_to_hub_callback = PushToHubCallback(\n",
" output_dir=\"my_awesome_wnut_model\",\n",
" tokenizer=tokenizer,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"次に、コールバックをまとめてバンドルします。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"callbacks = [metric_callback, push_to_hub_callback]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ついに、モデルのトレーニングを開始する準備が整いました。トレーニングおよび検証データセット、エポック数、コールバックを指定して [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) を呼び出し、モデルを微調整します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3, callbacks=callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"トレーニングが完了すると、モデルは自動的にハブにアップロードされ、誰でも使用できるようになります。\n",
"\n",
"\n",
"<Tip>\n",
"\n",
"トークン分類のモデルを微調整する方法のより詳細な例については、対応するセクションを参照してください。\n",
"[PyTorch ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb)\n",
"または [TensorFlow ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb)。\n",
"\n",
"\n",
"</Tip>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"モデルを微調整したので、それを推論に使用できるようになりました。\n",
"\n",
"推論を実行したいテキストをいくつか取得します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"The Golden State Warriors are an American professional basketball team based in San Francisco.\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"推論用に微調整されたモデルを試す最も簡単な方法は、それを [pipeline()](https://huggingface.co/docs/transformers/main/ja/main_classes/pipelines#transformers.pipeline) で使用することです。モデルを使用して NER の`pipeline`をインスタンス化し、テキストをそれに渡します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'entity': 'B-location',\n",
" 'score': 0.42658573,\n",
" 'index': 2,\n",
" 'word': 'golden',\n",
" 'start': 4,\n",
" 'end': 10},\n",
" {'entity': 'I-location',\n",
" 'score': 0.35856336,\n",
" 'index': 3,\n",
" 'word': 'state',\n",
" 'start': 11,\n",
" 'end': 16},\n",
" {'entity': 'B-group',\n",
" 'score': 0.3064001,\n",
" 'index': 4,\n",
" 'word': 'warriors',\n",
" 'start': 17,\n",
" 'end': 25},\n",
" {'entity': 'B-location',\n",
" 'score': 0.65523505,\n",
" 'index': 13,\n",
" 'word': 'san',\n",
" 'start': 80,\n",
" 'end': 83},\n",
" {'entity': 'B-location',\n",
" 'score': 0.4668663,\n",
" 'index': 14,\n",
" 'word': 'francisco',\n",
" 'start': 84,\n",
" 'end': 93}]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import pipeline\n",
"\n",
"classifier = pipeline(\"ner\", model=\"stevhliu/my_awesome_wnut_model\")\n",
"classifier(text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"必要に応じて、`pipeline`の結果を手動で複製することもできます。\n",
"\n",
"テキストをトークン化して PyTorch テンソルを返します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"stevhliu/my_awesome_wnut_model\")\n",
"inputs = tokenizer(text, return_tensors=\"pt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"入力をモデルに渡し、`logits`を返します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForTokenClassification\n",
"\n",
"model = AutoModelForTokenClassification.from_pretrained(\"stevhliu/my_awesome_wnut_model\")\n",
"with torch.no_grad():\n",
" logits = model(**inputs).logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最も高い確率でクラスを取得し、モデルの `id2label` マッピングを使用してそれをテキスト ラベルに変換します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['O',\n",
" 'O',\n",
" 'B-location',\n",
" 'I-location',\n",
" 'B-group',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'B-location',\n",
" 'B-location',\n",
" 'O',\n",
" 'O']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions = torch.argmax(logits, dim=2)\n",
"predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]\n",
"predicted_token_class"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"テキストをトークン化し、TensorFlow テンソルを返します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"stevhliu/my_awesome_wnut_model\")\n",
"inputs = tokenizer(text, return_tensors=\"tf\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"入力をモデルに渡し、`logits`を返します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TFAutoModelForTokenClassification\n",
"\n",
"model = TFAutoModelForTokenClassification.from_pretrained(\"stevhliu/my_awesome_wnut_model\")\n",
"logits = model(**inputs).logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最も高い確率でクラスを取得し、モデルの `id2label` マッピングを使用してそれをテキスト ラベルに変換します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['O',\n",
" 'O',\n",
" 'B-location',\n",
" 'I-location',\n",
" 'B-group',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'O',\n",
" 'B-location',\n",
" 'B-location',\n",
" 'O',\n",
" 'O']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_token_class_ids = tf.math.argmax(logits, axis=-1)\n",
"predicted_token_class = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]\n",
"predicted_token_class"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 4
}