movie_search_metadata/demo_app/backend/scene_search.py (76 lines of code) (raw):
import json
import os
from typing import List
import vertexai
import vertexai.preview.generative_models as generative_models
from vertexai.preview.generative_models import GenerativeModel, GenerationConfig
from google.cloud import storage
from search_document import search_documents_by_query
from prompt_content_search import PROMPT_CONTENT_SEARCH
from utils import get_bucket_and_blobnames
from utils import generate_download_signed_url_v4
# global variables
from utils import PROJECT_ID
vertexai.init(project=PROJECT_ID, location='us-central1')
model_flash = GenerativeModel('gemini-1.5-flash')
def generate_text(prompt: str, model: GenerativeModel = model_flash,
temperature: float = 0.4, top_p: float = 0.4) -> dict:
response_schema = {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'Timestamp': {
'type': 'string',
},
'Description': {
'type': 'string',
},
},
'required': ['Timestamp', 'Description'],
},
}
safety_settings={
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
responses = model.generate_content(
prompt,
generation_config=GenerationConfig(
max_output_tokens=8192,
temperature=temperature,
top_p=top_p,
response_mime_type='application/json',
response_schema=response_schema
),
safety_settings=safety_settings
)
result = responses.text
return json.loads(result)
def search_scenes(query: str, top_n: int = 1,
model: GenerativeModel = model_flash) -> List[dict]:
storage_client = storage.Client()
response = search_documents_by_query(query, show_summary=False)
results = []
for doc_id in range(min(top_n, len(response.results))):
title = response.results[doc_id].document.derived_struct_data['title']
meta_uri = response.results[doc_id].document.derived_struct_data['link']
bucket_name, blob_metadata, blob_mp4 = get_bucket_and_blobnames(meta_uri)
signed_url = generate_download_signed_url_v4(bucket_name, blob_mp4)
bucket = storage_client.bucket(bucket_name)
metatext = bucket.blob(blob_metadata).download_as_text()
prompt = PROMPT_CONTENT_SEARCH.format(query=query, metatext=metatext)
temperature = 0.4
while temperature < 1.0:
try:
result = generate_text(prompt, model=model, temperature=temperature)
for r in result:
r['signed_url'] = signed_url
r['title'] = title
results.extend(result)
break
except Exception as e:
print(e)
temperature += 0.05
return results