def collate()

in utils_nlp/models/transformers/abstractive_summarization_bertsum.py [0:0]


    def collate(self, data, block_size, device, train_mode=True):
        """ Collate formats the data passed to the data loader.
        In particular we tokenize the data batch after batch to avoid keeping them
        all in memory.

        Args:
            data (list of (str, str)): input data to be loaded.
            block_size (long): size of the encoded data to be passed into the data loader
            device (torch.device): A PyTorch device.
            train_mode (bool, optional): Training mode flag.
                Defaults to True.

        Returns:
            namedtuple: a nametuple containing input ids, segment ids,
                masks for the input ids and source text. If train_mode is True, it
                also contains the target ids and the number of tokens
                in the target and target text.
        """
        data = [x for x in data if not len(x["src"]) == 0]  # remove empty_files
        if len(data) == 0:
            return None
        stories = [" ".join(d["src"]) for d in data]
        if train_mode is True and "tgt" in data[0]:
            summaries = [" ".join(d["tgt"]) for d in data]
            encoded_text = [self.preprocess(d["src"], d["tgt"]) for d in data]
        else:
            encoded_text = [self.preprocess(d["src"], None) for d in data]

        encoded_stories = torch.tensor(
            [
                fit_to_block_size(story, block_size, self.tokenizer.pad_token_id)
                for story, _ in encoded_text
            ]
        )
        encoder_token_type_ids = compute_token_type_ids(
            encoded_stories, self.tokenizer.cls_token_id
        )
        encoder_mask = build_mask(encoded_stories, self.tokenizer.pad_token_id)

        if train_mode and "tgt" in data[0]:
            encoded_summaries = torch.tensor(
                [
                    [self.tgt_bos]
                    + fit_to_block_size(
                        summary, block_size - 2, self.tokenizer.pad_token_id
                    )
                    + [self.tgt_eos]
                    for _, summary in encoded_text
                ]
            )
            summary_num_tokens = [
                encoded_summary.ne(self.tokenizer.pad_token_id).sum()
                for encoded_summary in encoded_summaries
            ]

            Batch = namedtuple(
                "Batch",
                [
                    "src",
                    "segs",
                    "mask_src",
                    "tgt",
                    "tgt_num_tokens",
                    "src_str",
                    "tgt_str",
                ],
            )
            batch = Batch(
                src=encoded_stories.to(device),
                segs=encoder_token_type_ids.to(device),
                mask_src=encoder_mask.to(device),
                tgt_num_tokens=torch.stack(summary_num_tokens).to(device),
                tgt=encoded_summaries.to(device),
                src_str=stories,
                tgt_str=summaries,
            )
        else:
            Batch = namedtuple("Batch", ["src", "segs", "mask_src"])
            batch = Batch(
                src=encoded_stories.to(device),
                segs=encoder_token_type_ids.to(device),
                mask_src=encoder_mask.to(device),
            )

        return batch