def transform()

in datasets.py [0:0]


    def transform(self, idx):
        text = self.text_array[int(idx)]

        tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=220,
            return_tensors="pt",
        )

        return torch.squeeze(
            torch.stack(
                (
                    tokens["input_ids"],
                    tokens["attention_mask"],
                    tokens["token_type_ids"],
                ),
                dim=2,
            ),
            dim=0,
        )