subtitles/en/64_using-a-custom-loss-function.srt (253 lines of code) (raw):
1
00:00:00,573 --> 00:00:01,636
(air whooshing)
2
00:00:01,636 --> 00:00:02,594
(logo popping)
3
00:00:02,594 --> 00:00:05,550
(metal sliding)
4
00:00:05,550 --> 00:00:07,500
- In this video, we take
a look at setting up
5
00:00:07,500 --> 00:00:09,303
a custom loss function for training.
6
00:00:10,980 --> 00:00:13,260
In the default loss function, all samples,
7
00:00:13,260 --> 00:00:15,840
such as these code snippets,
are treated the same
8
00:00:15,840 --> 00:00:18,960
irrespective of their content
but there are scenarios
9
00:00:18,960 --> 00:00:21,660
where it could make sense to
weight the samples differently.
10
00:00:21,660 --> 00:00:24,570
If, for example, one sample
contains a lot of tokens
11
00:00:24,570 --> 00:00:26,160
that are of interest to us
12
00:00:26,160 --> 00:00:29,910
or if a sample has a
favorable diversity of tokens.
13
00:00:29,910 --> 00:00:31,950
We can also implement other heuristics
14
00:00:31,950 --> 00:00:33,963
with pattern matching or other rules.
15
00:00:35,993 --> 00:00:39,150
For each sample, we get a
loss value during training
16
00:00:39,150 --> 00:00:41,850
and we can combine that
loss with a weight.
17
00:00:41,850 --> 00:00:43,860
Then we can create a weighted sum
18
00:00:43,860 --> 00:00:45,660
or average over all samples
19
00:00:45,660 --> 00:00:47,613
to get the final loss for the batch.
20
00:00:48,690 --> 00:00:51,240
Let's have a look at a specific example.
21
00:00:51,240 --> 00:00:52,830
We want to set up a language model
22
00:00:52,830 --> 00:00:56,073
that helps us autocomplete
common data science code.
23
00:00:57,030 --> 00:01:01,830
For that task, we would like
to weight samples stronger
24
00:01:01,830 --> 00:01:04,110
where tokens related to
the data science stack,
25
00:01:04,110 --> 00:01:07,353
such as pd or np, occur more frequently.
26
00:01:10,140 --> 00:01:13,080
Here you see a loss function
that does exactly that
27
00:01:13,080 --> 00:01:15,180
for causal language modeling.
28
00:01:15,180 --> 00:01:18,030
It takes the model's input
and predicted logits,
29
00:01:18,030 --> 00:01:20,343
as well as the key tokens, as input.
30
00:01:21,869 --> 00:01:25,113
First, the inputs and logits are aligned.
31
00:01:26,490 --> 00:01:29,310
Then the loss per sample is calculated,
32
00:01:29,310 --> 00:01:30,843
followed by the weights.
33
00:01:32,430 --> 00:01:35,583
Finally, the loss and the weights
are combined and returned.
34
00:01:36,540 --> 00:01:39,150
This is a pretty big function,
so let's take a closer look
35
00:01:39,150 --> 00:01:40,953
at the loss and the weight blocks.
36
00:01:43,380 --> 00:01:45,600
During the calculation
of the standard loss,
37
00:01:45,600 --> 00:01:48,930
the logits and labels are
flattened over the batch.
38
00:01:48,930 --> 00:01:52,590
With the view, we unflatten
the tensor to get the matrix
39
00:01:52,590 --> 00:01:55,320
with a row for each sample
in the batch and a column
40
00:01:55,320 --> 00:01:57,723
for each position in the
sequence of the sample.
41
00:01:58,920 --> 00:02:00,600
We don't need the loss per position,
42
00:02:00,600 --> 00:02:04,083
so we average the loss over
all positions for each sample.
43
00:02:06,150 --> 00:02:08,970
For the weights, we use
Boolean logic to get a tensor
44
00:02:08,970 --> 00:02:12,483
with 1s where a keyword
occurred and 0s where not.
45
00:02:13,440 --> 00:02:15,690
This tensor has an additional dimension
46
00:02:15,690 --> 00:02:18,540
as the loss tensor we
just saw because we get
47
00:02:18,540 --> 00:02:21,693
the information for each
keyword in a separate matrix.
48
00:02:22,770 --> 00:02:24,120
We only want to know
49
00:02:24,120 --> 00:02:26,880
how many times keywords
occurred per sample,
50
00:02:26,880 --> 00:02:30,693
so we can sum overall keywords
and all positions per sample.
51
00:02:33,450 --> 00:02:35,010
Now we're almost there.
52
00:02:35,010 --> 00:02:38,850
We only need to combine the
loss with the weight per sample.
53
00:02:38,850 --> 00:02:41,790
We do this with element
wise multiplication
54
00:02:41,790 --> 00:02:45,233
and then average overall
samples in the batch.
55
00:02:45,233 --> 00:02:46,066
In the end,
56
00:02:46,066 --> 00:02:49,110
we have exactly one loss
value for the whole batch
57
00:02:49,110 --> 00:02:51,330
and this is the whole necessary logic
58
00:02:51,330 --> 00:02:53,223
to create a custom weighted loss.
59
00:02:56,250 --> 00:02:59,010
Let's see how we can make
use of that custom loss
60
00:02:59,010 --> 00:03:00,753
with Accelerate and the Trainer.
61
00:03:01,710 --> 00:03:04,656
In Accelerate, we just pass the input_ids
62
00:03:04,656 --> 00:03:05,730
to the model to get the logits
63
00:03:05,730 --> 00:03:08,103
and then we can call the
custom loss function.
64
00:03:09,000 --> 00:03:11,310
After that, we continue with
the normal training loop
65
00:03:11,310 --> 00:03:13,083
by, for example, calling backward.
66
00:03:14,010 --> 00:03:15,570
For the Trainer, we can overwrite
67
00:03:15,570 --> 00:03:19,260
the compute loss function
of the standard trainer.
68
00:03:19,260 --> 00:03:20,970
We just need to make sure that we return
69
00:03:20,970 --> 00:03:24,450
the loss and the model
outputs in the same format.
70
00:03:24,450 --> 00:03:27,570
With that, you can integrate
your own awesome loss function
71
00:03:27,570 --> 00:03:29,763
with both the Trainer and Accelerate.
72
00:03:31,389 --> 00:03:34,056
(air whooshing)