in src/setfit/data.py [0:0]
def collate_fn(self, batch):
features = {input_name: [] for input_name in self.tokenizer.model_input_names}
labels = []
for feature, label in batch:
features["input_ids"].append(feature["input_ids"])
if "attention_mask" in features:
features["attention_mask"].append(feature["attention_mask"])
if "token_type_ids" in features:
features["token_type_ids"].append(feature["token_type_ids"])
labels.append(label)
# convert to tensors
features = {k: torch.Tensor(v).int() for k, v in features.items()}
labels = torch.Tensor(labels)
labels = labels.long() if len(labels.size()) == 1 else labels.float()
return features, labels