optimum/graphcore/data/data_collator.py (70 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from transformers.data import DataCollatorForLanguageModeling
from transformers.data.data_collator import DataCollator
from transformers.tokenization_utils_base import BatchEncoding
def pad_on_batch_axis(batch_size: int) -> Callable[[DataCollator], DataCollator]:
"""
Creates a `DataCollator` wrapper that pads the batches generated by `DataCollator` on the batch axis to generate
fixed size batches. It implements the padding by repeating elements of the batch to reach the padded sized.
"""
def pad_tensor(x: Tensor) -> Tensor:
if batch_size != x.size(0):
repeat_dims = torch.ones(x.ndim, dtype=int, requires_grad=False)
num_repeats = batch_size // x.size(0) + 1
repeat_dims[0] = num_repeats
return x.repeat(*repeat_dims.tolist())[:batch_size]
else:
return x
def decorator(data_collator: DataCollator) -> DataCollator:
@wraps(data_collator)
def wrapper(*args, **kwargs):
batch = data_collator(*args, **kwargs)
for k, v in batch.items():
batch[k] = pad_tensor(v)
return batch
return wrapper
return decorator
class DataCollatorForLanguageModelingWithMaxTokensMasked(DataCollatorForLanguageModeling):
def __init__(self, max_seq_length, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_seq_length = max_seq_length
self.max_num_masked_tokens = self._calculate_max_num_masked_tokens(max_seq_length)
def _calculate_max_num_masked_tokens(self, max_seq_length):
"""
Gets the maximum number of masked tokens.
The number of masked tokens follows a binomial distribution. We approximate
the binomial distribution with a Gaussian distribution and cap the maximum number of masked tokens to two standard
deviations above the mean.
"""
import math
mean = max_seq_length * self.mlm_probability
var = max_seq_length * self.mlm_probability * (1 - self.mlm_probability)
std = math.sqrt(var)
max_num_masked_tokens = mean + 2 * std
# Round up to a multiple of 16
max_num_masked_tokens = math.ceil(max_num_masked_tokens / 16) * 16
# Cap to max_seq_length
max_num_masked_tokens = min(max_num_masked_tokens, max_seq_length)
return max_num_masked_tokens
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
# Necessary for poptorch.DataLoaderMode.AsyncRebatched which can handle dictionaries but not BatchEncoding.
if isinstance(batch, BatchEncoding):
batch = dict(batch)
return batch
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
"""
Prepare masked token inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
import torch
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix)
# Making sure there are at most max_num_masked_tokens that are masked for each example.
# torch_mask_tokens is called after padding so labels should be of fixed shape.
# Adding a small noise to -masked_indices to make the torch.topk selection of the ones to delete stochastic.
small_noise = torch.rand(masked_indices.size())
_, indices = torch.topk(
-masked_indices + small_noise, k=self.max_seq_length - self.max_num_masked_tokens, dim=1
)
masked_indices.scatter_(1, indices, 0)
masked_indices = masked_indices.bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels