pipelines/ml_ai_python/ml_ai_pipeline/pipeline.py (29 lines of code) (raw):

# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ A machine learning streaming inference pipeline for the Dataflow Solution Guides. """ from apache_beam import Pipeline, PCollection from apache_beam.ml.inference import RunInference from apache_beam.io.gcp import pubsub import apache_beam as beam from apache_beam.ml.inference.base import PredictionResult from .model_handlers import GemmaModelHandler from .options import MyPipelineOptions def _format_output(element: PredictionResult) -> str: return f"Input: \n{element.example}, \n\n\nOutput: \n{element.inference}" @beam.ptransform_fn def _extract(p: Pipeline, subscription: str) -> PCollection[str]: msgs: PCollection[bytes] = p | "Read subscription" >> beam.io.ReadFromPubSub( subscription=subscription) return msgs | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) @beam.ptransform_fn def _transform(msgs: PCollection[str], model_path: str) -> PCollection[str]: preds: PCollection[ PredictionResult] = msgs | "RunInference-Gemma" >> RunInference( GemmaModelHandler(model_path)) return preds | "Format Output" >> beam.Map(_format_output) def create_pipeline(options: MyPipelineOptions) -> Pipeline: """ Create the pipeline object. Args: options: The pipeline options, with type `MyPipelineOptions`. Returns: The pipeline object. """ pipeline = beam.Pipeline(options=options) # Extract msgs: PCollection[str] = pipeline | "Read" >> _extract( subscription=options.messages_subscription) # Transform responses: PCollection[str] = msgs | "Transform" >> _transform( model_path=options.model_path) # Load responses | "Publish Result" >> pubsub.WriteStringsToPubSub( topic=options.responses_topic) return pipeline