optimum/graphcore/data/data_collator.py (70 lines of code) (raw):

# Copyright 2021 The HuggingFace Team. All rights reserved. # Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from transformers.data import DataCollatorForLanguageModeling from transformers.data.data_collator import DataCollator from transformers.tokenization_utils_base import BatchEncoding def pad_on_batch_axis(batch_size: int) -> Callable[[DataCollator], DataCollator]: """ Creates a `DataCollator` wrapper that pads the batches generated by `DataCollator` on the batch axis to generate fixed size batches. It implements the padding by repeating elements of the batch to reach the padded sized. """ def pad_tensor(x: Tensor) -> Tensor: if batch_size != x.size(0): repeat_dims = torch.ones(x.ndim, dtype=int, requires_grad=False) num_repeats = batch_size // x.size(0) + 1 repeat_dims[0] = num_repeats return x.repeat(*repeat_dims.tolist())[:batch_size] else: return x def decorator(data_collator: DataCollator) -> DataCollator: @wraps(data_collator) def wrapper(*args, **kwargs): batch = data_collator(*args, **kwargs) for k, v in batch.items(): batch[k] = pad_tensor(v) return batch return wrapper return decorator class DataCollatorForLanguageModelingWithMaxTokensMasked(DataCollatorForLanguageModeling): def __init__(self, max_seq_length, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_length = max_seq_length self.max_num_masked_tokens = self._calculate_max_num_masked_tokens(max_seq_length) def _calculate_max_num_masked_tokens(self, max_seq_length): """ Gets the maximum number of masked tokens. The number of masked tokens follows a binomial distribution. We approximate the binomial distribution with a Gaussian distribution and cap the maximum number of masked tokens to two standard deviations above the mean. """ import math mean = max_seq_length * self.mlm_probability var = max_seq_length * self.mlm_probability * (1 - self.mlm_probability) std = math.sqrt(var) max_num_masked_tokens = mean + 2 * std # Round up to a multiple of 16 max_num_masked_tokens = math.ceil(max_num_masked_tokens / 16) * 16 # Cap to max_seq_length max_num_masked_tokens = min(max_num_masked_tokens, max_seq_length) return max_num_masked_tokens def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: batch = super().torch_call(examples) # Necessary for poptorch.DataLoaderMode.AsyncRebatched which can handle dictionaries but not BatchEncoding. if isinstance(batch, BatchEncoding): batch = dict(batch) return batch def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: """ Prepare masked token inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ import torch labels = inputs.clone() # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) probability_matrix = torch.full(labels.shape, self.mlm_probability) if special_tokens_mask is None: special_tokens_mask = [ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) else: special_tokens_mask = special_tokens_mask.bool() probability_matrix.masked_fill_(special_tokens_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix) # Making sure there are at most max_num_masked_tokens that are masked for each example. # torch_mask_tokens is called after padding so labels should be of fixed shape. # Adding a small noise to -masked_indices to make the torch.topk selection of the ones to delete stochastic. small_noise = torch.rand(masked_indices.size()) _, indices = torch.topk( -masked_indices + small_noise, k=self.max_seq_length - self.max_num_masked_tokens, dim=1 ) masked_indices.scatter_(1, indices, 0) masked_indices = masked_indices.bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels