subtitles/en/24_the-trainer-api.srt (293 lines of code) (raw):

1 00:00:00,304 --> 00:00:01,285 (air whooshing) 2 00:00:01,285 --> 00:00:02,345 (air popping) 3 00:00:02,345 --> 00:00:05,698 (air whooshing) 4 00:00:05,698 --> 00:00:06,548 - So Trainer API. 5 00:00:08,070 --> 00:00:10,040 So Transformers Library provides a Trainer API 6 00:00:10,040 --> 00:00:13,320 that allows you to easily find tune transformers models 7 00:00:13,320 --> 00:00:14,193 on your dataset. 8 00:00:15,150 --> 00:00:17,250 So Trainer class takes your datasets, 9 00:00:17,250 --> 00:00:19,900 your model as well as the training hyperparameters 10 00:00:20,820 --> 00:00:23,310 and can perform the training on any kind of setup, 11 00:00:23,310 --> 00:00:26,654 CPU, GPU, multiple GPUs, TPUs 12 00:00:26,654 --> 00:00:28,680 can also compute the predictions 13 00:00:28,680 --> 00:00:31,710 on any dataset and if you provided metrics 14 00:00:31,710 --> 00:00:33,813 evaluate your model on any dataset. 15 00:00:34,950 --> 00:00:36,930 You can also involve final data processing 16 00:00:36,930 --> 00:00:38,670 such as dynamic padding, 17 00:00:38,670 --> 00:00:40,377 as long as you provide the tokenizer 18 00:00:40,377 --> 00:00:42,693 or given data collator. 19 00:00:43,572 --> 00:00:45,900 We will try this API on the MRPC dataset, 20 00:00:45,900 --> 00:00:48,492 since it's relatively small and easy to preprocess. 21 00:00:48,492 --> 00:00:49,325 As we saw in the Datasets overview video, 22 00:00:49,325 --> 00:00:54,325 however we can preprocess it. 23 00:00:54,511 --> 00:00:57,030 We do not apply padding during the preprocessing, 24 00:00:57,030 --> 00:00:58,590 as we will use dynamic padding 25 00:00:58,590 --> 00:01:00,083 before DataCollatorWithPadding. 26 00:01:01,170 --> 00:01:02,790 Note that we don't do the final steps 27 00:01:02,790 --> 00:01:04,830 of renaming removing columns 28 00:01:04,830 --> 00:01:06,873 or set the format to torch tensors. 29 00:01:07,710 --> 00:01:10,560 So Trainer will do all of this automatically for us 30 00:01:10,560 --> 00:01:12,633 by analyzing the model signature. 31 00:01:14,054 --> 00:01:16,650 The last step before creating the Trainer are 32 00:01:16,650 --> 00:01:17,940 to define a model 33 00:01:17,940 --> 00:01:20,250 and some training hyperparameters. 34 00:01:20,250 --> 00:01:22,653 We saw to do the first in the model API video. 35 00:01:23,734 --> 00:01:26,790 For the second we use the TrainingArguments class. 36 00:01:26,790 --> 00:01:28,710 It only takes a path to a folder 37 00:01:28,710 --> 00:01:30,900 where results and checkpoint will be saved, 38 00:01:30,900 --> 00:01:33,060 but you can also customize all the hyperparameters 39 00:01:33,060 --> 00:01:34,470 your Trainer will use, 40 00:01:34,470 --> 00:01:37,270 learning weight, number of training impacts, et. cetera. 41 00:01:38,190 --> 00:01:39,660 It's been very easy to create a Trainer 42 00:01:39,660 --> 00:01:41,400 and launch a training. 43 00:01:41,400 --> 00:01:43,170 You should display a progress bar 44 00:01:43,170 --> 00:01:45,900 and after a few minutes if you're running on a GPU 45 00:01:45,900 --> 00:01:48,000 you should have the training finished. 46 00:01:48,000 --> 00:01:50,790 The result will be rather anticlimactic however, 47 00:01:50,790 --> 00:01:52,710 as you will only get a training loss 48 00:01:52,710 --> 00:01:54,300 which doesn't really tell you anything 49 00:01:54,300 --> 00:01:56,820 about how well your model is performing. 50 00:01:56,820 --> 00:01:58,977 This is because we didn't specify any metric 51 00:01:58,977 --> 00:02:00,273 for the evaluation. 52 00:02:01,200 --> 00:02:02,160 To get those metrics, 53 00:02:02,160 --> 00:02:03,810 we will first gather the predictions 54 00:02:03,810 --> 00:02:06,513 on the whole evaluation set using the predict method. 55 00:02:07,440 --> 00:02:10,020 It returns a namedtuple with three fields, 56 00:02:10,020 --> 00:02:12,990 Prediction, which contains the model of predictions. 57 00:02:12,990 --> 00:02:15,030 Label_IDs, which contains the labels 58 00:02:15,030 --> 00:02:16,800 if your dataset had them 59 00:02:16,800 --> 00:02:18,570 and metrics which is empty here. 60 00:02:18,570 --> 00:02:20,520 We're trying to do that. 61 00:02:20,520 --> 00:02:22,470 The predictions are the logits of the models 62 00:02:22,470 --> 00:02:24,143 for all the sentences in the dataset. 63 00:02:24,143 --> 00:02:27,513 So a NumPy array of shape 408 by 2. 64 00:02:28,500 --> 00:02:30,270 To match them with our labels, 65 00:02:30,270 --> 00:02:31,590 we need to take the maximum logit 66 00:02:31,590 --> 00:02:32,850 for each prediction 67 00:02:32,850 --> 00:02:35,820 to know which of the two classes was predicted. 68 00:02:35,820 --> 00:02:37,683 We do this with the argmax function. 69 00:02:38,640 --> 00:02:41,550 Then we can use a metric from the Datasets library. 70 00:02:41,550 --> 00:02:43,500 It can be loaded as easily as a dataset 71 00:02:43,500 --> 00:02:45,360 with the load metric function 72 00:02:45,360 --> 00:02:49,500 and each returns the evaluation metric used for the dataset. 73 00:02:49,500 --> 00:02:51,600 We can see our model did learn something 74 00:02:51,600 --> 00:02:54,363 as it is 85.7% accurate. 75 00:02:55,440 --> 00:02:57,870 To monitor the evaluation matrix during training, 76 00:02:57,870 --> 00:02:59,829 we need to define a compute_metrics function 77 00:02:59,829 --> 00:03:02,670 that does the same step as before. 78 00:03:02,670 --> 00:03:04,728 It takes a namedtuple with predictions and labels 79 00:03:04,728 --> 00:03:06,327 and must return a dictionary 80 00:03:06,327 --> 00:03:08,427 with the metrics we want to keep track of. 81 00:03:09,360 --> 00:03:11,490 By passing the epoch evaluation strategy 82 00:03:11,490 --> 00:03:13,080 to our training arguments, 83 00:03:13,080 --> 00:03:14,490 we tell the Trainer to evaluate 84 00:03:14,490 --> 00:03:15,903 at the end of every epoch. 85 00:03:17,280 --> 00:03:18,587 Launching a training inside a notebook 86 00:03:18,587 --> 00:03:20,640 will then display a progress bar 87 00:03:20,640 --> 00:03:23,643 and complete the table you see here as you pass every epoch. 88 00:03:25,400 --> 00:03:28,249 (air whooshing) 89 00:03:28,249 --> 00:03:29,974 (air decrescendos)