subtitles/en/68_data-collators-a-tour.srt (508 lines of code) (raw):
1
00:00:00,670 --> 00:00:01,503
(whooshing sound)
2
00:00:01,503 --> 00:00:02,469
(sticker popping)
3
00:00:02,469 --> 00:00:05,302
(whooshing sound)
4
00:00:06,240 --> 00:00:08,220
In a lot of our examples,
5
00:00:08,220 --> 00:00:12,150
you're going to see DataCollators
popping up over and over.
6
00:00:12,150 --> 00:00:16,020
They're used in both PyTorch
and TensorFlow workflows,
7
00:00:16,020 --> 00:00:17,460
and maybe even in JAX,
8
00:00:17,460 --> 00:00:20,130
but no-one really knows
what's happening in JAX.
9
00:00:20,130 --> 00:00:21,840
We do have a research
team working on it though,
10
00:00:21,840 --> 00:00:23,970
so maybe they'll tell us soon.
11
00:00:23,970 --> 00:00:25,620
But coming back on topic.
12
00:00:25,620 --> 00:00:27,600
What are data collators?
13
00:00:27,600 --> 00:00:30,480
Data collators collate data.
14
00:00:30,480 --> 00:00:31,800
That's not that helpful.
15
00:00:31,800 --> 00:00:35,023
But to be more specific, they
put together a list of samples
16
00:00:35,023 --> 00:00:37,830
into a single training minibatch.
17
00:00:37,830 --> 00:00:38,910
For some tasks,
18
00:00:38,910 --> 00:00:41,670
the data collator can
be very straightforward.
19
00:00:41,670 --> 00:00:44,820
For example, when you're
doing sequence classification,
20
00:00:44,820 --> 00:00:47,010
all you really need
from your data collator
21
00:00:47,010 --> 00:00:49,860
is that it pads your
samples to the same length
22
00:00:49,860 --> 00:00:52,413
and concatenates them
into a single Tensor.
23
00:00:53,340 --> 00:00:57,750
But for other workflows, data
collators can be quite complex
24
00:00:57,750 --> 00:00:59,910
as they handle some of the preprocessing
25
00:00:59,910 --> 00:01:02,340
needed for that particular task.
26
00:01:02,340 --> 00:01:04,800
So, if you want to use a data collator,
27
00:01:04,800 --> 00:01:07,860
for PyTorch users, you
usually pass the data collator
28
00:01:07,860 --> 00:01:09,780
to your Trainer object.
29
00:01:09,780 --> 00:01:11,310
In TensorFlow, it's a bit different.
30
00:01:11,310 --> 00:01:12,960
The easiest way to use a data collator
31
00:01:12,960 --> 00:01:16,860
is to pass it to the to_tf_dataset
method of your dataset.
32
00:01:16,860 --> 00:01:20,198
And this will give you a
tensorflow_tf_data.dataset
33
00:01:20,198 --> 00:01:22,743
that you can then pass to model.fit.
34
00:01:23,580 --> 00:01:25,890
You'll see these approaches
used in the examples
35
00:01:25,890 --> 00:01:28,068
and notebooks throughout this course.
36
00:01:28,068 --> 00:01:30,180
Also note that all of our collators
37
00:01:30,180 --> 00:01:32,610
take a return_tensors argument.
38
00:01:32,610 --> 00:01:35,737
You can set this to "pt"
to get PyTorch Tensors,
39
00:01:35,737 --> 00:01:37,920
"tf" to get TensorFlow Tensors,
40
00:01:37,920 --> 00:01:40,404
or "np" to get Numpy arrays.
41
00:01:40,404 --> 00:01:42,450
For backward compatibility reasons,
42
00:01:42,450 --> 00:01:44,460
the default value is "pt",
43
00:01:44,460 --> 00:01:47,160
so PyTorch users don't even
have to set this argument
44
00:01:47,160 --> 00:01:48,270
most of the time.
45
00:01:48,270 --> 00:01:50,820
And so as a result, they're
often totally unaware
46
00:01:50,820 --> 00:01:52,713
that this argument even exists.
47
00:01:53,730 --> 00:01:55,050
We can learn something from this
48
00:01:55,050 --> 00:01:57,120
which is that the
beneficiaries of privilege
49
00:01:57,120 --> 00:01:59,793
are often the most blind to its existence.
50
00:02:00,690 --> 00:02:01,920
But okay, coming back.
51
00:02:01,920 --> 00:02:06,540
Let's see how some specific
data collators work in action.
52
00:02:06,540 --> 00:02:08,070
Although again, remember if none
53
00:02:08,070 --> 00:02:09,900
of the built-in data
collators do what you need,
54
00:02:09,900 --> 00:02:13,650
you can always write your own
and they're often quite short.
55
00:02:13,650 --> 00:02:16,950
So first, we'll see the
"basic" data collators.
56
00:02:16,950 --> 00:02:20,433
These are DefaultDataCollator
and DataCollatorWithPadding.
57
00:02:21,420 --> 00:02:22,830
These are the ones you should use
58
00:02:22,830 --> 00:02:24,720
if your labels are straightforward
59
00:02:24,720 --> 00:02:27,300
and your data doesn't need
any special processing
60
00:02:27,300 --> 00:02:29,673
before being ready for training.
61
00:02:29,673 --> 00:02:31,272
Notice that because different models
62
00:02:31,272 --> 00:02:33,690
have different padding tokens,
63
00:02:33,690 --> 00:02:37,170
DataCollatorWithPadding will
need your model's Tokenizer
64
00:02:37,170 --> 00:02:40,150
so it knows how to pad sequences properly.
65
00:02:40,150 --> 00:02:44,790
The default data collator
doesn't need a Tokenizer to work,
66
00:02:44,790 --> 00:02:46,710
but it will as a result throw an error
67
00:02:46,710 --> 00:02:48,900
unless all of your sequences
are the same length.
68
00:02:48,900 --> 00:02:50,500
So, you should be aware of that.
69
00:02:51,480 --> 00:02:52,860
Moving on though.
70
00:02:52,860 --> 00:02:54,300
A lot of the other data collators
71
00:02:54,300 --> 00:02:56,130
aside from the basic two are,
72
00:02:56,130 --> 00:02:59,490
they're usually designed to
handle one specific task.
73
00:02:59,490 --> 00:03:01,050
And so, I'm going to show a couple here.
74
00:03:01,050 --> 00:03:04,320
These are
DataCollatorForTokenClassification
75
00:03:04,320 --> 00:03:06,447
and DataCollatorForSeqToSeq.
76
00:03:06,447 --> 00:03:09,540
And the reason these tasks
need special collators
77
00:03:09,540 --> 00:03:12,600
is because their labels
are variable in length.
78
00:03:12,600 --> 00:03:15,960
In token classification there's
one label for each token,
79
00:03:15,960 --> 00:03:17,400
and so the length of the labels
80
00:03:17,400 --> 00:03:18,993
is the length of the sequence.
81
00:03:20,280 --> 00:03:23,520
While in SeqToSeq the labels
are a sequence of tokens
82
00:03:23,520 --> 00:03:24,780
that can be variable length,
83
00:03:24,780 --> 00:03:25,800
that can be very different
84
00:03:25,800 --> 00:03:28,200
from the length of the input sequence.
85
00:03:28,200 --> 00:03:32,880
So in both of these cases, we
handle collating that batch
86
00:03:32,880 --> 00:03:35,280
by padding the labels as well,
87
00:03:35,280 --> 00:03:37,410
as you can see here in this example.
88
00:03:37,410 --> 00:03:40,770
So, inputs and the labels
will need to be padded
89
00:03:40,770 --> 00:03:43,860
if we want to join
samples of variable length
90
00:03:43,860 --> 00:03:45,120
into the same minibatch.
91
00:03:45,120 --> 00:03:47,520
That's exactly what the data collators
92
00:03:47,520 --> 00:03:50,460
and that's exactly what these
data collators will do for us
93
00:03:50,460 --> 00:03:52,383
you know, for this particular task.
94
00:03:53,820 --> 00:03:56,070
So, there's one final data collator
95
00:03:56,070 --> 00:03:58,560
I want to show you as
well just in this lecture.
96
00:03:58,560 --> 00:04:00,473
And that's the
DataCollatorForLanguageModeling.
97
00:04:01,410 --> 00:04:03,390
So, it's very important, and it's firstly,
98
00:04:03,390 --> 00:04:05,820
because language models
are just so foundational
99
00:04:05,820 --> 00:04:09,720
to do for everything we
do with NLP these days.
100
00:04:09,720 --> 00:04:12,060
But secondly, because it has two modes
101
00:04:12,060 --> 00:04:14,760
that do two very different things.
102
00:04:14,760 --> 00:04:19,230
So you choose which mode you
want with the mlm argument.
103
00:04:19,230 --> 00:04:22,470
Set it to True for
masked language modeling,
104
00:04:22,470 --> 00:04:26,190
and set it to False for
causal language modeling.
105
00:04:26,190 --> 00:04:28,620
So, collating data for
causal language modeling
106
00:04:28,620 --> 00:04:30,750
is actually quite straightforward.
107
00:04:30,750 --> 00:04:32,640
The model is just making predictions
108
00:04:32,640 --> 00:04:35,460
for what token comes
next, and so your labels
109
00:04:35,460 --> 00:04:37,800
are more or less just
a copy of your inputs,
110
00:04:37,800 --> 00:04:39,090
and the collator will handle that
111
00:04:39,090 --> 00:04:42,240
and ensure that the inputs and
labels are padded correctly.
112
00:04:42,240 --> 00:04:44,910
When you set mlm to True though,
113
00:04:44,910 --> 00:04:46,786
you get quite different behavior,
114
00:04:46,786 --> 00:04:49,200
that's different from
any other data collator,
115
00:04:49,200 --> 00:04:51,660
and that's because setting mlm to True
116
00:04:51,660 --> 00:04:53,550
means masked language modeling
117
00:04:53,550 --> 00:04:55,680
and that means the labels need to be,
118
00:04:55,680 --> 00:04:58,080
you know, the inputs need to be masked.
119
00:04:58,080 --> 00:05:00,093
So, what does that look like?
120
00:05:01,050 --> 00:05:03,900
So, recall that in
masked language modeling,
121
00:05:03,900 --> 00:05:06,570
the model is not predicting the next word,
122
00:05:06,570 --> 00:05:09,240
instead we randomly mask out some tokens
123
00:05:09,240 --> 00:05:11,130
and the model predicts
all of them at once.
124
00:05:11,130 --> 00:05:12,780
So, it tries to kinda fill in the blanks
125
00:05:12,780 --> 00:05:14,790
for those masked tokens.
126
00:05:14,790 --> 00:05:18,210
But the process of random
masking is surprisingly complex.
127
00:05:18,210 --> 00:05:21,330
If we follow the protocol
from the original BERT paper,
128
00:05:21,330 --> 00:05:23,970
we need to replace some
tokens with a masked token,
129
00:05:23,970 --> 00:05:26,190
some other tokens with a random token,
130
00:05:26,190 --> 00:05:29,820
and then keep a third
set of tokens unchanged.
131
00:05:29,820 --> 00:05:30,840
Yeah, this is not the lecture
132
00:05:30,840 --> 00:05:33,903
to go into the specifics
of that or why we do it.
133
00:05:33,903 --> 00:05:36,660
You can always check out
the original BERT paper
134
00:05:36,660 --> 00:05:37,493
if you're curious.
135
00:05:37,493 --> 00:05:39,620
It's well written. It's
easy to understand.
136
00:05:40,650 --> 00:05:44,190
The main thing to know here
is that it can be a real pain
137
00:05:44,190 --> 00:05:46,770
and quite complex to
implement that yourself.
138
00:05:46,770 --> 00:05:49,740
But DataCollatorForLanguageModeling
will do it for you
139
00:05:49,740 --> 00:05:51,750
when you set mlm to True.
140
00:05:51,750 --> 00:05:54,690
And that's an example
of the more intricate
141
00:05:54,690 --> 00:05:57,870
preprocessing that some
of our data collators do.
142
00:05:57,870 --> 00:05:59,430
And that's it!
143
00:05:59,430 --> 00:06:01,920
So, this covers the most
commonly used data collators
144
00:06:01,920 --> 00:06:03,480
and the tasks they're used for.
145
00:06:03,480 --> 00:06:06,990
And hopefully, now you'll know
when to use data collators
146
00:06:06,990 --> 00:06:10,833
and which one to choose
for your specific task.
147
00:06:11,765 --> 00:06:14,598
(whooshing sound)