in utils_nlp/models/transformers/abstractive_summarization_bertsum.py [0:0]
def collate(self, data, block_size, device, train_mode=True):
""" Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
all in memory.
Args:
data (list of (str, str)): input data to be loaded.
block_size (long): size of the encoded data to be passed into the data loader
device (torch.device): A PyTorch device.
train_mode (bool, optional): Training mode flag.
Defaults to True.
Returns:
namedtuple: a nametuple containing input ids, segment ids,
masks for the input ids and source text. If train_mode is True, it
also contains the target ids and the number of tokens
in the target and target text.
"""
data = [x for x in data if not len(x["src"]) == 0] # remove empty_files
if len(data) == 0:
return None
stories = [" ".join(d["src"]) for d in data]
if train_mode is True and "tgt" in data[0]:
summaries = [" ".join(d["tgt"]) for d in data]
encoded_text = [self.preprocess(d["src"], d["tgt"]) for d in data]
else:
encoded_text = [self.preprocess(d["src"], None) for d in data]
encoded_stories = torch.tensor(
[
fit_to_block_size(story, block_size, self.tokenizer.pad_token_id)
for story, _ in encoded_text
]
)
encoder_token_type_ids = compute_token_type_ids(
encoded_stories, self.tokenizer.cls_token_id
)
encoder_mask = build_mask(encoded_stories, self.tokenizer.pad_token_id)
if train_mode and "tgt" in data[0]:
encoded_summaries = torch.tensor(
[
[self.tgt_bos]
+ fit_to_block_size(
summary, block_size - 2, self.tokenizer.pad_token_id
)
+ [self.tgt_eos]
for _, summary in encoded_text
]
)
summary_num_tokens = [
encoded_summary.ne(self.tokenizer.pad_token_id).sum()
for encoded_summary in encoded_summaries
]
Batch = namedtuple(
"Batch",
[
"src",
"segs",
"mask_src",
"tgt",
"tgt_num_tokens",
"src_str",
"tgt_str",
],
)
batch = Batch(
src=encoded_stories.to(device),
segs=encoder_token_type_ids.to(device),
mask_src=encoder_mask.to(device),
tgt_num_tokens=torch.stack(summary_num_tokens).to(device),
tgt=encoded_summaries.to(device),
src_str=stories,
tgt_str=summaries,
)
else:
Batch = namedtuple("Batch", ["src", "segs", "mask_src"])
batch = Batch(
src=encoded_stories.to(device),
segs=encoder_token_type_ids.to(device),
mask_src=encoder_mask.to(device),
)
return batch