subtitles/zh-CN/24_the-trainer-api.srt (356 lines of code) (raw):
1
00:00:00,304 --> 00:00:01,285
(空气呼啸)
(air whooshing)
2
00:00:01,285 --> 00:00:02,345
(空气爆裂声)
(air popping)
3
00:00:02,345 --> 00:00:05,698
(空气呼啸)
(air whooshing)
4
00:00:05,698 --> 00:00:06,548
- Trainer API。
- The Trainer API.
5
00:00:08,070 --> 00:00:10,040
Transformers Library 提供了一个 Trainer API
The Transformers Library provides a Trainer API
6
00:00:10,040 --> 00:00:13,320
让你能够轻松的微调 Transformer 模型
that allows you to easily fine-tune transformer models
7
00:00:13,320 --> 00:00:14,193
在你自己的数据集上。
on your own dataset.
8
00:00:15,150 --> 00:00:17,250
Trainer 类接受你的数据集,
The Trainer class take your datasets,
9
00:00:17,250 --> 00:00:19,900
你的模型以及训练超参数
your model as well as the training hyperparameters
10
00:00:20,820 --> 00:00:23,310
并且可以在任何类型的设置上进行训练,
and can perform the training on any kind of setup,
11
00:00:23,310 --> 00:00:26,654
(CPU、GPU、多个 GPU、TPUs)
(CPU, GPU, multiple GPUs, TPUs)
12
00:00:26,654 --> 00:00:28,680
也可以在任意数据集上进行推理
can also compute the predictions
13
00:00:28,680 --> 00:00:31,710
如果你提供了指标
on any dataset and if you provided metrics
14
00:00:31,710 --> 00:00:33,813
就可以在任意数据集上评估你的模型。
evaluate your model on any dataset.
15
00:00:34,950 --> 00:00:36,930
Trainer 也可以负责最后的数据处理
It can also handle final data processing
16
00:00:36,930 --> 00:00:38,670
比如动态填充,
such as dynamic padding,
17
00:00:38,670 --> 00:00:40,377
只要你提供 tokenizer
as long as you provide the tokenizer
18
00:00:40,377 --> 00:00:42,693
或给定 data collator。
or a given data collator.
19
00:00:43,572 --> 00:00:45,900
我们将在 MRPC 数据集上尝试这个 API,
We will try this API on the MRPC dataset,
20
00:00:45,900 --> 00:00:48,492
因为该数据集相对较小且易于预处理。
since it's relatively small and easy to preprocess.
21
00:00:48,492 --> 00:00:49,325
正如我们在 Datasets 概述视频中看到的那样,
As we saw in the Datasets overview video,
22
00:00:49,325 --> 00:00:54,325
我们可以像这样对其进行预处理。
here is how we can preprocess it.
23
00:00:54,511 --> 00:00:57,030
我们在预处理过程中不进行填充,
We do not apply padding during the preprocessing,
24
00:00:57,030 --> 00:00:58,590
因为我们将进行动态填充
as we will use dynamic padding
25
00:00:58,590 --> 00:01:00,083
使用我们的 DataCollatorWithPadding。
with our DataCollatorWithPadding.
26
00:01:01,170 --> 00:01:02,790
请注意,我们不执行一些最终的数据处理步骤
Note that we don't do the final steps
27
00:01:02,790 --> 00:01:04,830
像是重命名列/删除列
of renaming/removing columns
28
00:01:04,830 --> 00:01:06,873
或将格式设置为 torch 张量。
or set the format to torch tensors.
29
00:01:07,710 --> 00:01:10,560
Trainer 会自动为我们做这一切
The Trainer will do all of this automatically for us
30
00:01:10,560 --> 00:01:12,633
通过分析模型签名。
by analyzing the model signature.
31
00:01:14,054 --> 00:01:16,650
实例化 Trainer 前的最后一步是
The last step before creating the Trainer are
32
00:01:16,650 --> 00:01:17,940
定义模型
to define a model
33
00:01:17,940 --> 00:01:20,250
和一些训练超参数。
and some training hyperparameters.
34
00:01:20,250 --> 00:01:22,653
我们在 model API 视频中学会了如何定义模型。
We saw to do the first in the model API video.
35
00:01:23,734 --> 00:01:26,790
对于第二点,我们使用 TrainingArguments 类。
For the second we use the TrainingArguments class.
36
00:01:26,790 --> 00:01:28,710
Trainer 只需要一个文件夹的路径
It only takes a path to a folder
37
00:01:28,710 --> 00:01:30,900
用以保存结果和检查点,
where results and checkpoint will be saved,
38
00:01:30,900 --> 00:01:33,060
但你也可以自定义你的 Trainer 会使用的
but you can also customize all the hyperparameters
39
00:01:33,060 --> 00:01:34,470
所有超参数
your Trainer will use,
40
00:01:34,470 --> 00:01:37,270
比如学习率,训练几个 epoch 等等。
learning rate, number of training epochs etc.
41
00:01:38,190 --> 00:01:39,660
接下来实例化一个 Trainer 并开始训练
It's then very easy to create a Trainer
42
00:01:39,660 --> 00:01:41,400
就非常简单了。
and launch a training.
43
00:01:41,400 --> 00:01:43,170
这会显示一个进度条
This should display a progress bar
44
00:01:43,170 --> 00:01:45,900
几分钟后(如果你在 GPU 上运行)
and after a few minutes (if you're running on a GPU)
45
00:01:45,900 --> 00:01:48,000
你会完成训练。
you should have the training finished.
46
00:01:48,000 --> 00:01:50,790
然而,结果将是相当虎头蛇尾,
The result will be rather anticlimactic however,
47
00:01:50,790 --> 00:01:52,710
因为你只会得到训练损失
as you will only get a training loss
48
00:01:52,710 --> 00:01:54,300
这并没有真正告诉你
which doesn't really tell you anything
49
00:01:54,300 --> 00:01:56,820
你的模型表现如何。
about how well your model is performing.
50
00:01:56,820 --> 00:01:58,977
这是因为我们没有指定任何指标
This is because we didn't specify any metric
51
00:01:58,977 --> 00:02:00,273
用于评估。
for the evaluation.
52
00:02:01,200 --> 00:02:02,160
要获得这些指标,
To get those metrics,
53
00:02:02,160 --> 00:02:03,810
我们将首先使用预测方法
we will first gather the predictions
54
00:02:03,810 --> 00:02:06,513
在整个评估集上收集预测结果。
on the whole evaluation set using the predict method.
55
00:02:07,440 --> 00:02:10,020
它返回一个包含三个字段的命名元组,
It returns a namedtuple with three fields,
56
00:02:10,020 --> 00:02:12,990
Prediction (其中包含模型的预测),
Prediction(which contains the model predictions),
57
00:02:12,990 --> 00:02:15,030
Label_IDs (其中包含标签
Label_IDs(which contains the labels
58
00:02:15,030 --> 00:02:16,800
如果你的数据集有的话)
if your dataset had them)
59
00:02:16,800 --> 00:02:18,570
和指标(在本示例中是空的)
and metrics (which is empty here).
60
00:02:18,570 --> 00:02:20,520
我们正在努力做到这一点。
We're trying to do that.
61
00:02:20,520 --> 00:02:22,470
预测结果是模型
The predictions are the logits of the models
62
00:02:22,470 --> 00:02:24,143
对于数据集中的所有句子所输出的 logits。
for all the sentences in the dataset.
63
00:02:24,143 --> 00:02:27,513
所以是一个形状为 408 x 2 的 NumPy 数组。
So a NumPy array of shape 408 by 2.
64
00:02:28,500 --> 00:02:30,270
为了将它们与我们的标签相匹配,
To match them with our labels,
65
00:02:30,270 --> 00:02:31,590
我们需要取最大的 logit
we need to take the maximum logit
66
00:02:31,590 --> 00:02:32,850
对于每个预测
for each prediction
67
00:02:32,850 --> 00:02:35,820
(知道两个类别中的哪一个类是所预测的结果)
(to know which of the two classes was predicted.)
68
00:02:35,820 --> 00:02:37,683
我们使用 argmax 函数来做到这一点。
We do this with the argmax function.
69
00:02:38,640 --> 00:02:41,550
然后我们可以使用 Datasets library 中的指标。
Then we can use a metric from the Datasets library.
70
00:02:41,550 --> 00:02:43,500
它可以像数据集一样被轻松地加载
It can be loaded as easily as a dataset
71
00:02:43,500 --> 00:02:45,360
使用 load_metric 函数
with the load_metric function
72
00:02:45,360 --> 00:02:49,500
并且返回用于该数据集的评估指标。
and it returns the evaluation metric used for the dataset.
73
00:02:49,500 --> 00:02:51,600
我们可以看到我们的模型确实学到了一些东西
We can see our model did learn something
74
00:02:51,600 --> 00:02:54,363
因为它有 85.7% 的准确率。
as it is 85.7% accurate.
75
00:02:55,440 --> 00:02:57,870
为了在训练期间监控评估指标,
To monitor the evaluation metrics during training,
76
00:02:57,870 --> 00:02:59,829
我们需要定义一个 compute_metrics 函数
we need to define a compute_metrics function
77
00:02:59,829 --> 00:03:02,670
和以前一样的步骤。
that does the same step as before.
78
00:03:02,670 --> 00:03:04,728
它接收一个带有预测和标签的命名元组
It takes a namedtuple with predictions and labels
79
00:03:04,728 --> 00:03:06,327
并且返回一个字典
and must return a dictionary
80
00:03:06,327 --> 00:03:08,427
包含我们想要跟踪的指标。
with the metrics we want to keep track of.
81
00:03:09,360 --> 00:03:11,490
通过将评估策略设置为 epoch
By passing the epoch evaluation strategy
82
00:03:11,490 --> 00:03:13,080
对于我们的 TrainingArguments,
to our TrainingArguments,
83
00:03:13,080 --> 00:03:14,490
我们告诉 Trainer 去进行评估
we tell the Trainer to evaluate
84
00:03:14,490 --> 00:03:15,903
在每个 epoch 结束的时候。
at the end of every epoch.
85
00:03:17,280 --> 00:03:18,587
在 notebook 中启动训练
Launching a training inside a notebook
86
00:03:18,587 --> 00:03:20,640
会显示一个进度条
will then display a progress bar
87
00:03:20,640 --> 00:03:23,643
并在你运行完每个 epoch 时将数据填到你看到的这个表格。
and complete the table you see here as you pass every epoch.
88
00:03:25,400 --> 00:03:28,249
(空气呼啸)
(air whooshing)
89
00:03:28,249 --> 00:03:29,974
(空气渐弱)
(air decrescendos)