in notebooks/src/code/data/mlm.py [0:0]
def torch_call(self, examples: List[TextractLayoutLMExampleForLM]) -> Dict[str, Any]:
# Tokenize, pad and etc the words:
batch = self.tokenizer(
[example.word_texts for example in examples],
is_split_into_words=True,
return_attention_mask=True,
padding=bool(self.pad_to_multiple_of),
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
# Map through the bounding boxes to the generated tokens:
# We do this by augmenting the list of word bboxes to include the special token bboxes,
# editing the word_ids mapping from tokens->words to match special tokens to their special
# boxes (instead of None), and then applying this set of indexes to produce the token-wise
# boxes including special tokens.
bbox_tensors_by_example = []
for ixex in range(len(examples)):
box_ids = batch.word_ids(ixex) # List[Union[int, None]], =None at special tokens
n_real_boxes = len(examples[ixex].word_boxes_normalized)
augmented_example_word_boxes = torch.cat(
(
torch.LongTensor(examples[ixex].word_boxes_normalized),
self._special_token_boxes,
),
dim=0,
)
if box_ids[0] is None: # Shortcut as <bos> should appear only at start
box_ids[0] = n_real_boxes # bos_token_box, per _special_token_boxes
# Torch tensors don't support None, but numpy float ndarrays do:
box_ids_np = np.array(box_ids, dtype=float)
box_ids_np = np.where(
batch.input_ids[ixex, :] == self.tokenizer.pad_token_id,
n_real_boxes + 1, # pad_token_box, per _special_token_boxes
box_ids_np,
)
box_ids_np = np.where(
batch.input_ids[ixex, :] == self.tokenizer.sep_token_id,
n_real_boxes + 2, # sep_token_box, per _special_token_boxes
box_ids_np,
)
bbox_tensors_by_example.append(
torch.index_select(
augmented_example_word_boxes,
0,
# By this point all NaNs from special tokens should be resolved so can cast:
torch.LongTensor(box_ids_np.astype(int)),
)
)
batch["bbox"] = torch.stack(bbox_tensors_by_example)
# From here, implementation is as per superclass (but we can't call super because the first
# part of the method expects batching not to have been done yet):
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch["input_ids"].clone()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch