dataflow/run-inference/main.py (86 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Runs a streaming RunInference Language Model pipeline.""" from __future__ import annotations import logging import apache_beam as beam from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor from apache_beam.options.pipeline_options import PipelineOptions import torch from transformers import AutoConfig from transformers import AutoModelForSeq2SeqLM from transformers import AutoTokenizer from transformers.tokenization_utils import PreTrainedTokenizer MAX_RESPONSE_TOKENS = 256 def to_tensors(input_text: str, tokenizer: PreTrainedTokenizer) -> torch.Tensor: """Encodes input text into token tensors. Args: input_text: Input text for the language model. tokenizer: Tokenizer for the language model. Returns: Tokenized input tokens. """ return tokenizer(input_text, return_tensors="pt").input_ids[0] def decode_response(result: PredictionResult, tokenizer: PreTrainedTokenizer) -> str: """Decodes output token tensors into text. Args: result: Prediction results from the RunInference transform. tokenizer: Tokenizer for the language model. Returns: The model's response as text. """ output_tokens = result.inference return tokenizer.decode(output_tokens, skip_special_tokens=True) class AskModel(beam.PTransform): """Asks an language model a prompt message and gets its responses. Attributes: model_name: HuggingFace model name compatible with AutoModelForSeq2SeqLM. state_dict_path: File path to the model's state_dict, can be in Cloud Storage. max_response_tokens: Maximum number of tokens for the model to generate. """ def __init__( self, model_name: str, state_dict_path: str, max_response_tokens: int = MAX_RESPONSE_TOKENS, ) -> None: self.model_handler = PytorchModelHandlerTensor( state_dict_path=state_dict_path, model_class=AutoModelForSeq2SeqLM.from_config, model_params={"config": AutoConfig.from_pretrained(model_name)}, inference_fn=make_tensor_model_fn("generate"), ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_response_tokens = max_response_tokens def expand(self, pcollection: beam.PCollection[str]) -> beam.PCollection[str]: return ( pcollection | "To tensors" >> beam.Map(to_tensors, self.tokenizer) | "RunInference" >> RunInference( self.model_handler, inference_args={"max_new_tokens": self.max_response_tokens}, ) | "Get response" >> beam.Map(decode_response, self.tokenizer) ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--messages-topic", required=True, help="Pub/Sub topic for input text messages", ) parser.add_argument( "--responses-topic", required=True, help="Pub/Sub topic for output text responses", ) parser.add_argument( "--model-name", required=True, help="HuggingFace model name compatible with AutoModelForSeq2SeqLM", ) parser.add_argument( "--state-dict-path", required=True, help="File path to the model's state_dict, can be in Cloud Storage", ) args, beam_args = parser.parse_known_args() logging.getLogger().setLevel(logging.INFO) beam_options = PipelineOptions( beam_args, pickle_library="cloudpickle", streaming=True, ) simple_name = args.model_name.split("/")[-1] pipeline = beam.Pipeline(options=beam_options) _ = ( pipeline | "Read from Pub/Sub" >> beam.io.ReadFromPubSub(args.messages_topic) | "Decode bytes" >> beam.Map(lambda msg: msg.decode("utf-8")) | f"Ask {simple_name}" >> AskModel(args.model_name, args.state_dict_path) | "Encode bytes" >> beam.Map(lambda msg: msg.encode("utf-8")) | "Write to Pub/Sub" >> beam.io.WriteToPubSub(args.responses_topic) ) pipeline.run()