data-analytics/next25-turbocharge-ecomm/main.py (175 lines of code) (raw):

# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import apache_beam as beam from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.ml.transforms.base import MLTransform from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAIImageEmbeddings from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON from apache_beam.io.jdbc import WriteToJdbc from apache_beam.io import ReadFromPubSub from apache_beam import coders import typing import tempfile # Imports the Google Cloud client library from google.cloud import storage from google import genai from google.genai import types import base64 from vertexai.vision_models import Image import argparse from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from typing import Any from typing import Dict from typing import Iterable from typing import Optional from typing import Sequence ### TODO google_project = "" def generate(): client = genai.Client( vertexai=True, project=google_project, location="us-central1", ) class GeminiModelHandler(ModelHandler): def load_model(self) -> genai.Client: """Loads and initializes a model for processing.""" client = genai.Client( vertexai=True, project=google_project, location="us-central1", ) return client def run_inference( self, batch: Sequence[str], model: genai.Client, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: model_name = "gemini-2.0-flash-001" generate_content_config = types.GenerateContentConfig( temperature = 1, top_p = 0.95, max_output_tokens = 8192, response_modalities = ["TEXT"], ) # Loop each text string, and use a tuple to store the inference results. predictions = [] for one_text in batch: contents = [ types.Content( role="user", parts=[ types.Part.from_text(text=one_text) ] ) ] result = "" for chunk in model.models.generate_content( model = model_name, contents = contents, config = generate_content_config, ): result = result + str(chunk) predictions.append((one_text,result)) return predictions project_id = google_project artifact_location = tempfile.mkdtemp(prefix='vertex_ai') embedding_model_name = 'multimodalembedding@001' mm_embedding_transform = VertexAIImageEmbeddings( model_name=embedding_model_name, # columns=['image','contextual_text'], columns=['image'], dimension=1408, project=project_id) class EmbeddingRowSchema(typing.NamedTuple): id: str image_path: str image_embedding: str contextual_text: str # Add the user_prompt field coders.registry.register_coder(EmbeddingRowSchema, coders.RowCoder) class PrepImageFn(beam.DoFn): # def setup(self): def start_bundle(self): # Instantiates a client self.storage_client = storage.Client() #logging.info("GCS client initialized in start_bundle.") def process(self, element): #logging.info(f"Processing element: {element}") bucket = self.storage_client.get_bucket(element['image_bucket']) blob = bucket.get_blob(element['image_path_split']) blob_bytes = blob.download_as_bytes() image = Image(blob_bytes) element['image'] = image element['image_bytes'] = blob_bytes yield element class DecodePubsubMessageFn(beam.DoFn): """Decodes Pub/Sub messages from JSON to Python dictionaries.""" def process(self, element): """Decodes a single Pub/Sub message.""" import json try: decoded_message = json.loads(element.decode('utf-8')) #logging.debug(f"Decoded Pub/Sub message: {decoded_message}") yield decoded_message except Exception as e: error_message = f"Error decoding message: {e}, message: {element.decode('utf-8')}" #logging.error(error_message) yield beam.pvalue.TaggedOutput('errors', element) def run(argv=None, save_main_session=True): """Runs the pipeline""" parser = argparse.ArgumentParser() parser.add_argument('--alloydb_username', dest='alloydb_username', required=True, help='AlloyDB username') parser.add_argument('--alloydb_password', dest='alloydb_password', required=True, help='AlloyDB password') parser.add_argument('--alloydb_ip', dest='alloydb_ip', required=True, help='AlloyDB IP Address') parser.add_argument('--alloydb_port', dest='alloydb_port', default="5432", help='AlloyDB Port') parser.add_argument('--alloydb_database', dest='alloydb_database', required=True, help='AlloyDB Database name') parser.add_argument('--alloydb_table', dest='alloydb_table', required=True, help='AlloyDB table name') parser.add_argument('--pubsub_subscription', dest='pubsub_subscription', required=True) known_args, pipeline_args = parser.parse_known_args(argv) pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session with beam.Pipeline(options=pipeline_options) as p: pubsub_pcoll = ( p | "ReadFromPubSub" >> ReadFromPubSub(subscription=known_args.pubsub_subscription).with_output_types(bytes) | "DecodeMessage" >> beam.ParDo(DecodePubsubMessageFn()) ) pubsub_pcoll | "printPubsubstuff" >> beam.ParDo(lambda x: print(x)) # Explain what a turnkey (MLTransform) is and why it is beneficial embedding_pcoll = ( pubsub_pcoll | beam.ParDo(PrepImageFn()) | "Embedding" >> MLTransform(write_artifact_location=artifact_location) .with_transform(mm_embedding_transform) | "ConvertToRows" >> beam.Map( lambda element: EmbeddingRowSchema( id= element['id'], image_path= element['image_path'], image_embedding= str(element['image']), contextual_text = element['contextual_text'] )) .with_output_types(EmbeddingRowSchema) ) # embedding_pcoll | "printMLTransformResults" >> beam.Map(lambda x: print(x)) inference_pcoll = (pubsub_pcoll | "getText" >> beam.Map(lambda data: data['contextual_text']) | "performGeminiInf" >> RunInference(GeminiModelHandler()) | "printText" >> beam.Map(lambda x: print(x)) ) embedding_pcoll | 'Write to jdbc' >> WriteToJdbc( driver_class_name='org.postgresql.Driver', table_name=known_args.alloydb_table, jdbc_url=(f'jdbc:postgresql://{known_args.alloydb_ip}:' f'{known_args.alloydb_port}' f'/{known_args.alloydb_database}'), username=known_args.alloydb_username, password=known_args.alloydb_password, connection_properties='stringtype=unspecified' ) if __name__ == "__main__": run()