in pipelines/ml_ai_python/ml_ai_pipeline/model_handlers.py [0:0]
def run_inference(
self,
batch: Sequence[str],
model: GemmaCausalLM,
unused: Optional[dict[str, Any]] = None) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.
Args:
batch: A sequence of examples as text strings.
model: The Gemma model being used.
Returns:
An Iterable of type PredictionResult.
"""
_ = unused # for interface compatibility with Model Handler
# Loop each text string, and use a tuple to store the inference results.
for one_text in batch:
result = model.generate(one_text, max_length=64)
yield PredictionResult(one_text, result, self._model_name)