def collate_fn()

in src/setfit/data.py [0:0]


    def collate_fn(self, batch):
        features = {input_name: [] for input_name in self.tokenizer.model_input_names}

        labels = []
        for feature, label in batch:
            features["input_ids"].append(feature["input_ids"])
            if "attention_mask" in features:
                features["attention_mask"].append(feature["attention_mask"])
            if "token_type_ids" in features:
                features["token_type_ids"].append(feature["token_type_ids"])
            labels.append(label)

        # convert to tensors
        features = {k: torch.Tensor(v).int() for k, v in features.items()}
        labels = torch.Tensor(labels)
        labels = labels.long() if len(labels.size()) == 1 else labels.float()
        return features, labels