in jat/processing_jat.py [0:0]
def pad(encoding: Dict[str, List[List[Any]]], target_length: int) -> Dict[str, List[List[Any]]]:
"""
Pad the sequences in the encoding to the specified maximum length.
This function is designed to process batch of sequences represented in the encoding dictionary.
The padding value is set to be the first element in the sequence.
Args:
encoding (`Mapping`):
A dictionary where each key-value pair consists of a feature name and its corresponding batch of sequences.
The sequences are expected to be lists.
target_length (`int`):
The desired length for the sequences.
Returns:
`Dict[str, List[List[Any]]]`:
A dictionary with the same keys as the input `encoding`, containing the padded batch of sequences.
An additional key `attention_mask` is added to the dictionary to indicate the positions of the non-padding
elements with 1s and the padding elements with 0s. If the input `encoding` already contains an
`attention_mask` key, the corresponding mask will be updated such that the original masking is preserved,
and the newly added padding elements will be masked with 0s. In other words, the resulting
`attention_mask` is a logical "AND" between the provided mask and the mask created due to padding, ensuring
that any element masked originally remains masked.
Example:
>>> encoding = {'feature1': [[1, 2], [3, 4, 5]]}
>>> pad(encoding, 4)
{'feature1': [[1, 2, 1, 1], [3, 4, 5, 3]], 'attention_mask': [[1, 1, 0, 0], [1, 1, 1, 0]]}
>>> encoding = {'feature1': [[1, 2], [3, 4, 5]], "attention_mask": [[1, 0], [0, 1, 1]]}
>>> pad(encoding, 4)
{'feature1': [[1, 2, 1, 1], [3, 4, 5, 3]], 'attention_mask': [[1, 0, 0, 0], [0, 1, 1, 0]]}
"""
padded_encoding = {}
for key, sequences in encoding.items():
if not all(isinstance(seq, (list, torch.Tensor)) for seq in sequences):
raise TypeError(f"All sequences under key {key} should be of type list or tensor.")
if key == "attention_mask": # attention_mask is handled separately
continue
padded_sequences = []
pad_mask = []
for seq in sequences:
pad_len = target_length - len(seq)
padded_seq = list(seq) + [seq[0]] * max(0, pad_len)
mask = [1] * len(seq) + [0] * max(0, pad_len)
padded_sequences.append(padded_seq)
pad_mask.append(mask)
padded_encoding[key] = padded_sequences
if "attention_mask" in encoding:
padded_encoding["attention_mask"] = [
[a * (b[i] if i < len(b) else 0) for i, a in enumerate(row)]
for row, b in zip(pad_mask, encoding["attention_mask"])
]
else:
padded_encoding["attention_mask"] = pad_mask
return padded_encoding