in src/evaluate/module.py [0:0]
def add_batch(self, *, predictions=None, references=None, **kwargs):
"""Add a batch of predictions and references for the evaluation module's stack.
Args:
predictions (`list/array/tensor`, *optional*):
Predictions.
references (`list/array/tensor`, *optional*):
References.
Example:
```py
>>> import evaluate
>>> accuracy = evaluate.load("accuracy")
>>> for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
... accuracy.add_batch(references=refs, predictions=preds)
```
"""
bad_inputs = [input_name for input_name in kwargs if input_name not in self._feature_names()]
if bad_inputs:
raise ValueError(
f"Bad inputs for evaluation module: {bad_inputs}. All required inputs are {list(self._feature_names())}"
)
batch = {"predictions": predictions, "references": references, **kwargs}
batch = {input_name: batch[input_name] for input_name in self._feature_names()}
if self.writer is None:
self.selected_feature_format = self._infer_feature_from_batch(batch)
self._init_writer()
try:
for key, column in batch.items():
if len(column) > 0:
self._enforce_nested_string_type(self.selected_feature_format[key], column[0])
batch = self.selected_feature_format.encode_batch(batch)
self.writer.write_batch(batch)
except (pa.ArrowInvalid, TypeError):
if any(len(batch[c]) != len(next(iter(batch.values()))) for c in batch):
col0 = next(iter(batch))
bad_col = [c for c in batch if len(batch[c]) != len(batch[col0])][0]
error_msg = (
f"Mismatch in the number of {col0} ({len(batch[col0])}) and {bad_col} ({len(batch[bad_col])})"
)
elif set(self.selected_feature_format) != {"references", "predictions"}:
error_msg = (
f"Module inputs don't match the expected format.\n"
f"Expected format: {self.selected_feature_format },\n"
)
error_msg_inputs = ",\n".join(
f"Input {input_name}: {summarize_if_long_list(batch[input_name])}"
for input_name in self.selected_feature_format
)
error_msg += error_msg_inputs
else:
error_msg = (
f"Predictions and/or references don't match the expected format.\n"
f"Expected format: {self.selected_feature_format },\n"
f"Input predictions: {summarize_if_long_list(predictions)},\n"
f"Input references: {summarize_if_long_list(references)}"
)
raise ValueError(error_msg) from None