subtitles/en/64_using-a-custom-loss-function.srt (253 lines of code) (raw):

1 00:00:00,573 --> 00:00:01,636 (air whooshing) 2 00:00:01,636 --> 00:00:02,594 (logo popping) 3 00:00:02,594 --> 00:00:05,550 (metal sliding) 4 00:00:05,550 --> 00:00:07,500 - In this video, we take a look at setting up 5 00:00:07,500 --> 00:00:09,303 a custom loss function for training. 6 00:00:10,980 --> 00:00:13,260 In the default loss function, all samples, 7 00:00:13,260 --> 00:00:15,840 such as these code snippets, are treated the same 8 00:00:15,840 --> 00:00:18,960 irrespective of their content but there are scenarios 9 00:00:18,960 --> 00:00:21,660 where it could make sense to weight the samples differently. 10 00:00:21,660 --> 00:00:24,570 If, for example, one sample contains a lot of tokens 11 00:00:24,570 --> 00:00:26,160 that are of interest to us 12 00:00:26,160 --> 00:00:29,910 or if a sample has a favorable diversity of tokens. 13 00:00:29,910 --> 00:00:31,950 We can also implement other heuristics 14 00:00:31,950 --> 00:00:33,963 with pattern matching or other rules. 15 00:00:35,993 --> 00:00:39,150 For each sample, we get a loss value during training 16 00:00:39,150 --> 00:00:41,850 and we can combine that loss with a weight. 17 00:00:41,850 --> 00:00:43,860 Then we can create a weighted sum 18 00:00:43,860 --> 00:00:45,660 or average over all samples 19 00:00:45,660 --> 00:00:47,613 to get the final loss for the batch. 20 00:00:48,690 --> 00:00:51,240 Let's have a look at a specific example. 21 00:00:51,240 --> 00:00:52,830 We want to set up a language model 22 00:00:52,830 --> 00:00:56,073 that helps us autocomplete common data science code. 23 00:00:57,030 --> 00:01:01,830 For that task, we would like to weight samples stronger 24 00:01:01,830 --> 00:01:04,110 where tokens related to the data science stack, 25 00:01:04,110 --> 00:01:07,353 such as pd or np, occur more frequently. 26 00:01:10,140 --> 00:01:13,080 Here you see a loss function that does exactly that 27 00:01:13,080 --> 00:01:15,180 for causal language modeling. 28 00:01:15,180 --> 00:01:18,030 It takes the model's input and predicted logits, 29 00:01:18,030 --> 00:01:20,343 as well as the key tokens, as input. 30 00:01:21,869 --> 00:01:25,113 First, the inputs and logits are aligned. 31 00:01:26,490 --> 00:01:29,310 Then the loss per sample is calculated, 32 00:01:29,310 --> 00:01:30,843 followed by the weights. 33 00:01:32,430 --> 00:01:35,583 Finally, the loss and the weights are combined and returned. 34 00:01:36,540 --> 00:01:39,150 This is a pretty big function, so let's take a closer look 35 00:01:39,150 --> 00:01:40,953 at the loss and the weight blocks. 36 00:01:43,380 --> 00:01:45,600 During the calculation of the standard loss, 37 00:01:45,600 --> 00:01:48,930 the logits and labels are flattened over the batch. 38 00:01:48,930 --> 00:01:52,590 With the view, we unflatten the tensor to get the matrix 39 00:01:52,590 --> 00:01:55,320 with a row for each sample in the batch and a column 40 00:01:55,320 --> 00:01:57,723 for each position in the sequence of the sample. 41 00:01:58,920 --> 00:02:00,600 We don't need the loss per position, 42 00:02:00,600 --> 00:02:04,083 so we average the loss over all positions for each sample. 43 00:02:06,150 --> 00:02:08,970 For the weights, we use Boolean logic to get a tensor 44 00:02:08,970 --> 00:02:12,483 with 1s where a keyword occurred and 0s where not. 45 00:02:13,440 --> 00:02:15,690 This tensor has an additional dimension 46 00:02:15,690 --> 00:02:18,540 as the loss tensor we just saw because we get 47 00:02:18,540 --> 00:02:21,693 the information for each keyword in a separate matrix. 48 00:02:22,770 --> 00:02:24,120 We only want to know 49 00:02:24,120 --> 00:02:26,880 how many times keywords occurred per sample, 50 00:02:26,880 --> 00:02:30,693 so we can sum overall keywords and all positions per sample. 51 00:02:33,450 --> 00:02:35,010 Now we're almost there. 52 00:02:35,010 --> 00:02:38,850 We only need to combine the loss with the weight per sample. 53 00:02:38,850 --> 00:02:41,790 We do this with element wise multiplication 54 00:02:41,790 --> 00:02:45,233 and then average overall samples in the batch. 55 00:02:45,233 --> 00:02:46,066 In the end, 56 00:02:46,066 --> 00:02:49,110 we have exactly one loss value for the whole batch 57 00:02:49,110 --> 00:02:51,330 and this is the whole necessary logic 58 00:02:51,330 --> 00:02:53,223 to create a custom weighted loss. 59 00:02:56,250 --> 00:02:59,010 Let's see how we can make use of that custom loss 60 00:02:59,010 --> 00:03:00,753 with Accelerate and the Trainer. 61 00:03:01,710 --> 00:03:04,656 In Accelerate, we just pass the input_ids 62 00:03:04,656 --> 00:03:05,730 to the model to get the logits 63 00:03:05,730 --> 00:03:08,103 and then we can call the custom loss function. 64 00:03:09,000 --> 00:03:11,310 After that, we continue with the normal training loop 65 00:03:11,310 --> 00:03:13,083 by, for example, calling backward. 66 00:03:14,010 --> 00:03:15,570 For the Trainer, we can overwrite 67 00:03:15,570 --> 00:03:19,260 the compute loss function of the standard trainer. 68 00:03:19,260 --> 00:03:20,970 We just need to make sure that we return 69 00:03:20,970 --> 00:03:24,450 the loss and the model outputs in the same format. 70 00:03:24,450 --> 00:03:27,570 With that, you can integrate your own awesome loss function 71 00:03:27,570 --> 00:03:29,763 with both the Trainer and Accelerate. 72 00:03:31,389 --> 00:03:34,056 (air whooshing)