ai-ml/spark-gemini-rag/main.py (132 lines of code) (raw):
import os
import time
import uuid
from google.cloud import bigquery
from google.cloud import firestore
from nicegui import app, run, ui
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
ConsoleSpanExporter,
)
import pandas as pd
from sentence_transformers import SentenceTransformer
import vertexai
from vertexai.generative_models import GenerativeModel
from vertexai.evaluation import EvalTask, PointwiseMetric
# Define constants
# Instantiate OpenTelemetry
provider = TracerProvider()
processor = BatchSpanProcessor(ConsoleSpanExporter())
provider.add_span_processor(processor)
# Sets the global default tracer provider
trace.set_tracer_provider(provider)
# Creates a tracer from the global tracer provider
TRACER = trace.get_tracer("my.tracer.name")
PROJECT_ID = os.environ.get("PROJECT")
REGION = os.environ.get("REGION", "us-central1")
# Initialize VertexAI
vertexai.init(project=PROJECT_ID, location=REGION)
BIGQUERY_CLIENT = bigquery.Client(project=PROJECT_ID)
FIRESTORE_CLIENT = firestore.Client(project=PROJECT_ID, database="gemini-hackathon")
GEMINI_ENDPOINT = GenerativeModel("gemini-1.5-flash")
TRANSFORMER_MODEL = None
def get_embedding(prompt):
global TRANSFORMER_MODEL
if not TRANSFORMER_MODEL:
TRANSFORMER_MODEL = SentenceTransformer("all-miniLM-L6-v2")
embeddings = TRANSFORMER_MODEL.encode(prompt)
return embeddings.tolist()
def make_gemini_prediction(prompt: str) -> str:
try:
return GEMINI_ENDPOINT.generate_content(prompt).text
except Exception as e:
print(e)
raise()
def prompt_maker(input) -> tuple[str, str, str]:
context = get_rag_context(input)
version = os.environ.get("PROMPT_VERSION", "v20240926.1")
ref = FIRESTORE_CLIENT.collection(f"prompts").document(document_id=version).get()
prompt = ref.to_dict()["prompt"].format(input=input, context=context)
return prompt, version
def write_to_database(client_id: str, data: dict):
FIRESTORE_CLIENT.collection("requests").add(data, document_id=client_id)
def multiturn_quality(history, prompt, response):
# Define a pointwise multi-turn chat quality metric
pointwise_chat_quality_metric_prompt = """Evaluate the AI's contribution to a meaningful conversation, considering coherence, fluency, groundedness, and conciseness.
Review the chat history for context. Rate the response on a 1-5 scale, with explanations for each criterion and its overall impact.
# Conversation History
{history}
# Current User Prompt
{prompt}
# AI-generated Response
{response}
"""
freeform_multi_turn_chat_quality_metric = PointwiseMetric(
metric="multi_turn_chat_quality_metric",
metric_prompt_template=pointwise_chat_quality_metric_prompt,
)
eval_dataset = pd.DataFrame(
{
"history": [history],
"prompt": prompt,
"response": response
}
)
# Run evaluation using the defined metric
eval_task = EvalTask(
dataset=eval_dataset,
metrics=[freeform_multi_turn_chat_quality_metric],
)
result = eval_task.evaluate()
return {
"score": result.metrics_table["multi_turn_chat_quality_metric/score"].item(),
"explanation": result.metrics_table["multi_turn_chat_quality_metric/explanation"].item(),
"mean": result.summary_metrics["multi_turn_chat_quality_metric/mean"].item()
}
def get_rag_context(input):
embedding = get_embedding(input)
query = f"""
SELECT *
FROM `{PROJECT_ID}.rag_data.rag_data`
WHERE id IN (
SELECT s.base.id
FROM VECTOR_SEARCH(
TABLE `{PROJECT_ID}.rag_data.embeddings`,
"embeddings",
(SELECT {embedding}),
top_k => 5) as s);
"""
print("waiting")
rows = BIGQUERY_CLIENT.query_and_wait(query)
print("waited")
bodies = [row["body"] for row in rows]
return " ### ".join(bodies)
@ui.page('/')
def index():
async def update_prompt():
print("prompt received")
input_time = time.time()
user_input = user_input_raw.value
with chat_container:
ui.chat_message(user_input, name='Me')
print("spinner")
spinner = ui.spinner('audio', size='lg', color='green')
client_id = str(uuid.uuid4())
request_id = f"{client_id}-{str(uuid.uuid4())[:8]}"
prompt, prompt_version = await run.cpu_bound(prompt_maker, user_input)
app.storage.client["count"] = app.storage.client.get("count", 0) + 1
app.storage.client["history"] = app.storage.client.get("history", "") + "### User: " + prompt
with TRACER.start_as_current_span("child") as span:
span.set_attribute(
"operation.count", app.storage.client["count"])
span.set_attribute("prompt", user_input)
span.set_attribute("prompt_id", prompt_version)
span.set_attribute("client_id", client_id)
span.set_attribute("request_id", request_id)
request_time = time.time()
response = await run.io_bound(make_gemini_prediction, prompt)
# response = make_prediction(user_input)
response_time = time.time()
app.storage.client["history"] = app.storage.client.get("history") + "### Agent: " + response
span.set_attribute("response", response)
spinner.delete()
ui.chat_message(response,
name='Robot',
stamp='now',
avatar='https://robohash.org/ui',) \
.style('font-family: Comic Sans, sans-serif; font-size: 16px;')
query = {
"request_id": request_id,
"prompt": user_input,
"response": response,
"input_time": input_time,
"request_time": request_time,
"response_time": response_time,
"prompt_version": prompt_version
}
print(f"Count: {app.storage.client['count']}")
write_to_database(client_id, query)
# print(multiturn_quality(
# app.storage.client.get("history"),
# prompt,
# response
# ))
ui.markdown("<h2>Welcome to predictions bot!</h2>")
with ui.row().classes('flex flex-col h-screen'):
chat_container = ui.column().classes('w-full max-w-3xl mx-auto my-6')
with ui.footer().classes('bg-black'), ui.column().classes('w-full max-w-3xl mx-auto my-6'):
with ui.row().classes('w-full no-wrap items-center'):
user_input_raw = ui.input("Prompt").on('keydown.enter', update_prompt) \
.props('rounded outlined input-class=mx-3').classes('flex-grow')
ui.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080)), storage_secret="1234", dark=True)