in pytext/legacy/data/field.py [0:0]
def pad(self, minibatch):
"""Pad a batch of examples using this field.
If ``self.nesting_field.sequential`` is ``False``, each example in the batch must
be a list of string tokens, and pads them as if by a ``Field`` with
``sequential=True``. Otherwise, each example must be a list of list of tokens.
Using ``self.nesting_field``, pads the list of tokens to
``self.nesting_field.fix_length`` if provided, or otherwise to the length of the
longest list of tokens in the batch. Next, using this field, pads the result by
filling short examples with ``self.nesting_field.pad_token``.
Example:
>>> import pprint
>>> pp = pprint.PrettyPrinter(indent=4)
>>>
>>> nesting_field = Field(pad_token='<c>', init_token='<w>', eos_token='</w>')
>>> field = NestedField(nesting_field, init_token='<s>', eos_token='</s>')
>>> minibatch = [
... [list('john'), list('loves'), list('mary')],
... [list('mary'), list('cries')],
... ]
>>> padded = field.pad(minibatch)
>>> pp.pprint(padded)
[ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
['<w>', 'j', 'o', 'h', 'n', '</w>', '<c>'],
['<w>', 'l', 'o', 'v', 'e', 's', '</w>'],
['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>']],
[ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
['<w>', 'c', 'r', 'i', 'e', 's', '</w>'],
['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
['<c>', '<c>', '<c>', '<c>', '<c>', '<c>', '<c>']]]
Arguments:
minibatch (list): Each element is a list of string if
``self.nesting_field.sequential`` is ``False``, a list of list of string
otherwise.
Returns:
list: The padded minibatch. or (padded, sentence_lens, word_lengths)
"""
minibatch = list(minibatch)
if not self.nesting_field.sequential:
return super(NestedField, self).pad(minibatch)
# Save values of attributes to be monkeypatched
old_pad_token = self.pad_token
old_init_token = self.init_token
old_eos_token = self.eos_token
old_fix_len = self.nesting_field.fix_length
# Monkeypatch the attributes
if self.nesting_field.fix_length is None:
max_len = max(len(xs) for ex in minibatch for xs in ex)
fix_len = (
max_len
+ 2
- (self.nesting_field.init_token, self.nesting_field.eos_token).count(
None
)
)
self.nesting_field.fix_length = fix_len
self.pad_token = [self.pad_token] * self.nesting_field.fix_length
if self.init_token is not None:
# self.init_token = self.nesting_field.pad([[self.init_token]])[0]
self.init_token = [self.init_token]
if self.eos_token is not None:
# self.eos_token = self.nesting_field.pad([[self.eos_token]])[0]
self.eos_token = [self.eos_token]
# Do padding
old_include_lengths = self.include_lengths
self.include_lengths = True
self.nesting_field.include_lengths = True
padded, sentence_lengths = super(NestedField, self).pad(minibatch)
padded_with_lengths = [self.nesting_field.pad(ex) for ex in padded]
word_lengths = []
final_padded = []
max_sen_len = len(padded[0])
for (pad, lens), sentence_len in zip(padded_with_lengths, sentence_lengths):
if sentence_len == max_sen_len:
lens = lens
pad = pad
elif self.pad_first:
lens[: (max_sen_len - sentence_len)] = [0] * (
max_sen_len - sentence_len
)
pad[: (max_sen_len - sentence_len)] = [self.pad_token] * (
max_sen_len - sentence_len
)
else:
lens[-(max_sen_len - sentence_len) :] = [0] * (
max_sen_len - sentence_len
)
pad[-(max_sen_len - sentence_len) :] = [self.pad_token] * (
max_sen_len - sentence_len
)
word_lengths.append(lens)
final_padded.append(pad)
padded = final_padded
# Restore monkeypatched attributes
self.nesting_field.fix_length = old_fix_len
self.pad_token = old_pad_token
self.init_token = old_init_token
self.eos_token = old_eos_token
self.include_lengths = old_include_lengths
if self.include_lengths:
return padded, sentence_lengths, word_lengths
return padded