in fairseq/data/legacy/block_pair_dataset.py [0:0]
def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
"""
Go through a single document and genrate sentence paris from it
"""
current_chunk = []
current_length = 0
curr = 0
# To provide more randomness, we decrease target seq length for parts of
# samples (10% by default). Note that max_num_tokens is the hard threshold
# for batching and will never be changed.
target_seq_length = max_num_tokens
if np.random.random() < self.short_seq_prob:
target_seq_length = np.random.randint(2, max_num_tokens)
# loop through all sentences in document
while curr < len(doc):
sent_id = doc[curr]
current_chunk.append(sent_id)
current_length = sum(sizes[current_chunk])
# split chunk and generate pair when exceed target_seq_length or
# finish the loop
if curr == len(doc) - 1 or current_length >= target_seq_length:
# split the chunk into 2 parts
a_end = 1
if len(current_chunk) > 2:
a_end = np.random.randint(1, len(current_chunk) - 1)
sent_a = current_chunk[:a_end]
len_a = sum(sizes[sent_a])
# generate next sentence label, note that if there is only 1 sentence
# in current chunk, label is always 0
next_sent_label = (
1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
)
if not next_sent_label:
# if next sentence label is 0, sample sent_b from a random doc
target_b_length = target_seq_length - len_a
rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
random_doc = self.block_indices[rand_doc_id]
random_start = np.random.randint(0, len(random_doc))
sent_b = []
len_b = 0
for j in range(random_start, len(random_doc)):
sent_b.append(random_doc[j])
len_b = sum(sizes[sent_b])
if len_b >= target_b_length:
break
# return the second part of the chunk since it's not used
num_unused_segments = len(current_chunk) - a_end
curr -= num_unused_segments
else:
# if next sentence label is 1, use the second part of chunk as sent_B
sent_b = current_chunk[a_end:]
len_b = sum(sizes[sent_b])
# currently sent_a and sent_B may be longer than max_num_tokens,
# truncate them and return block idx and offsets for them
sent_a, sent_b = self._truncate_sentences(
sent_a, sent_b, max_num_tokens
)
self.sent_pairs.append((sent_a, sent_b, next_sent_label))
self.sizes.append(3 + sent_a[3] + sent_b[3])
current_chunk = []
curr += 1