# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import apache_beam as beam 
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAIImageEmbeddings
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.io import ReadFromPubSub
from apache_beam import coders
import typing
import tempfile
# Imports the Google Cloud client library
from google.cloud import storage

from google import genai
from google.genai import types
import base64
from vertexai.vision_models import Image
import argparse
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence


### TODO
google_project = ""

def generate():
    client = genai.Client(
        vertexai=True,
        project=google_project,
        location="us-central1",
    )


class GeminiModelHandler(ModelHandler):

    def load_model(self) -> genai.Client:
        """Loads and initializes a model for processing."""
        client = genai.Client(
            vertexai=True,
            project=google_project,
            location="us-central1",
        )
        return client

    def run_inference(
        self,
        batch: Sequence[str],
        model: genai.Client,
        inference_args: Optional[Dict[str, Any]] = None
    ) -> Iterable[PredictionResult]:
        model_name = "gemini-2.0-flash-001"

        generate_content_config = types.GenerateContentConfig(
        temperature = 1,
        top_p = 0.95,
        max_output_tokens = 8192,
        response_modalities = ["TEXT"],
        )


        # Loop each text string, and use a tuple to store the inference results.
        predictions = []
        for one_text in batch:
            contents = [
                types.Content(
                    role="user",
                    parts=[
                    types.Part.from_text(text=one_text)
                    ]
                )
            ]
            result = ""
            for chunk in model.models.generate_content(
                model = model_name,
                contents = contents,
                config = generate_content_config,
                ):

                result = result + str(chunk)
            predictions.append((one_text,result))
        return predictions


project_id = google_project
artifact_location = tempfile.mkdtemp(prefix='vertex_ai')

embedding_model_name = 'multimodalembedding@001'
mm_embedding_transform = VertexAIImageEmbeddings(
        model_name=embedding_model_name,
        # columns=['image','contextual_text'],
        columns=['image'],
        dimension=1408,
        project=project_id)

class EmbeddingRowSchema(typing.NamedTuple):
    id: str
    image_path: str
    image_embedding: str
    contextual_text: str # Add the user_prompt field

coders.registry.register_coder(EmbeddingRowSchema, coders.RowCoder)

class PrepImageFn(beam.DoFn):

    # def setup(self):

    def start_bundle(self):
        # Instantiates a client
        self.storage_client = storage.Client()
        #logging.info("GCS client initialized in start_bundle.")

    def process(self, element):
        #logging.info(f"Processing element: {element}")
        bucket = self.storage_client.get_bucket(element['image_bucket'])
        blob = bucket.get_blob(element['image_path_split'])
        blob_bytes = blob.download_as_bytes()
        image = Image(blob_bytes)
        element['image'] = image
        element['image_bytes'] = blob_bytes
        yield element

class DecodePubsubMessageFn(beam.DoFn):
    """Decodes Pub/Sub messages from JSON to Python dictionaries."""
    def process(self, element):
        """Decodes a single Pub/Sub message."""
        import json
        try:
            decoded_message = json.loads(element.decode('utf-8'))
            #logging.debug(f"Decoded Pub/Sub message: {decoded_message}")
            yield decoded_message
        except Exception as e:
            error_message = f"Error decoding message: {e}, message: {element.decode('utf-8')}"
            #logging.error(error_message)
            yield beam.pvalue.TaggedOutput('errors', element)

def run(argv=None, save_main_session=True):
    """Runs the pipeline"""

    parser = argparse.ArgumentParser()
    parser.add_argument('--alloydb_username',
                        dest='alloydb_username',
                        required=True,
                        help='AlloyDB username')
    parser.add_argument('--alloydb_password',
                        dest='alloydb_password',
                        required=True,
                        help='AlloyDB password')
    parser.add_argument('--alloydb_ip',
                        dest='alloydb_ip',
                        required=True,
                        help='AlloyDB IP Address')
    parser.add_argument('--alloydb_port',
                        dest='alloydb_port',
                        default="5432",
                        help='AlloyDB Port')
    parser.add_argument('--alloydb_database',
                        dest='alloydb_database',
                        required=True,
                        help='AlloyDB Database name')
    parser.add_argument('--alloydb_table',
                        dest='alloydb_table',
                        required=True,
                        help='AlloyDB table name')
    parser.add_argument('--pubsub_subscription', dest='pubsub_subscription', required=True)
    known_args, pipeline_args = parser.parse_known_args(argv)
    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

    with beam.Pipeline(options=pipeline_options) as p:


        pubsub_pcoll = (
            p
            | "ReadFromPubSub" >> ReadFromPubSub(subscription=known_args.pubsub_subscription).with_output_types(bytes)
            | "DecodeMessage" >> beam.ParDo(DecodePubsubMessageFn())
        )
        pubsub_pcoll | "printPubsubstuff" >> beam.ParDo(lambda x: print(x))



        # Explain what a turnkey (MLTransform) is and why it is beneficial 
        embedding_pcoll = (
            pubsub_pcoll 
            | beam.ParDo(PrepImageFn()) 
            | "Embedding" >> MLTransform(write_artifact_location=artifact_location)
                .with_transform(mm_embedding_transform)
            | "ConvertToRows" >> beam.Map(
                lambda element: EmbeddingRowSchema(
                        id= element['id'],
                        image_path= element['image_path'], 
                        image_embedding= str(element['image']),
                        contextual_text = element['contextual_text']
                ))
                .with_output_types(EmbeddingRowSchema)
        )

        # embedding_pcoll | "printMLTransformResults" >> beam.Map(lambda x: print(x))

        inference_pcoll = (pubsub_pcoll
            | "getText" >> beam.Map(lambda data: data['contextual_text'])
            | "performGeminiInf" >> RunInference(GeminiModelHandler())
            | "printText" >> beam.Map(lambda x: print(x))
        )

        embedding_pcoll | 'Write to jdbc' >> WriteToJdbc(
            driver_class_name='org.postgresql.Driver',
            table_name=known_args.alloydb_table,
            jdbc_url=(f'jdbc:postgresql://{known_args.alloydb_ip}:'
                        f'{known_args.alloydb_port}'
                        f'/{known_args.alloydb_database}'),
            username=known_args.alloydb_username,
            password=known_args.alloydb_password,
            connection_properties='stringtype=unspecified'
        )

if __name__ == "__main__":
    run()
