pipelines/ml_ai_python/ml_ai_pipeline/model_handlers.py (22 lines of code) (raw):
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Custom model handlers to be used with RunInference.
"""
from typing import Sequence, Optional, Any, Iterable
import keras_nlp
from apache_beam.ml.inference.base import ModelHandler, PredictionResult
from keras_nlp.src.models import GemmaCausalLM
class GemmaModelHandler(ModelHandler[str, PredictionResult, GemmaCausalLM]):
"""
A RunInference model handler for the Gemma model.
"""
def __init__(self, model_name: str = "gemma_2B"):
""" Implementation of the ModelHandler interface for Gemma using text as input.
Example Usage::
pcoll | RunInference(GemmaModelHandler())
Args:
model_name: The Gemma model name. Default is gemma_2B.
"""
super().__init__()
self._model_name = model_name
self._env_vars = {}
def share_model_across_processes(self) -> bool:
""" Indicates if the model should be loaded once-per-VM rather than
once-per-worker-process on a VM. Because Gemma is a large language model,
this will always return True to avoid OOM errors.
"""
return True
def load_model(self) -> GemmaCausalLM:
"""Loads and initializes a model for processing."""
return keras_nlp.models.GemmaCausalLM.from_preset(self._model_name)
def run_inference(
self,
batch: Sequence[str],
model: GemmaCausalLM,
unused: Optional[dict[str, Any]] = None) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.
Args:
batch: A sequence of examples as text strings.
model: The Gemma model being used.
Returns:
An Iterable of type PredictionResult.
"""
_ = unused # for interface compatibility with Model Handler
# Loop each text string, and use a tuple to store the inference results.
for one_text in batch:
result = model.generate(one_text, max_length=64)
yield PredictionResult(one_text, result, self._model_name)