dataflow/gemma-flex-template/custom_model_gemma.py (117 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable, Sequence import json import logging import apache_beam as beam from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.options.pipeline_options import PipelineOptions from gemma.config import get_config_for_2b from gemma.config import get_config_for_7b from gemma.model import GemmaForCausalLM import torch class GemmaPytorchModelHandler(ModelHandler[str, PredictionResult, GemmaForCausalLM]): def __init__( self, model_variant: str, checkpoint_path: str, tokenizer_path: str, device: str | None = "cpu", ): """Implementation of the ModelHandler interface for Gemma-on-Pytorch using text as input. Example Usage:: pcoll | RunInference(GemmaPytorchHandler()) Args: model_variant: The Gemma model name. checkpoint_path: the path to a local copy of gemma model weights. tokenizer_path: the path to a local copy of the gemma tokenizer device: optional. the device to run inference on. can be either 'cpu' or 'gpu', defaults to cpu. """ model_config = ( get_config_for_2b() if "2b" in model_variant else get_config_for_7b() ) model_config.tokenizer = tokenizer_path model_config.quant = "quant" in model_variant model_config.tokenizer = tokenizer_path self._model_config = model_config self._checkpoint_path = checkpoint_path if device == "GPU": logging.info("Device is set to CUDA") self._device = torch.device("cuda") else: logging.info("Device is set to CPU") self._device = torch.device("cpu") self._env_vars = {} def share_model_across_processes(self) -> bool: """Allows us to load a model only once per worker VM, decreasing pipeline memory requirements. """ return True def load_model(self) -> GemmaForCausalLM: """Loads and initializes a model for processing.""" torch.set_default_dtype(self._model_config.get_dtype()) model = GemmaForCausalLM(self._model_config) model.load_weights(self._checkpoint_path) model = model.to(self._device).eval() return model def run_inference( self, batch: Sequence[str], model: GemmaForCausalLM, inference_args: dict | None = None, ) -> Iterable[PredictionResult]: """Runs inferences on a batch of text strings. Args: batch: A sequence of examples as text strings. model: The Gemma model being used. inference_args: Any additional arguments for an inference. Returns: An Iterable of type PredictionResult. """ result = model.generate(prompts=batch, device=self._device) predictions = [result] return [PredictionResult(x, y) for x, y in zip(batch, predictions)] if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--messages_subscription", required=True, help="Pub/Sub subscription for input text messages", ) parser.add_argument( "--responses_topic", required=True, help="Pub/Sub topic for output text responses", ) parser.add_argument( "--model_variant", required=False, default="gemma-2b-it", help="name of the gemma variant being used", ) parser.add_argument( "--checkpoint_path", required=False, default="pytorch_model/gemma-2b-it.ckpt", help="path to the Gemma model weights in the custom worker container", ) parser.add_argument( "--tokenizer_path", required=False, default="pytorch_model/tokenizer.model", help="path to the Gemma tokenizer in the custom worker container", ) parser.add_argument( "--device", required=False, default="cpu", help="device to run the model on", ) args, beam_args = parser.parse_known_args() config = get_config_for_2b() logging.getLogger().setLevel(logging.INFO) beam_options = PipelineOptions( beam_args, save_main_session=True, streaming=True, ) handler = GemmaPytorchModelHandler( model_variant=args.model_variant, checkpoint_path=args.checkpoint_path, tokenizer_path=args.tokenizer_path, device=args.device, ) with beam.Pipeline(options=beam_options) as pipeline: _ = ( pipeline | "Subscribe to Pub/Sub" >> beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Decode" >> beam.Map(lambda msg: msg.decode("utf-8")) | "RunInference Gemma" >> RunInference(handler) | "Format output" >> beam.Map( lambda response: json.dumps( {"input": response.example, "outputs": response.inference} ) ) | "Encode" >> beam.Map(lambda msg: msg.encode("utf-8")) | "Publish to Pub/Sub" >> beam.io.gcp.pubsub.WriteToPubSub(topic=args.responses_topic) )