in optimum/habana/sentence_transformers/st_gaudi_data_collator.py [0:0]
def st_gaudi_data_collator_call(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""Collator for a SentenceTransformers model."""
column_names = list(features[0].keys())
# We should always be able to return a loss, label or not:
batch = {"return_loss": True}
if "dataset_name" in column_names:
column_names.remove("dataset_name")
batch["dataset_name"] = features[0]["dataset_name"]
if tuple(column_names) not in self._warned_columns:
self.maybe_warn_about_column_order(column_names)
# Extract the label column if it exists
for label_column in self.valid_label_columns:
if label_column in column_names:
batch["label"] = torch.tensor([row[label_column] for row in features])
column_names.remove(label_column)
break
# Extract the feature columns
cnt = 0
cnt1 = 0
power2_len = [0, 0]
for column_name in column_names:
# If the prompt length has been set, we should add it to the batch
if column_name.endswith("_prompt_length") and column_name[: -len("_prompt_length")] in column_names:
batch[column_name] = torch.tensor([row[column_name] for row in features], dtype=torch.int)
continue
tokenized = self.tokenize_fn([row[column_name] for row in features])
for key, value in tokenized.items():
curr_tokenize_len = value.shape
if curr_tokenize_len[1] > 4096:
power2_len[cnt1] = math.ceil(curr_tokenize_len[1] / 128) * 128
else:
power2_len[cnt1] = 2 ** math.ceil(math.log2(curr_tokenize_len[1]))
additional_pad_len = power2_len[cnt1] - curr_tokenize_len[1]
if (cnt1 == 1) and (power2_len[0] == power2_len[1]):
additional_pad_len += 1
batch[f"{column_name}_{key}"] = torch.cat(
(
value,
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
cnt += 1
cnt1 = cnt & 1
return batch