def encode_single()

in utils_nlp/models/transformers/extractive_summarization.py [0:0]


    def encode_single(self, d, block_size, train_mode=True):
        """ Enocde a single sample.
            Args:
                d (dict): s data sample from SummarizationDataset.
                block_size (int): maximum input length for the model.

            Returns:
                Tuple of encoded data.

        """

        src = d["src"]

        if len(src) == 0:
            raise ValueError("source doesn't have any sentences")

        original_src_txt = [" ".join(s) for s in src]
        # no filtering for prediction
        idxs = [i for i, s in enumerate(src)]
        src = [src[i] for i in idxs]

        tgt_txt = None
        labels = None
        if (
            train_mode and "oracle_ids" in d and "tgt" in d and "tgt_txt" in d
        ):  # is not None and tgt is not None:
            labels = [0] * len(src)
            for l in d["oracle_ids"]:
                labels[l] = 1

            # source filtering for only training
            idxs = [i for i, s in enumerate(src) if (len(s) > self.min_src_ntokens)]
            src = [src[i][: self.max_src_ntokens] for i in idxs]
            src = src[: self.max_nsents]
            labels = [labels[i] for i in idxs]
            labels = labels[: self.max_nsents]

            if len(src) < self.min_nsents:
                return None
            if len(labels) == 0:
                return None
            tgt_txt = "".join([" ".join(tt) for tt in d["tgt"]])

        src_txt = [" ".join(sent) for sent in src]
        text = " [SEP] [CLS] ".join(src_txt)
        src_subtokens = self.tokenizer.tokenize(text)
        # src_subtokens = src_subtokens[:510]
        src_subtokens = (
            ["[CLS]"]
            + fit_to_block_size(
                src_subtokens, block_size - 2, self.tokenizer.pad_token_id
            )
            + ["[SEP]"]
        )
        src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens)
        _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid]
        segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
        segments_ids = []
        for i, s in enumerate(segs):
            if i % 2 == 0:
                segments_ids += s * [0]
            else:
                segments_ids += s * [1]
        cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid]
        if labels:
            labels = labels[: len(cls_ids)]
        src_txt = [original_src_txt[i] for i in idxs]
        return src_subtoken_idxs, labels, segments_ids, cls_ids, src_txt, tgt_txt