1 00:00:00,304 --> 00:00:01,285 (空气呼啸) (air whooshing) 2 00:00:01,285 --> 00:00:02,345 (空气爆裂声) (air popping) 3 00:00:02,345 --> 00:00:05,698 (空气呼啸) (air whooshing) 4 00:00:05,698 --> 00:00:06,548 - Trainer API。 - The Trainer API. 5 00:00:08,070 --> 00:00:10,040 Transformers Library 提供了一个 Trainer API The Transformers Library provides a Trainer API 6 00:00:10,040 --> 00:00:13,320 让你能够轻松的微调 Transformer 模型 that allows you to easily fine-tune transformer models 7 00:00:13,320 --> 00:00:14,193 在你自己的数据集上。 on your own dataset. 8 00:00:15,150 --> 00:00:17,250 Trainer 类接受你的数据集, The Trainer class take your datasets, 9 00:00:17,250 --> 00:00:19,900 你的模型以及训练超参数 your model as well as the training hyperparameters 10 00:00:20,820 --> 00:00:23,310 并且可以在任何类型的设置上进行训练, and can perform the training on any kind of setup, 11 00:00:23,310 --> 00:00:26,654 (CPU、GPU、多个 GPU、TPUs) (CPU, GPU, multiple GPUs, TPUs) 12 00:00:26,654 --> 00:00:28,680 也可以在任意数据集上进行推理 can also compute the predictions 13 00:00:28,680 --> 00:00:31,710 如果你提供了指标 on any dataset and if you provided metrics 14 00:00:31,710 --> 00:00:33,813 就可以在任意数据集上评估你的模型。 evaluate your model on any dataset. 15 00:00:34,950 --> 00:00:36,930 Trainer 也可以负责最后的数据处理 It can also handle final data processing 16 00:00:36,930 --> 00:00:38,670 比如动态填充, such as dynamic padding, 17 00:00:38,670 --> 00:00:40,377 只要你提供 tokenizer as long as you provide the tokenizer 18 00:00:40,377 --> 00:00:42,693 或给定 data collator。 or a given data collator. 19 00:00:43,572 --> 00:00:45,900 我们将在 MRPC 数据集上尝试这个 API, We will try this API on the MRPC dataset, 20 00:00:45,900 --> 00:00:48,492 因为该数据集相对较小且易于预处理。 since it's relatively small and easy to preprocess. 21 00:00:48,492 --> 00:00:49,325 正如我们在 Datasets 概述视频中看到的那样, As we saw in the Datasets overview video, 22 00:00:49,325 --> 00:00:54,325 我们可以像这样对其进行预处理。 here is how we can preprocess it. 23 00:00:54,511 --> 00:00:57,030 我们在预处理过程中不进行填充, We do not apply padding during the preprocessing, 24 00:00:57,030 --> 00:00:58,590 因为我们将进行动态填充 as we will use dynamic padding 25 00:00:58,590 --> 00:01:00,083 使用我们的 DataCollatorWithPadding。 with our DataCollatorWithPadding. 26 00:01:01,170 --> 00:01:02,790 请注意,我们不执行一些最终的数据处理步骤 Note that we don't do the final steps 27 00:01:02,790 --> 00:01:04,830 像是重命名列/删除列 of renaming/removing columns 28 00:01:04,830 --> 00:01:06,873 或将格式设置为 torch 张量。 or set the format to torch tensors. 29 00:01:07,710 --> 00:01:10,560 Trainer 会自动为我们做这一切 The Trainer will do all of this automatically for us 30 00:01:10,560 --> 00:01:12,633 通过分析模型签名。 by analyzing the model signature. 31 00:01:14,054 --> 00:01:16,650 实例化 Trainer 前的最后一步是 The last step before creating the Trainer are 32 00:01:16,650 --> 00:01:17,940 定义模型 to define a model 33 00:01:17,940 --> 00:01:20,250 和一些训练超参数。 and some training hyperparameters. 34 00:01:20,250 --> 00:01:22,653 我们在 model API 视频中学会了如何定义模型。 We saw to do the first in the model API video. 35 00:01:23,734 --> 00:01:26,790 对于第二点,我们使用 TrainingArguments 类。 For the second we use the TrainingArguments class. 36 00:01:26,790 --> 00:01:28,710 Trainer 只需要一个文件夹的路径 It only takes a path to a folder 37 00:01:28,710 --> 00:01:30,900 用以保存结果和检查点, where results and checkpoint will be saved, 38 00:01:30,900 --> 00:01:33,060 但你也可以自定义你的 Trainer 会使用的 but you can also customize all the hyperparameters 39 00:01:33,060 --> 00:01:34,470 所有超参数 your Trainer will use, 40 00:01:34,470 --> 00:01:37,270 比如学习率,训练几个 epoch 等等。 learning rate, number of training epochs etc. 41 00:01:38,190 --> 00:01:39,660 接下来实例化一个 Trainer 并开始训练 It's then very easy to create a Trainer 42 00:01:39,660 --> 00:01:41,400 就非常简单了。 and launch a training. 43 00:01:41,400 --> 00:01:43,170 这会显示一个进度条 This should display a progress bar 44 00:01:43,170 --> 00:01:45,900 几分钟后(如果你在 GPU 上运行) and after a few minutes (if you're running on a GPU) 45 00:01:45,900 --> 00:01:48,000 你会完成训练。 you should have the training finished. 46 00:01:48,000 --> 00:01:50,790 然而,结果将是相当虎头蛇尾, The result will be rather anticlimactic however, 47 00:01:50,790 --> 00:01:52,710 因为你只会得到训练损失 as you will only get a training loss 48 00:01:52,710 --> 00:01:54,300 这并没有真正告诉你 which doesn't really tell you anything 49 00:01:54,300 --> 00:01:56,820 你的模型表现如何。 about how well your model is performing. 50 00:01:56,820 --> 00:01:58,977 这是因为我们没有指定任何指标 This is because we didn't specify any metric 51 00:01:58,977 --> 00:02:00,273 用于评估。 for the evaluation. 52 00:02:01,200 --> 00:02:02,160 要获得这些指标, To get those metrics, 53 00:02:02,160 --> 00:02:03,810 我们将首先使用预测方法 we will first gather the predictions 54 00:02:03,810 --> 00:02:06,513 在整个评估集上收集预测结果。 on the whole evaluation set using the predict method. 55 00:02:07,440 --> 00:02:10,020 它返回一个包含三个字段的命名元组, It returns a namedtuple with three fields, 56 00:02:10,020 --> 00:02:12,990 Prediction (其中包含模型的预测), Prediction(which contains the model predictions), 57 00:02:12,990 --> 00:02:15,030 Label_IDs (其中包含标签 Label_IDs(which contains the labels 58 00:02:15,030 --> 00:02:16,800 如果你的数据集有的话) if your dataset had them) 59 00:02:16,800 --> 00:02:18,570 和指标(在本示例中是空的) and metrics (which is empty here). 60 00:02:18,570 --> 00:02:20,520 我们正在努力做到这一点。 We're trying to do that. 61 00:02:20,520 --> 00:02:22,470 预测结果是模型 The predictions are the logits of the models 62 00:02:22,470 --> 00:02:24,143 对于数据集中的所有句子所输出的 logits。 for all the sentences in the dataset. 63 00:02:24,143 --> 00:02:27,513 所以是一个形状为 408 x 2 的 NumPy 数组。 So a NumPy array of shape 408 by 2. 64 00:02:28,500 --> 00:02:30,270 为了将它们与我们的标签相匹配, To match them with our labels, 65 00:02:30,270 --> 00:02:31,590 我们需要取最大的 logit we need to take the maximum logit 66 00:02:31,590 --> 00:02:32,850 对于每个预测 for each prediction 67 00:02:32,850 --> 00:02:35,820 (知道两个类别中的哪一个类是所预测的结果) (to know which of the two classes was predicted.) 68 00:02:35,820 --> 00:02:37,683 我们使用 argmax 函数来做到这一点。 We do this with the argmax function. 69 00:02:38,640 --> 00:02:41,550 然后我们可以使用 Datasets library 中的指标。 Then we can use a metric from the Datasets library. 70 00:02:41,550 --> 00:02:43,500 它可以像数据集一样被轻松地加载 It can be loaded as easily as a dataset 71 00:02:43,500 --> 00:02:45,360 使用 load_metric 函数 with the load_metric function 72 00:02:45,360 --> 00:02:49,500 并且返回用于该数据集的评估指标。 and it returns the evaluation metric used for the dataset. 73 00:02:49,500 --> 00:02:51,600 我们可以看到我们的模型确实学到了一些东西 We can see our model did learn something 74 00:02:51,600 --> 00:02:54,363 因为它有 85.7% 的准确率。 as it is 85.7% accurate. 75 00:02:55,440 --> 00:02:57,870 为了在训练期间监控评估指标, To monitor the evaluation metrics during training, 76 00:02:57,870 --> 00:02:59,829 我们需要定义一个 compute_metrics 函数 we need to define a compute_metrics function 77 00:02:59,829 --> 00:03:02,670 和以前一样的步骤。 that does the same step as before. 78 00:03:02,670 --> 00:03:04,728 它接收一个带有预测和标签的命名元组 It takes a namedtuple with predictions and labels 79 00:03:04,728 --> 00:03:06,327 并且返回一个字典 and must return a dictionary 80 00:03:06,327 --> 00:03:08,427 包含我们想要跟踪的指标。 with the metrics we want to keep track of. 81 00:03:09,360 --> 00:03:11,490 通过将评估策略设置为 epoch By passing the epoch evaluation strategy 82 00:03:11,490 --> 00:03:13,080 对于我们的 TrainingArguments, to our TrainingArguments, 83 00:03:13,080 --> 00:03:14,490 我们告诉 Trainer 去进行评估 we tell the Trainer to evaluate 84 00:03:14,490 --> 00:03:15,903 在每个 epoch 结束的时候。 at the end of every epoch. 85 00:03:17,280 --> 00:03:18,587 在 notebook 中启动训练 Launching a training inside a notebook 86 00:03:18,587 --> 00:03:20,640 会显示一个进度条 will then display a progress bar 87 00:03:20,640 --> 00:03:23,643 并在你运行完每个 epoch 时将数据填到你看到的这个表格。 and complete the table you see here as you pass every epoch. 88 00:03:25,400 --> 00:03:28,249 (空气呼啸) (air whooshing) 89 00:03:28,249 --> 00:03:29,974 (空气渐弱) (air decrescendos)