def get_parallel_minibatch()

in utils_nlp/models/gensen/utils.py [0:0]


    def get_parallel_minibatch(self, index, batch_size, sent_type="train"):
        """Prepare minibatch.

        Args:
            index(int): The index for line.
            batch_size(int): Batch size.
            sent_type(str): Type of dataset.

        Returns:
            dict for batch training.
        """
        if sent_type == "train":
            lines = self.train_lines
        elif sent_type == "dev":
            lines = self.dev_lines
        else:
            lines = self.test_lines

        sent1 = [
            ["<s>"] + line[0].split() + ["</s>"]
            for line in lines[index : index + batch_size]
        ]

        sent2 = [
            ["<s>"] + line[1].split() + ["</s>"]
            for line in lines[index : index + batch_size]
        ]

        labels = [
            self.text2label[line[2]]
            for line in lines[index : index + batch_size]
        ]

        sent1_lens = [len(line) for line in sent1]
        sorted_sent1_indices = np.argsort(sent1_lens)[::-1]
        sorted_sent1_lines = [sent1[idx] for idx in sorted_sent1_indices]
        rev_sent1 = np.argsort(sorted_sent1_indices)

        sent2_lens = [len(line) for line in sent2]
        sorted_sent2_indices = np.argsort(sent2_lens)[::-1]
        sorted_sent2_lines = [sent2[idx] for idx in sorted_sent2_indices]
        rev_sent2 = np.argsort(sorted_sent2_indices)

        sorted_sent1_lens = [len(line) for line in sorted_sent1_lines]
        sorted_sent2_lens = [len(line) for line in sorted_sent2_lines]

        max_sent1_len = max(sorted_sent1_lens)
        max_sent2_len = max(sorted_sent2_lens)

        sent1 = [
            [
                self.word2id[w] if w in self.word2id else self.word2id["<unk>"]
                for w in line
            ]
            + [self.word2id["<pad>"]] * (max_sent1_len - len(line))
            for line in sorted_sent1_lines
        ]

        sent2 = [
            [
                self.word2id[w] if w in self.word2id else self.word2id["<unk>"]
                for w in line
            ]
            + [self.word2id["<pad>"]] * (max_sent2_len - len(line))
            for line in sorted_sent2_lines
        ]

        sent1 = Variable(torch.LongTensor(sent1)).cuda()
        sent2 = Variable(torch.LongTensor(sent2)).cuda()
        labels = Variable(torch.LongTensor(labels)).cuda()
        sent1_lens = (
            Variable(torch.LongTensor(sorted_sent1_lens), requires_grad=False)
            .squeeze()
            .cuda()
        )
        sent2_lens = (
            Variable(torch.LongTensor(sorted_sent2_lens), requires_grad=False)
            .squeeze()
            .cuda()
        )
        rev_sent1 = (
            Variable(torch.LongTensor(rev_sent1), requires_grad=False)
            .squeeze()
            .cuda()
        )
        rev_sent2 = (
            Variable(torch.LongTensor(rev_sent2), requires_grad=False)
            .squeeze()
            .cuda()
        )

        return {
            "sent1": sent1,
            "sent2": sent2,
            "sent1_lens": sent1_lens,
            "sent2_lens": sent2_lens,
            "rev_sent1": rev_sent1,
            "rev_sent2": rev_sent2,
            "labels": labels,
            "type": "nli",
        }