movie_search_metadata/demo_app/backend/scene_search.py (76 lines of code) (raw):

import json import os from typing import List import vertexai import vertexai.preview.generative_models as generative_models from vertexai.preview.generative_models import GenerativeModel, GenerationConfig from google.cloud import storage from search_document import search_documents_by_query from prompt_content_search import PROMPT_CONTENT_SEARCH from utils import get_bucket_and_blobnames from utils import generate_download_signed_url_v4 # global variables from utils import PROJECT_ID vertexai.init(project=PROJECT_ID, location='us-central1') model_flash = GenerativeModel('gemini-1.5-flash') def generate_text(prompt: str, model: GenerativeModel = model_flash, temperature: float = 0.4, top_p: float = 0.4) -> dict: response_schema = { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'Timestamp': { 'type': 'string', }, 'Description': { 'type': 'string', }, }, 'required': ['Timestamp', 'Description'], }, } safety_settings={ generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, } responses = model.generate_content( prompt, generation_config=GenerationConfig( max_output_tokens=8192, temperature=temperature, top_p=top_p, response_mime_type='application/json', response_schema=response_schema ), safety_settings=safety_settings ) result = responses.text return json.loads(result) def search_scenes(query: str, top_n: int = 1, model: GenerativeModel = model_flash) -> List[dict]: storage_client = storage.Client() response = search_documents_by_query(query, show_summary=False) results = [] for doc_id in range(min(top_n, len(response.results))): title = response.results[doc_id].document.derived_struct_data['title'] meta_uri = response.results[doc_id].document.derived_struct_data['link'] bucket_name, blob_metadata, blob_mp4 = get_bucket_and_blobnames(meta_uri) signed_url = generate_download_signed_url_v4(bucket_name, blob_mp4) bucket = storage_client.bucket(bucket_name) metatext = bucket.blob(blob_metadata).download_as_text() prompt = PROMPT_CONTENT_SEARCH.format(query=query, metatext=metatext) temperature = 0.4 while temperature < 1.0: try: result = generate_text(prompt, model=model, temperature=temperature) for r in result: r['signed_url'] = signed_url r['title'] = title results.extend(result) break except Exception as e: print(e) temperature += 0.05 return results