colab-enterprise/gen-ai-demo/Menu-A-B-Testing-Generate-Insight-GenAI.ipynb (1,249 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "k6eIqerFOzyj" }, "source": [ "## <img src=\"https://lh3.googleusercontent.com/mUTbNK32c_DTSNrhqETT5aQJYFKok2HB1G2nk2MZHvG5bSs0v_lmDm_ArW7rgd6SDGHXo0Ak2uFFU96X6Xd0GQ=w160-h128\" width=\"45\" valign=\"top\" alt=\"BigQuery\"> Generating A/B Menu items using Gemini Pro and ImageGen2.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### License" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "##################################################################################\n", "# Copyright 2024 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "# \n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "# \n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "###################################################################################" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Notebook Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- This notebook will create new menu desciption and images that allows different trucks to test different menu descriptions and images. The notebook also shows how we can create the new menu table by using Gemini Pro to read an ERD.\n", "\n", "- Notebook Logic:\n", " 1. Create our new A/B menu table by using Gemini Pro to read our ERD and create our table along with primary keys. \n", " 1. First we ened to download a picture of our ERD.\n", " 2. We need to construct a LLM prompt telling it to read the ERD and construct our SQL.\n", " 3. Execute the SQL to create the new table and primary keys on the table.\n", " 2. Select some locations (trucks)\n", " 3. Run a query that sums Oct and Nov sales data by menu item.\n", " - Rank the sales (high drop off) for each city / location / menu item.\n", " - We will generate a new menu purmutation for the item that has the highest drop in sales\n", " 4. Create our LLM prompt:\n", " - We want to generate new data for each of the menu items based upon our BigQuery results\n", " - The prompt will:\n", " - Provide the stating primary key\n", " - Ask for synthetic data to be generated\n", " - Provide a Google Cloud Storage pattern of the image of the menu item\n", " - We will ask the LLM to write our ImageGen2 LLM prompt\n", " - We will pass in the schema of the table\n", " 5. Execute the SQL adding rows to our BigQuery table\n", " 6. Query the new menu items, retrieving the LLM image prompt\n", " 7. Generate the menu item image using ImageGen2.\n", " 8. Upload the image to GCS. The path will match what the LLM generated for the path." ] }, { "cell_type": "markdown", "metadata": { "id": "8zy0eEJmHxRZ" }, "source": [ "## Initialize Python" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3wiruT266H3e" }, "outputs": [], "source": [ "project_id=\"${project_id}\"\n", "location=\"us-central1\"\n", "model_id = \"imagegeneration@005\"\n", "\n", "# No need to set these\n", "city_names=[\"New York City\", \"London\", \"Tokyo\", \"San Francisco\"]\n", "city_ids=[1,2,3,4]\n", "city_languages=[\"American English\", \"British English\", \"Japanese\", \"American English\"]\n", "number_of_coffee_trucks = \"4\"\n", "\n", "dataset_id = \"data_beans_synthetic_data\"\n", "\n", "gcs_storage_bucket = \"${data_beans_curated_bucket}\"\n", "gcs_storage_path = \"data-beans/menu-images-a-b-testing/\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z4NpP0pCH0pj" }, "outputs": [], "source": [ "from PIL import Image\n", "from IPython.display import HTML\n", "import IPython.display\n", "import google.auth\n", "import requests\n", "import json\n", "import uuid\n", "import base64\n", "import os\n", "import cv2\n", "\n", "from google.cloud import bigquery\n", "client = bigquery.Client()" ] }, { "cell_type": "markdown", "metadata": { "id": "YtZuFgjbOjso" }, "source": [ "## ImageGen2 / Gemini Pro / Gemini Pro Vision (Helper Functions)" ] }, { "cell_type": "markdown", "metadata": { "id": "xUolPsMFOjpZ" }, "source": [ "#### ImageGen2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LPf6NurhNi2l" }, "outputs": [], "source": [ "def ImageGen(prompt):\n", " creds, project = google.auth.default()\n", " auth_req = google.auth.transport.requests.Request() # required to acess access token\n", " creds.refresh(auth_req)\n", " access_token=creds.token\n", "\n", " headers = {\n", " \"Content-Type\" : \"application/json\",\n", " \"Authorization\" : \"Bearer \" + access_token\n", " }\n", "\n", " # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/image-generation\n", " url = f\"https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/google/models/imagegeneration:predict\"\n", "\n", " payload = {\n", " \"instances\": [\n", " {\n", " \"prompt\": prompt\n", " }\n", " ],\n", " \"parameters\": {\n", " \"sampleCount\": 1\n", " }\n", " }\n", "\n", " response = requests.post(url, json=payload, headers=headers)\n", "\n", " if response.status_code == 200:\n", " image_data = json.loads(response.content)[\"predictions\"][0][\"bytesBase64Encoded\"]\n", " image_data = base64.b64decode(image_data)\n", " filename= str(uuid.uuid4()) + \".png\"\n", " with open(filename, \"wb\") as f:\n", " f.write(image_data)\n", " print(f\"Image generated OK.\")\n", " return filename\n", " else:\n", " error = f\"Error with prompt:'{prompt}' Status:'{response.status_code}' Text:'{response.text}'\"\n", " raise RuntimeError(error)" ] }, { "cell_type": "markdown", "metadata": { "id": "E5CFSdK3HxYm" }, "source": [ "#### Gemini Pro LLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9jTBzcSIMbwg" }, "outputs": [], "source": [ "def GeminiProLLM(prompt, temperature = .8, topP = .8, topK = 40):\n", "\n", " if temperature < 0:\n", " temperature = 0\n", "\n", " creds, project = google.auth.default()\n", " auth_req = google.auth.transport.requests.Request() # required to acess access token\n", " creds.refresh(auth_req)\n", " access_token=creds.token\n", "\n", " headers = {\n", " \"Content-Type\" : \"application/json\",\n", " \"Authorization\" : \"Bearer \" + access_token\n", " }\n", "\n", " # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini\n", " url = f\"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/gemini-2.0-flash:streamGenerateContent\"\n", "\n", " payload = {\n", " \"contents\": {\n", " \"role\": \"user\",\n", " \"parts\": {\n", " \"text\": prompt\n", " },\n", " },\n", " \"safety_settings\": {\n", " \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n", " \"threshold\": \"BLOCK_LOW_AND_ABOVE\"\n", " },\n", " \"generation_config\": {\n", " \"temperature\": temperature,\n", " \"topP\": topP,\n", " \"topK\": topK,\n", " \"maxOutputTokens\": 8192,\n", " \"candidateCount\": 1\n", " }\n", " }\n", "\n", " response = requests.post(url, json=payload, headers=headers)\n", "\n", " if response.status_code == 200:\n", " json_response = json.loads(response.content)\n", " llm_response = \"\"\n", " for item in json_response:\n", " try:\n", " llm_response = llm_response + item[\"candidates\"][0][\"content\"][\"parts\"][0][\"text\"]\n", " except Exception as err:\n", " print(f\"response.content: {response.content}\")\n", " raise RuntimeError(err)\n", "\n", " # Remove some typically response characters (if asking for a JSON reply)\n", " llm_response = llm_response.replace(\"```json\",\"\")\n", " llm_response = llm_response.replace(\"```\",\"\")\n", "\n", " # print(f\"llm_response:\\n{llm_response}\")\n", " return llm_response\n", " else:\n", " error = f\"Error with prompt:'{prompt}' Status:'{response.status_code}' Text:'{response.text}'\"\n", " raise RuntimeError(error)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-L93udtrH1Oz" }, "source": [ "#### Gemini Pro Vision LLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ecvrUyp0BcXg" }, "outputs": [], "source": [ "# Use the Gemini with Vision\n", "def GeminiProVisionLLM(prompt, imageBase64, temperature = .4, topP = 1, topK = 32):\n", "\n", " if temperature < 0:\n", " temperature = 0\n", "\n", " creds, project = google.auth.default()\n", " auth_req = google.auth.transport.requests.Request() # required to acess access token\n", " creds.refresh(auth_req)\n", " access_token=creds.token\n", "\n", " headers = {\n", " \"Content-Type\" : \"application/json\",\n", " \"Authorization\" : \"Bearer \" + access_token\n", " }\n", "\n", " # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini\n", " url = f\"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/gemini-2.0-flash:streamGenerateContent\"\n", "\n", " payload = {\n", " \"contents\": [\n", " {\n", " \"role\": \"user\",\n", " \"parts\": [\n", " {\n", " \"text\": prompt\n", " },\n", " {\n", " \"inlineData\": {\n", " \"mimeType\": \"image/png\",\n", " \"data\": f\"{imageBase64}\"\n", " }\n", " }\n", " ]\n", " }\n", " ],\n", " \"safety_settings\": {\n", " \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n", " \"threshold\": \"BLOCK_LOW_AND_ABOVE\"\n", " },\n", " \"generation_config\": {\n", " \"temperature\": temperature,\n", " \"topP\": topP,\n", " \"topK\": topK,\n", " \"maxOutputTokens\": 2048,\n", " \"candidateCount\": 1\n", " }\n", " }\n", "\n", " response = requests.post(url, json=payload, headers=headers)\n", "\n", " if response.status_code == 200:\n", " json_response = json.loads(response.content)\n", " llm_response = \"\"\n", " for item in json_response:\n", " llm_response = llm_response + item[\"candidates\"][0][\"content\"][\"parts\"][0][\"text\"]\n", "\n", " # Remove some typically response characters (if asking for a JSON reply)\n", " llm_response = llm_response.replace(\"```json\",\"\")\n", " llm_response = llm_response.replace(\"```\",\"\")\n", "\n", " # print(f\"llm_response:\\n{llm_response}\")\n", " return llm_response\n", " else:\n", " error = f\"Error with prompt:'{prompt}' Status:'{response.status_code}' Text:'{response.text}'\"\n", " raise RuntimeError(error)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QNz6pofvfXDS" }, "outputs": [], "source": [ "# Use the Gemini with Vision\n", "def GeminiProVisionMultipleFileLLM(prompt, image_prompt, temperature = .4, topP = 1, topK = 32):\n", " creds, project = google.auth.default()\n", " auth_req = google.auth.transport.requests.Request() # required to acess access token\n", " creds.refresh(auth_req)\n", " access_token=creds.token\n", "\n", " headers = {\n", " \"Content-Type\" : \"application/json\",\n", " \"Authorization\" : \"Bearer \" + access_token\n", " }\n", "\n", " # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini\n", " url = f\"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/gemini-2.0-flash:streamGenerateContent\"\n", "\n", "\n", " parts = []\n", " new_item = {\n", " \"text\": prompt\n", " }\n", " parts.append(new_item)\n", "\n", " for item in image_prompt:\n", " new_item = {\n", " \"text\": f\"Image Name: {item['llm_image_filename']}:\\n\"\n", " }\n", " parts.append(new_item)\n", " new_item = {\n", " \"inlineData\": {\n", " \"mimeType\": \"image/png\",\n", " \"data\": item[\"llm_image_base64\"]\n", " }\n", " }\n", " parts.append(new_item)\n", "\n", " payload = {\n", " \"contents\": [\n", " {\n", " \"role\": \"user\",\n", " \"parts\": parts\n", " }\n", " ],\n", " \"safety_settings\": {\n", " \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n", " \"threshold\": \"BLOCK_LOW_AND_ABOVE\"\n", " },\n", " \"generation_config\": {\n", " \"temperature\": temperature,\n", " \"topP\": topP,\n", " \"topK\": topK,\n", " \"maxOutputTokens\": 2048,\n", " \"candidateCount\": 1\n", " }\n", " }\n", "\n", " response = requests.post(url, json=payload, headers=headers)\n", "\n", " if response.status_code == 200:\n", " json_response = json.loads(response.content)\n", " llm_response = \"\"\n", " for item in json_response:\n", " llm_response = llm_response + item[\"candidates\"][0][\"content\"][\"parts\"][0][\"text\"]\n", "\n", " # Remove some typically response characters (if asking for a JSON reply)\n", " llm_response = llm_response.replace(\"```json\",\"\")\n", " llm_response = llm_response.replace(\"```\",\"\")\n", "\n", " # print(f\"llm_response:\\n{llm_response}\")\n", " return llm_response\n", " else:\n", " error = f\"Error with prompt:'{prompt}' Status:'{response.status_code}' Text:'{response.text}'\"\n", " raise RuntimeError(error)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rVCY93IyXPoO" }, "source": [ "#### SQL Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jHCtXYuRNU0p" }, "outputs": [], "source": [ "def RunQuery(sql):\n", " import time\n", "\n", " if (sql.startswith(\"SELECT\") or sql.startswith(\"WITH\")):\n", " df_result = client.query(sql).to_dataframe()\n", " return df_result\n", " else:\n", " job_config = bigquery.QueryJobConfig(priority=bigquery.QueryPriority.INTERACTIVE)\n", " query_job = client.query(sql, job_config=job_config)\n", "\n", " # Check on the progress by getting the job's updated state.\n", " query_job = client.get_job(\n", " query_job.job_id, location=query_job.location\n", " )\n", " print(\"Job {} is currently in state {} with error result of {}\".format(query_job.job_id, query_job.state, query_job.error_result))\n", "\n", " while query_job.state != \"DONE\":\n", " time.sleep(2)\n", " query_job = client.get_job(\n", " query_job.job_id, location=query_job.location\n", " )\n", " print(\"Job {} is currently in state {} with error result of {}\".format(query_job.job_id, query_job.state, query_job.error_result))\n", "\n", " if query_job.error_result == None:\n", " return True\n", " else:\n", " return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p9j1GdAwNifB" }, "outputs": [], "source": [ "def GetNextPrimaryKey(fully_qualified_table_name, field_name):\n", " sql = f\"\"\"\n", " SELECT IFNULL(MAX({field_name}),0) AS result\n", " FROM `{fully_qualified_table_name}`\n", " \"\"\"\n", " # print(sql)\n", " df_result = client.query(sql).to_dataframe()\n", " # display(df_result)\n", " return df_result['result'].iloc[0] + 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ws3i8HVGkk5p" }, "outputs": [], "source": [ "def GetTableSchema(dataset_name, table_name):\n", " import io\n", "\n", " dataset_ref = client.dataset(dataset_name, project=project_id)\n", " table_ref = dataset_ref.table(table_name)\n", " table = client.get_table(table_ref)\n", "\n", " f = io.StringIO(\"\")\n", " client.schema_to_json(table.schema, f)\n", " return f.getvalue()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qx0q3yUNl63e" }, "outputs": [], "source": [ "def GetStartingValue(dataset_name, table_name, field_name):\n", " sql = f\"\"\"\n", " SELECT IFNULL(MAX({field_name}),0) + 1 AS result\n", " FROM `{project_id}.{dataset_name}.{table_name}`\n", " \"\"\"\n", " #print(sql)\n", " df_result = client.query(sql).to_dataframe()\n", " #display(df_result)\n", " return df_result['result'].iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6z5Fc5Mel6sH" }, "outputs": [], "source": [ "def GetForeignKeys(dataset_name, table_name, field_name):\n", " sql = f\"\"\"\n", " SELECT STRING_AGG(CAST({field_name} AS STRING), \",\" ORDER BY {field_name}) AS result\n", " FROM `{project_id}.{dataset_name}.{table_name}`\n", " \"\"\"\n", " #print(sql)\n", " df_result = client.query(sql).to_dataframe()\n", " #display(df_result)\n", " return df_result['result'].iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9-6THGgjnDKg" }, "outputs": [], "source": [ "def GetDistinctValues(dataset_name, table_name, field_name):\n", " sql = f\"\"\"\n", " SELECT STRING_AGG(DISTINCT {field_name}, \",\" ) AS result\n", " FROM `{project_id}.{dataset_name}.{table_name}`\n", " \"\"\"\n", " #print(sql)\n", " df_result = client.query(sql).to_dataframe()\n", " #display(df_result)\n", " return df_result['result'].iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run DDL\n", "# This has been done to make the code re-runnable for the demo\n", "# We need to add some logic for the PK and FKs and the prompt will be updated to code gen it\n", "\n", "def RunDDL(sql):\n", " import time\n", "\n", " sql = f\"\"\"CREATE SCHEMA IF NOT EXISTS `{project_id}.{dataset_id}`;\n", "\n", "CREATE TABLE IF NOT EXISTS `{project_id}.{dataset_id}.menu_a_b_testing`\n", "(\n", " menu_a_b_testing_id INT64 NOT NULL OPTIONS(description=\"Primary key. Menu A/B Testing table.\"),\n", " menu_id INT64 NOT NULL OPTIONS(description=\"Foreign key: Menu table.\"),\n", " location_id INT64 NOT NULL OPTIONS(description=\"Foreign key: Location table.\"),\n", " item_name STRING NOT NULL OPTIONS(description=\"The name of the item.\"),\n", " item_description STRING NOT NULL OPTIONS(description=\"The description of the item.\"),\n", " item_size STRING NOT NULL OPTIONS(description=\"The size of the item.\"),\n", " item_price FLOAT64 NOT NULL OPTIONS(description=\"The price of the item.\"),\n", " llm_item_description_prompt STRING OPTIONS(description=\"The prompt used to generate the LLM item description.\"),\n", " llm_item_description STRING OPTIONS(description=\"The LLM generated description of the item.\"),\n", " llm_item_image_prompt STRING OPTIONS(description=\"The prompt used to generate the LLM item image.\"),\n", " llm_item_image_url STRING OPTIONS(description=\"The LLM generated image url of the item.\"),\n", " create_date TIMESTAMP NOT NULL OPTIONS(description=\"The date the item was created.\"),\n", " llm_marketing_prompt STRING OPTIONS(description=\"The prompt used to generate the LLM marketing response.\"),\n", " llm_marketing_response JSON OPTIONS(description=\"The LLM generated marketing response.\"),\n", " llm_marketing_parsed_response JSON OPTIONS(description=\"The parsed LLM generated marketing response.\"),\n", " html_generated BOOLEAN OPTIONS(description=\"True if the HTML was generated.\"),\n", " html_filename STRING OPTIONS(description=\"The name of the HTML file.\"),\n", " html_url STRING OPTIONS(description=\"The URL of the HTML file.\")\n", ")\n", "CLUSTER BY menu_a_b_testing_id;\n", "\n", "\n", "ALTER TABLE `{project_id}.{dataset_id}.menu_a_b_testing` DROP PRIMARY KEY IF EXISTS;\n", "ALTER TABLE `{project_id}.{dataset_id}.menu_a_b_testing` ADD PRIMARY KEY (menu_a_b_testing_id) NOT ENFORCED;\n", " \"\"\"\n", "\n", " # To see the contraints in BigQuery\n", " # SELECT * FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS`\n", "\n", "\n", " job_config = bigquery.QueryJobConfig(priority=bigquery.QueryPriority.INTERACTIVE)\n", " query_job = client.query(sql, job_config=job_config)\n", "\n", " # Check on the progress by getting the job's updated state.\n", " query_job = client.get_job(\n", " query_job.job_id, location=query_job.location\n", " )\n", " print(\"Job {} is currently in state {} with error result of {}\".format(query_job.job_id, query_job.state, query_job.error_result))\n", "\n", " while query_job.state != \"DONE\":\n", " time.sleep(2)\n", " query_job = client.get_job(\n", " query_job.job_id, location=query_job.location\n", " )\n", " print(\"Job {} is currently in state {} with error result of {}\".format(query_job.job_id, query_job.state, query_job.error_result))\n", "\n", " if query_job.error_result == None:\n", " return True\n", " else:\n", " return False" ] }, { "cell_type": "markdown", "metadata": { "id": "BlxddNzpmAgp" }, "source": [ "#### Helper Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E1yrPjvVXNCz" }, "outputs": [], "source": [ "def convert_png_to_base64(image_path):\n", " image = cv2.imread(image_path)\n", "\n", " # Convert the image to a base64 string.\n", " _, buffer = cv2.imencode('.png', image)\n", " base64_string = base64.b64encode(buffer).decode('utf-8')\n", "\n", " return base64_string" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ylq2crklNuCB" }, "outputs": [], "source": [ "# This was generated by GenAI\n", "\n", "def copy_file_to_gcs(local_file_path, bucket_name, destination_blob_name):\n", " \"\"\"Copies a file from a local drive to a GCS bucket.\n", "\n", " Args:\n", " local_file_path: The full path to the local file.\n", " bucket_name: The name of the GCS bucket to upload to.\n", " destination_blob_name: The desired name of the uploaded file in the bucket.\n", "\n", " Returns:\n", " None\n", " \"\"\"\n", "\n", " import os\n", " from google.cloud import storage\n", "\n", " # Ensure the file exists locally\n", " if not os.path.exists(local_file_path):\n", " raise FileNotFoundError(f\"Local file '{local_file_path}' not found.\")\n", "\n", " # Create a storage client\n", " storage_client = storage.Client()\n", "\n", " # Get a reference to the bucket\n", " bucket = storage_client.bucket(bucket_name)\n", "\n", " # Create a blob object with the desired destination path\n", " blob = bucket.blob(destination_blob_name)\n", "\n", " # Upload the file from the local filesystem\n", " content_type = \"\"\n", " if local_file_path.endswith(\".html\"):\n", " content_type = \"text/html; charset=utf-8\"\n", "\n", " if local_file_path.endswith(\".json\"):\n", " content_type = \"application/json; charset=utf-8\"\n", "\n", " if content_type == \"\":\n", " blob.upload_from_filename(local_file_path)\n", " else:\n", " blob.upload_from_filename(local_file_path, content_type = content_type)\n", "\n", " print(f\"File '{local_file_path}' uploaded to GCS bucket '{bucket_name}' as '{destination_blob_name}. Content-Type: {content_type}'.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MOnML8jpdwzg" }, "source": [ "## Menu Synthetic Data and Image Generation" ] }, { "cell_type": "markdown", "metadata": { "id": "xk2c2lvsnHZI" }, "source": [ "#### Download the ERD image and Generate our Database Schema DDL" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0E29ZF_newIa" }, "outputs": [], "source": [ "import requests\n", "import shutil\n", "\n", "menu_erd_filename = \"Data-Beans-Menu-A-B-Testing-ERD.png\"\n", "\n", "# Specify the image URL\n", "img_url = f\"https://storage.googleapis.com/data-analytics-golden-demo/data-beans/v1/colab-supporting-images/{menu_erd_filename}\"\n", "\n", "# Send a GET request to fetch the image\n", "response = requests.get(img_url, stream=True)\n", "\n", "# Check for successful download\n", "if response.status_code == 200:\n", " # Set decode_content to True to prevent encoding errors\n", " response.raw.decode_content = True\n", "\n", " # Open a local file in binary write mode\n", " with open(menu_erd_filename, \"wb\") as f:\n", " # Copy image data to the local file in chunks\n", " shutil.copyfileobj(response.raw, f)\n", "\n", " print(\"Image downloaded successfully!\")\n", "else:\n", " print(\"Image download failed with status code:\", response.status_code)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JZm_BtyiflQh" }, "outputs": [], "source": [ "print(f\"Filename: {menu_erd_filename}\")\n", "img = Image.open(menu_erd_filename)\n", "img.thumbnail([504,700]) # width, height\n", "IPython.display.display(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MJS1jALgeBL-" }, "outputs": [], "source": [ "llm_erd_prompt=f\"\"\"Use BigQuery SQL commands to create the following:\n", "- Create a new BigQuery schema named \"{dataset_id}\".\n", "- Use only BigQuery datatypes. Double and triple check this since it causes a lot of errors.\n", "- Create the BigQuery DDLs for the attached ERD.\n", "- Create primary keys for each table using the ALTER command. Use the \"NOT ENFORCED\" keyword.\n", "- DO NOT Create foreign keys for each table.\n", "- For each field add an OPTIONS for the description.\n", "- Cluster the table by the primary key.\n", "- For columns that can be null do not add \"NULL\" to the create table statement. BigQuery leaves this blank. For example this is INCORRECT: STRING NULL OPTIONS.\n", "- All ALTER TABLE statements should by at the bottom of the generated script.\n", "- The ALTER TABLES statements should be order by the primary key statements and then the foreign key statements. Order matters!\n", "- Double check your work especially that you used ONLY BigQuery data types.\n", "- Double check that NULL was not specified for NULLABLE fields.\n", "- Only create the tables shown in the diagram. Do not create foreign key tables.\n", "\n", "Previous Errors that have been generated by this script. Be sure to check your work to avoid encountering these.\n", "- Query error: Type not found: FLOAT at [6:12]\n", "- Query error: Table test.company does not have Primary Key constraints\n", "- Query error: Syntax error: Expected \")\" or \",\" but got keyword NULL\n", "\n", "Example:\n", "CREATE TABLE IF NOT EXISTS `{project_id}.{dataset_id}.customer`\n", "(\n", " customer_id INTEGER NOT NULL OPTIONS(description=\"Primary key. Customer table.\"),\n", " country_id INTEGER NOT NULL OPTIONS(description=\"Foreign key: Country table.\"),\n", " customer_llm_summary STRING NOT NULL OPTIONS(description=\"LLM generated summary of customer data.\"),\n", " customer_lifetime_value STRING NOT NULL OPTIONS(description=\"Total sales for this customer.\"),\n", " customer_cluster_id FLOAT NOT NULL OPTIONS(description=\"Clustering algorithm id.\"),\n", " customer_review_llm_summary STRING OPTIONS(description=\"LLM summary are all of the customer reviews.\"),\n", " customer_survey_llm_summary STRING OPTIONS(description=\"LLM summary are all of the customer surveys.\")\n", ")\n", "CLUSTER BY customer_id;\n", "\n", "CREATE TABLE IF NOT EXISTS `{project_id}.{dataset_id}.country`\n", "(\n", "country_id INTEGER NOT NULL OPTIONS(description=\"Primary key. Country table.\"),\n", "country_name STRING NOT NULL OPTIONS(description=\"The name of the country.\")\n", ")\n", "CLUSTER BY country_id;\n", "\n", "\n", "ALTER TABLE `{project_id}.{dataset_id}.customer` ADD PRIMARY KEY (customer_id) NOT ENFORCED;\n", "ALTER TABLE `{project_id}.{dataset_id}.country` ADD PRIMARY KEY (country_id) NOT ENFORCED;\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Am9jrbXVWfv0" }, "outputs": [], "source": [ "# Run the LLM to generate the DDL and run the DDL\n", "imageBase64 = convert_png_to_base64(menu_erd_filename)\n", "\n", "llm_success = False\n", "temperature=.2\n", "while llm_success == False:\n", " try:\n", " sql = GeminiProVisionLLM(llm_erd_prompt, imageBase64, temperature=temperature, topP=1, topK=32)\n", " # Need to prompt this\n", " sql = sql.replace(\"STRING NULL OPTIONS\",\"STRING OPTIONS\")\n", " sql = sql.replace(\"JSON NULL OPTIONS\",\"JSON OPTIONS\")\n", " sql = sql.replace(\"BOOLEAN NULL OPTIONS\",\"BOOLEAN OPTIONS\")\n", " print(f\"SQL: {sql}\")\n", " llm_success = RunDDL(sql)\n", " except:\n", " # Reduce the temperature for more accurate generation\n", " temperature = temperature - .05\n", " print(\"Regenerating...\")" ] }, { "cell_type": "markdown", "metadata": { "id": "BbKwJYMV0hVO" }, "source": [ "### Find the location you want to update" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l14KCyGS0sQG" }, "outputs": [], "source": [ "# Set your location ids\n", "\n", "location_ids = [2,43,10,30]\n", "# location_id = 1 # CitySips Roaming Cafe(New York City)\n", "# location_id = 2 # JavaJourney Express(London)\n", "# location_id = 3 # UrbanCaffeine Cruiser(London)\n", "# location_id = 4 # CafeWheels(Tokyo)\n", "# location_id = 5 # Golden Gate Grind Mobile(San Francisco)\n", "# location_id = 6 # Sunshine Sips on Wheels(San Francisco)\n", "# location_id = 7 # Big Apple Brew Bus(London)\n", "# location_id = 8 # Bay Brew Hauler(San Francisco)\n", "# location_id = 9 # Magic City Mocha Mobile(Tokyo)\n", "# location_id = 10 # Metropolis Mug Mover(Tokyo)\n", "# location_id = 11 # Nectar Nomad(Tokyo)\n", "# location_id = 12 # Street Sips(New York City)\n", "# location_id = 13 # MiaMornings Mobile(London)\n", "# location_id = 14 # CityBeans Roam-uccino(New York City)\n", "# location_id = 15 # Sunrise City Sipper(Tokyo)\n", "# location_id = 16 # Gotham Grind on Wheels(New York City)\n", "# location_id = 17 # Bay Area Bean Bus(San Francisco)\n", "# location_id = 18 # Mia Mochaccino Mobile(London)\n", "# location_id = 19 # Cityscape Sip Stop(Tokyo)\n", "# location_id = 20 # Transit Brew Buggy(London)\n", "# location_id = 21 # Fog City Fueler(London)\n", "# location_id = 22 # Metro Mugs(London)\n", "# location_id = 23 # Espresso Express(Tokyo)\n", "# location_id = 24 # Sunny Side Sips Shuttle(London)\n", "# location_id = 25 # Empire Espresso Explorer(New York City)\n", "# location_id = 26 # SF Sidewalk Sipper(San Francisco)\n", "# location_id = 27 # Beachside Brew Bounder(San Francisco)\n", "# location_id = 28 # Urban Sipper's Shuttle(London)\n", "# location_id = 29 # Nomadic Nectar(London)\n", "# location_id = 30 # Golden Bridge Brewmobile(San Francisco)\n", "# location_id = 31 # Sunny State Sipster(San Francisco)\n", "# location_id = 32 # Cafe Cruiser Central(Tokyo)\n", "# location_id = 33 # Neighborhood Nectar(Tokyo)\n", "# location_id = 34 # Frisco Fuel on Wheels(Tokyo)\n", "# location_id = 35 # MiaMug Mobility(New York City)\n", "# location_id = 36 # Metropolitan Mochaccino(London)\n", "# location_id = 37 # CitySips Street Surfer(New York City)\n", "# location_id = 38 # Golden Gate Gourmet Glide(San Francisco)\n", "# location_id = 39 # Beach Breeze Brew Bus(San Francisco)\n", "# location_id = 40 # City Roast Cruiser(Tokyo)\n", "# location_id = 41 # Urban Uplifter(New York City)\n", "# location_id = 42 # Frisco Fresh Brews(New York City)\n", "# location_id = 43 # Magic Mugs(London)\n", "# location_id = 44 # Coffee Cart Connection(New York City)\n", "# location_id = 45 # Empire City Espresso Explorer(New York City)\n", "# location_id = 46 # Golden Glow Grind Rover(San Francisco)\n", "# location_id = 47 # Sun-Kissed Sip & Go(San Francisco)\n", "# location_id = 48 # CityLife Latte Lorry(New York City)\n", "# location_id = 49 # Cityscape Sipper Shuttle(Tokyo)\n", "# location_id = 50 # Golden Grind Getter(San Francisco)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yNVfPUft0lAF" }, "outputs": [], "source": [ "location_id_string = ' '.join(map(str, location_ids))\n", "\n", "sql = f\"\"\"WITH oct_data AS\n", "(\n", "SELECT order_item.menu_id,\n", " order_t.location_id,\n", " CAST(SUM(order_item.item_total) AS INT64) AS sum_item_total\n", " FROM `${project_id}.data_beans_curated.order_item` AS order_item\n", " INNER JOIN `${project_id}.data_beans_curated.order` AS order_t\n", " ON order_t.order_id = order_item.order_id\n", " AND order_t.order_datetime BETWEEN '2023-10-01'\n", " AND '2023-10-31'\n", "GROUP BY 1,2\n", ")\n", ", nov_data AS\n", "(\n", "SELECT order_item.menu_id,\n", " order_t.location_id,\n", " CAST(SUM(order_item.item_total) AS INT64) AS sum_item_total\n", " FROM `${project_id}.data_beans_curated.order_item` AS order_item\n", " INNER JOIN `${project_id}.data_beans_curated.order` AS order_t\n", " ON order_t.order_id = order_item.order_id\n", " AND order_t.order_datetime BETWEEN '2023-11-01'\n", " AND '2023-11-30'\n", "GROUP BY 1,2\n", ")\n", ", results AS\n", "(\n", "SELECT city.city_name,\n", " nov_data.location_id,\n", " location.location_name,\n", " menu.*,\n", " nov_data.sum_item_total AS current_month_sales,\n", " oct_data.sum_item_total AS prior_month_sales,\n", " (oct_data.sum_item_total - nov_data.sum_item_total) AS sales_drop_off,\n", " ROW_NUMBER() OVER (PARTITION BY city.city_name, nov_data.location_id ORDER BY (oct_data.sum_item_total - nov_data.sum_item_total)) AS ranking\n", " FROM nov_data\n", " INNER JOIN `${project_id}.data_beans_curated.menu` AS menu\n", " ON menu.menu_id = nov_data.menu_id\n", " INNER JOIN `${project_id}.data_beans_curated.location` AS location\n", " ON location.location_id = nov_data.location_id\n", " INNER JOIN `${project_id}.data_beans_curated.city` AS city\n", " ON city.city_id = location.city_id\n", " LEFT JOIN oct_data\n", " ON nov_data.menu_id = oct_data.menu_id\n", " AND nov_data.location_id = oct_data.location_id\n", ")\n", "SELECT *\n", "FROM results AS t\n", "WHERE ranking = 1\n", " AND location_id IN ({str(location_ids).replace('[','').replace(']','') }) -- comment this out to see all locations\n", "ORDER BY city_name, location_id, ranking;\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mG4zxCvg082N" }, "outputs": [], "source": [ "df_low_sales_item = client.query(sql).to_dataframe()\n", "items_to_generate = []\n", "\n", "for row in df_low_sales_item.itertuples():\n", " items_to_generate.append ({\n", " \"city_name\" : row.city_name,\n", " \"current_month_sales\" : row.current_month_sales,\n", " \"prior_month_sales\" : row.prior_month_sales,\n", " \"sales_drop_off\" : row.sales_drop_off,\n", " \"menu_id\" : row.menu_id,\n", " \"item_name\" : row.item_name,\n", " \"company_id\" : row.company_id,\n", " \"location_id\" : row.location_id,\n", " \"location_name\" : row.location_name,\n", " \"item_price\" : row.item_price,\n", " \"item_description\" : row.item_description,\n", " \"item_size\" : row.item_size,\n", " \"llm_item_description_prompt\" : row.llm_item_description_prompt,\n", " \"llm_item_description\" : row.llm_item_description,\n", " \"llm_item_image_prompt\" : row.llm_item_image_prompt,\n", " \"llm_item_image_url\" : row.llm_item_image_url,\n", " })\n", "\n", "for item in items_to_generate:\n", " print(f\"location_id: {item['location_id']}\")\n", " print(f\"location_name: {item['location_name']}\")\n", " print(f\"city_name: {item['city_name']}\")\n", " print(f\"current_month_sales: {item['current_month_sales']}\")\n", " print(f\"prior_month_sales: {item['prior_month_sales']}\")\n", " print(f\"sales_drop_off: {item['sales_drop_off']}\")\n", " print(\"\")\n", " print(f\"menu_id: {item['menu_id']}\")\n", " print(f\"item_name: {item['item_name']}\")\n", " print(f\"company_id: {item['company_id']}\")\n", " print(f\"location_id: {item['location_id']}\")\n", " print(f\"item_price: {item['item_price']}\")\n", " print(f\"item_description: {item['item_description']}\")\n", " print(f\"item_size: {item['item_size']}\")\n", " print(f\"llm_item_description_prompt: {item['llm_item_description_prompt']}\")\n", " print(f\"llm_item_description: {item['llm_item_description']}\")\n", " print(f\"llm_item_image_prompt: {item['llm_item_image_prompt']}\")\n", " print(f\"llm_item_image_url: {item['llm_item_image_url']}\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ZcLWRmMy0Mpm" }, "source": [ "### Generate new images" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PnSFlpIa4A73" }, "outputs": [], "source": [ "table_name = \"menu_a_b_testing\"\n", "primary_key = \"menu_a_b_testing_id\"\n", "\n", "schema = GetTableSchema(dataset_id, table_name)\n", "starting_value = GetStartingValue(dataset_id, table_name, primary_key)\n", "current_value = starting_value\n", "\n", "for item in items_to_generate:\n", " menu_names_sql_prompt=f\"\"\"\n", " You are a database engineer and need to generate data for a table for the below schema.\n", " - The schema is for a Google Cloud BigQuery Table.\n", " - The table name is \"{project_id}.{dataset_id}.{table_name}\".\n", " - Read the description of each field for valid values.\n", " - Do not preface the response with any special characters or 'sql'.\n", " - Generate 1 rows of data for this table.\n", " - The starting value of the field {primary_key} is {current_value}.\n", " - Only generate a single statement, not multiple INSERTs.\n", " - Hardcode the field \"menu_id\" to the value of \"{item['menu_id']}\".\n", " - Hardcode the field \"location_id\" to the value of \"{item['location_id']}\".\n", " - Hardcode the field \"item_name\" to the value of \"{item['item_name']}\".\n", " - Hardcode the field \"company_id\" to the value of \"{item['company_id']}\".\n", " - Hardcode the field \"item_price\" to the value of \"{item['item_price']}\".\n", " - Hardcode the field \"item_size\" to the value of \"{item['item_size']}\".\n", " - Hardcode the field \"menu_id\" to the value of \"{item['menu_id']}\".\n", " - For the field \"llm_item_image_prompt\", limit the text to 256 characters.\n", " - For the field \"llm_item_image_prompt\": Think outside the box that encourages unconventional approaches and fresh perspectives based upon the item description.\n", " - For the field \"llm_item_image_url\" use the following pattern and replace [[menu_a_b_testing_id]] with the generated menu id: https://storage.cloud.google.com/{gcs_storage_bucket}/{gcs_storage_path}[[menu_a_b_testing_id]].png\n", "\n", " Example 1: INSERT INTO `my-dataset.my-dataset.my-table` (field_1, field_2) VALUES (1, 'Sample'),(2, 'Sample');\n", " Example 2: INSERT INTO `my-dataset.my-dataset.my-table` (field_1, field_2) VALUES (1, 'Data'),(2, 'Data'),(3, 'Data');\n", "\n", " Schema: {schema}\n", " \"\"\"\n", "\n", " llm_success = False\n", " temperature=.8\n", " while llm_success == False:\n", " try:\n", " sql = GeminiProLLM(menu_names_sql_prompt, temperature=temperature, topP=.8, topK = 40)\n", " print(f\"SQL: {sql}\")\n", " llm_success = RunQuery(sql)\n", " if llm_success == True:\n", " current_value = current_value + 1\n", " except:\n", " # Reduce the temperature for more accurate generation\n", " temperature = temperature - .05\n", " print(\"Regenerating...\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VmOKq9nzrbhj" }, "outputs": [], "source": [ "# Query to get a list of menu items and the prompts for the images\n", "\n", "sql = f\"\"\"SELECT menu_a_b_testing_id,\n", " item_name,\n", " llm_item_image_prompt\n", " FROM `{project_id}.{dataset_id}.{table_name}`\n", " WHERE menu_a_b_testing_id BETWEEN {starting_value} AND {current_value - 1}\n", " ORDER BY menu_id\"\"\"\n", "\n", "print(f\"SQL: {sql}\")\n", "df_process = client.query(sql).to_dataframe()\n", "image_files = []\n", "\n", "for row in df_process.itertuples():\n", " menu_a_b_testing_id = row.menu_a_b_testing_id\n", " item_name = row.item_name\n", " llm_item_image_prompt = row.llm_item_image_prompt\n", "\n", " print(f\"item_name: {item_name}\")\n", " print(f\"llm_item_image_prompt: {llm_item_image_prompt}\")\n", " try:\n", " image_file = ImageGen(llm_item_image_prompt)\n", " image_files.append ({\n", " \"menu_a_b_testing_id\" : menu_a_b_testing_id,\n", " \"item_name\" : item_name,\n", " \"llm_item_image_prompt\" : llm_item_image_prompt,\n", " \"gcs_storage_bucket\" : gcs_storage_bucket,\n", " \"gcs_storage_path\" : gcs_storage_path,\n", " \"llm_image_filename\" : image_file\n", " })\n", " except:\n", " print(\"Image failed to generate.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NiOCtQEs76B" }, "outputs": [], "source": [ "# View the results\n", "for item in image_files:\n", " print(f\"menu_a_b_testing_id: {item['menu_a_b_testing_id']}\")\n", " print(f\"item_name: {item['item_name']}\")\n", " print(f\"llm_item_image_prompt: {item['llm_item_image_prompt']}\")\n", " img = Image.open(item[\"llm_image_filename\"])\n", " img.thumbnail([500,500])\n", " IPython.display.display(img)" ] }, { "cell_type": "markdown", "metadata": { "id": "9879tAKZCxvJ" }, "source": [ "#### Save the results to storage" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aRMySmnxDgea" }, "outputs": [], "source": [ "# When we create the sample data for our table we also asked for the LLM to generate the correct GCS / HTTP path\n", "\n", "# Copy all image files to storage\n", "for item in image_files:\n", " copy_file_to_gcs(item[\"llm_image_filename\"],item[\"gcs_storage_bucket\"], item[\"gcs_storage_path\"] + str(item['menu_a_b_testing_id']) + \".png\")" ] } ], "metadata": { "colab": { "collapsed_sections": [ "k6eIqerFOzyj", "8zy0eEJmHxRZ", "YtZuFgjbOjso", "xUolPsMFOjpZ", "E5CFSdK3HxYm", "-L93udtrH1Oz", "rVCY93IyXPoO", "BlxddNzpmAgp" ], "name": "BigQuery table", "private_outputs": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }