subtitles/en/24_the-trainer-api.srt (293 lines of code) (raw):
1
00:00:00,304 --> 00:00:01,285
(air whooshing)
2
00:00:01,285 --> 00:00:02,345
(air popping)
3
00:00:02,345 --> 00:00:05,698
(air whooshing)
4
00:00:05,698 --> 00:00:06,548
- So Trainer API.
5
00:00:08,070 --> 00:00:10,040
So Transformers Library
provides a Trainer API
6
00:00:10,040 --> 00:00:13,320
that allows you to easily
find tune transformers models
7
00:00:13,320 --> 00:00:14,193
on your dataset.
8
00:00:15,150 --> 00:00:17,250
So Trainer class takes your datasets,
9
00:00:17,250 --> 00:00:19,900
your model as well as the
training hyperparameters
10
00:00:20,820 --> 00:00:23,310
and can perform the training
on any kind of setup,
11
00:00:23,310 --> 00:00:26,654
CPU, GPU, multiple GPUs, TPUs
12
00:00:26,654 --> 00:00:28,680
can also compute the predictions
13
00:00:28,680 --> 00:00:31,710
on any dataset and if you provided metrics
14
00:00:31,710 --> 00:00:33,813
evaluate your model on any dataset.
15
00:00:34,950 --> 00:00:36,930
You can also involve final data processing
16
00:00:36,930 --> 00:00:38,670
such as dynamic padding,
17
00:00:38,670 --> 00:00:40,377
as long as you provide the tokenizer
18
00:00:40,377 --> 00:00:42,693
or given data collator.
19
00:00:43,572 --> 00:00:45,900
We will try this API on the MRPC dataset,
20
00:00:45,900 --> 00:00:48,492
since it's relatively small
and easy to preprocess.
21
00:00:48,492 --> 00:00:49,325
As we saw in the Datasets overview video,
22
00:00:49,325 --> 00:00:54,325
however we can preprocess it.
23
00:00:54,511 --> 00:00:57,030
We do not apply padding
during the preprocessing,
24
00:00:57,030 --> 00:00:58,590
as we will use dynamic padding
25
00:00:58,590 --> 00:01:00,083
before DataCollatorWithPadding.
26
00:01:01,170 --> 00:01:02,790
Note that we don't do the final steps
27
00:01:02,790 --> 00:01:04,830
of renaming removing columns
28
00:01:04,830 --> 00:01:06,873
or set the format to torch tensors.
29
00:01:07,710 --> 00:01:10,560
So Trainer will do all of
this automatically for us
30
00:01:10,560 --> 00:01:12,633
by analyzing the model signature.
31
00:01:14,054 --> 00:01:16,650
The last step before
creating the Trainer are
32
00:01:16,650 --> 00:01:17,940
to define a model
33
00:01:17,940 --> 00:01:20,250
and some training hyperparameters.
34
00:01:20,250 --> 00:01:22,653
We saw to do the first
in the model API video.
35
00:01:23,734 --> 00:01:26,790
For the second we use the
TrainingArguments class.
36
00:01:26,790 --> 00:01:28,710
It only takes a path to a folder
37
00:01:28,710 --> 00:01:30,900
where results and
checkpoint will be saved,
38
00:01:30,900 --> 00:01:33,060
but you can also customize
all the hyperparameters
39
00:01:33,060 --> 00:01:34,470
your Trainer will use,
40
00:01:34,470 --> 00:01:37,270
learning weight, number of
training impacts, et. cetera.
41
00:01:38,190 --> 00:01:39,660
It's been very easy to create a Trainer
42
00:01:39,660 --> 00:01:41,400
and launch a training.
43
00:01:41,400 --> 00:01:43,170
You should display a progress bar
44
00:01:43,170 --> 00:01:45,900
and after a few minutes
if you're running on a GPU
45
00:01:45,900 --> 00:01:48,000
you should have the training finished.
46
00:01:48,000 --> 00:01:50,790
The result will be rather
anticlimactic however,
47
00:01:50,790 --> 00:01:52,710
as you will only get a training loss
48
00:01:52,710 --> 00:01:54,300
which doesn't really tell you anything
49
00:01:54,300 --> 00:01:56,820
about how well your model is performing.
50
00:01:56,820 --> 00:01:58,977
This is because we
didn't specify any metric
51
00:01:58,977 --> 00:02:00,273
for the evaluation.
52
00:02:01,200 --> 00:02:02,160
To get those metrics,
53
00:02:02,160 --> 00:02:03,810
we will first gather the predictions
54
00:02:03,810 --> 00:02:06,513
on the whole evaluation set
using the predict method.
55
00:02:07,440 --> 00:02:10,020
It returns a namedtuple with three fields,
56
00:02:10,020 --> 00:02:12,990
Prediction, which contains
the model of predictions.
57
00:02:12,990 --> 00:02:15,030
Label_IDs, which contains the labels
58
00:02:15,030 --> 00:02:16,800
if your dataset had them
59
00:02:16,800 --> 00:02:18,570
and metrics which is empty here.
60
00:02:18,570 --> 00:02:20,520
We're trying to do that.
61
00:02:20,520 --> 00:02:22,470
The predictions are the
logits of the models
62
00:02:22,470 --> 00:02:24,143
for all the sentences in the dataset.
63
00:02:24,143 --> 00:02:27,513
So a NumPy array of shape 408 by 2.
64
00:02:28,500 --> 00:02:30,270
To match them with our labels,
65
00:02:30,270 --> 00:02:31,590
we need to take the maximum logit
66
00:02:31,590 --> 00:02:32,850
for each prediction
67
00:02:32,850 --> 00:02:35,820
to know which of the two
classes was predicted.
68
00:02:35,820 --> 00:02:37,683
We do this with the argmax function.
69
00:02:38,640 --> 00:02:41,550
Then we can use a metric
from the Datasets library.
70
00:02:41,550 --> 00:02:43,500
It can be loaded as easily as a dataset
71
00:02:43,500 --> 00:02:45,360
with the load metric function
72
00:02:45,360 --> 00:02:49,500
and each returns the evaluation
metric used for the dataset.
73
00:02:49,500 --> 00:02:51,600
We can see our model did learn something
74
00:02:51,600 --> 00:02:54,363
as it is 85.7% accurate.
75
00:02:55,440 --> 00:02:57,870
To monitor the evaluation
matrix during training,
76
00:02:57,870 --> 00:02:59,829
we need to define a
compute_metrics function
77
00:02:59,829 --> 00:03:02,670
that does the same step as before.
78
00:03:02,670 --> 00:03:04,728
It takes a namedtuple with
predictions and labels
79
00:03:04,728 --> 00:03:06,327
and must return a dictionary
80
00:03:06,327 --> 00:03:08,427
with the metrics we want to keep track of.
81
00:03:09,360 --> 00:03:11,490
By passing the epoch evaluation strategy
82
00:03:11,490 --> 00:03:13,080
to our training arguments,
83
00:03:13,080 --> 00:03:14,490
we tell the Trainer to evaluate
84
00:03:14,490 --> 00:03:15,903
at the end of every epoch.
85
00:03:17,280 --> 00:03:18,587
Launching a training inside a notebook
86
00:03:18,587 --> 00:03:20,640
will then display a progress bar
87
00:03:20,640 --> 00:03:23,643
and complete the table you see
here as you pass every epoch.
88
00:03:25,400 --> 00:03:28,249
(air whooshing)
89
00:03:28,249 --> 00:03:29,974
(air decrescendos)