colab-enterprise/rideshare_llm_step_06_driver_quantitative_analysis.ipynb (351 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Fw_xXPkQm-Qe" }, "source": [ "# Create Driver Summary (Quantitative Analysis)\n", "- This notebook take about 5 to 10 minutes to execute\n", "- Extract quantitative data from the Trips data\n", " - How many pick up locations\n", " - Do they drive to the airport\n", " - Do they cross state lines\n", " - Do they work only certain days of the week\n", "- Create a LLM summary of the extracted data" ] }, { "cell_type": "markdown", "metadata": { "id": "P0mDZ_HwnMo9" }, "source": [ "## Create Summary Prompt and run through LLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TRJYApx9ZUd8" }, "outputs": [], "source": [ "%%bigquery\n", "\n", "-- OPTIONAL: Reset all the fields to null\n", "-- If you need to reset you data back to fresh data run the stored procedure: CALL `${project_id}.${bigquery_rideshare_llm_curated_dataset}.sp_reset_demo`();\n", "\n", "/*\n", "UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n", " SET driver_quantitative_analysis_prompt = NULL,\n", " llm_driver_quantitative_analysis_json = NULL,\n", " llm_driver_quantitative_analysis = NULL\n", " WHERE TRUE;\n", "*/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "45e-7wxzoFVE" }, "outputs": [], "source": [ "%%bigquery\n", "\n", "-- Create the LLM prompt\n", "UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` AS driver\n", " SET driver_quantitative_analysis_prompt =\n", " CONCAT('Write a 3 to 8 sentence summary of the following attributes of a driver in third person gender neutral form: ',\n", " CASE WHEN pickup_location_habit = 'few-1-3-pickup-locations' AND pickup_location_count = 1\n", " THEN CONCAT('- The driver only picks up customers at ',\n", " CAST(pickup_location_count AS STRING),\n", " ' pickup location. These location is ',\n", " distinct_pickup_location_zones,\n", " '.\\n')\n", " WHEN pickup_location_habit = 'few-1-3-pickup-locations' THEN CONCAT('- The driver only picks up customers at ',\n", " CAST(pickup_location_count AS STRING),\n", " ' pickup locations. These locations are: ',\n", " distinct_pickup_location_zones,\n", " '.\\n')\n", " WHEN pickup_location_habit = 'average-4-6-pickup-locations' THEN CONCAT('- The driver picks up customers at ',\n", " CAST(pickup_location_count AS STRING),\n", " ' pickup locations. This is an average number. These locations are: ',\n", " distinct_pickup_location_zones,\n", " '.\\n')\n", " WHEN pickup_location_habit = 'many-7-9-pickup-locations' THEN CONCAT('- The driver picks up customers many pickup locations. ',\n", " 'This is an above average number. These locations include: ',\n", " distinct_pickup_location_zones,\n", " '.\\n')\n", " WHEN pickup_location_habit = 'any-pickup-locations' THEN 'The driver will pickup customers at a large number of locations.'\n", " ELSE ''\n", " END,\n", " CASE WHEN dropoff_location_habit = 'few-1-3-dropoff-locations' AND dropoff_location_count = 1\n", " THEN CONCAT('- The driver only drops off customers at ',\n", " CAST(dropoff_location_count AS STRING),\n", " ' dropoff location. The location is ',\n", " distinct_dropoff_location_zones,\n", " '.\\n')\n", " WHEN dropoff_location_habit = 'few-1-3-dropoff-locations' THEN CONCAT('- The driver only drops off customers at ',\n", " CAST(dropoff_location_count AS STRING),\n", " ' dropoff locations. These locations are: ',\n", " distinct_dropoff_location_zones,\n", " '.\\n')\n", " WHEN dropoff_location_habit = 'average-4-6-dropoff-locations' THEN CONCAT('- The driver drops off customers at ',\n", " CAST(dropoff_location_count AS STRING),\n", " ' dropoff locations. This is an average number. These locations are: ',\n", " distinct_dropoff_location_zones,\n", " '.\\n')\n", " WHEN dropoff_location_habit = 'many-7-9-dropoff-locations' THEN CONCAT('- The driver drops off customers many locations. ',\n", " 'This is an above average number. These locations include: ',\n", " distinct_dropoff_location_zones,\n", " '.\\n')\n", " ELSE ''\n", " END,\n", "\n", " CASE WHEN cross_state_habit = 'crosses-state-line' THEN CONCAT('- The driver is will to pickup or dropoff customers accross state lines.\\n')\n", " WHEN cross_state_habit = 'does-not-cross-state-line' THEN CONCAT('- The driver is not willing to drive accross state lines.\\n')\n", " ELSE ''\n", " END,\n", "\n", " CASE WHEN airport_habit = 'airport-driver' THEN CONCAT('- The driver has a high preference for picking and dropping off customers at the airport.\\n')\n", " WHEN airport_habit = 'non-airport-driver' THEN CONCAT('- The driver typically does not pickup or dropoff at the airport.\\n')\n", " ELSE ''\n", " END,\n", "\n", " CASE WHEN day_of_week = 'weekend-driver' THEN CONCAT('- The driver only works on weekends.\\n')\n", " WHEN day_of_week = 'weekday-driver' THEN CONCAT('- The driver only works on weekdays.\\n')\n", " ELSE '- The driver works a varity of days not targetting specific days of the week.\\n'\n", " END,\n", "\n", " CASE WHEN hour_of_day = 'night-hour-driver' THEN CONCAT('- The driver likes to work late at night.\\n')\n", " WHEN hour_of_day = 'rush-hour-driver' THEN CONCAT('- The driver likes to work a split shift which appears to target rush hour.\\n')\n", " ELSE '- The driver does not appear to have a set schedule for the hours they work.\\n'\n", " END,\n", "\n", " CASE WHEN average_daily_pay IS NOT NULL THEN CONCAT('- The driver appears to target a specific amount of income per day.\\n',\n", " '- The drivers likes to their daily amount to be approximately $',\n", " CAST(ROUND(average_daily_pay,2) AS STRING),\n", " ' with a ',\n", " CAST(ROUND(stddev_amt,2) AS STRING),\n", " '% standard deviation.\\n')\n", " ELSE ''\n", " END\n", " )\n", " FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver_quantitative_analysis` AS driver_quantitative_analysis\n", "WHERE driver.driver_id = driver_quantitative_analysis.driver_id\n", ";\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EnZ44v7eypQG" }, "outputs": [], "source": [ "%%bigquery\n", "\n", "SELECT driver_quantitative_analysis_prompt\n", " FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` AS driver\n", " WHERE driver_quantitative_analysis_prompt IS NOT NULL\n", " LIMIT 10;" ] }, { "cell_type": "markdown", "metadata": { "id": "n_8WMAP7yzqc" }, "source": [ "## Run the LLM to generate a Driver Summary on Quantitative Analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UV9N-LwVzre_" }, "outputs": [], "source": [ "from google.cloud import bigquery\n", "import pandas as pd\n", "\n", "client = bigquery.Client()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9nuVE4BlzgaG" }, "outputs": [], "source": [ "# Process in batches\n", "batch_size = 100\n", "\n", "# Set the parameters so we are more deterministic and less creative/random responses\n", "llm_temperature = .80\n", "llm_max_output_tokens = 1024\n", "llm_top_p = .70\n", "llm_top_k = 25\n", "\n", "update_sql=\"\"\"\n", "UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` AS driver\n", " SET llm_driver_quantitative_analysis_json = child.ml_generate_text_result\n", " FROM (SELECT *\n", " FROM ML.GENERATE_TEXT(MODEL`${project_id}.${bigquery_rideshare_llm_enriched_dataset}.gemini_model`,\n", " (SELECT driver_id,\n", " driver_quantitative_analysis_prompt AS prompt\n", " FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n", " WHERE (llm_driver_quantitative_analysis_json IS NULL\n", " OR\n", " JSON_VALUE(llm_driver_quantitative_analysis_json, '$.candidates[0].content.parts[0].text') IS NULL\n", " )\n", " AND include_in_llm_processing = TRUE\n", " AND driver_quantitative_analysis_prompt IS NOT NULL\n", " LIMIT {batch_size}),\n", " STRUCT(\n", " {llm_temperature} AS temperature,\n", " {llm_max_output_tokens} AS max_output_tokens,\n", " {llm_top_p} AS top_p,\n", " {llm_top_k} AS top_k\n", " ))\n", " ) AS child\n", "WHERE driver.driver_id = child.driver_id\n", " \"\"\".format(batch_size = batch_size,\n", " llm_temperature = llm_temperature,\n", " llm_max_output_tokens = llm_max_output_tokens,\n", " llm_top_p = llm_top_p,\n", " llm_top_k = llm_top_k)\n", "\n", "print(\"SQL: {update_sql}\".format(update_sql=update_sql))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZB8YzCbdzlEl" }, "outputs": [], "source": [ "# Score while records remain\n", "# score in groups of batch_size records (we can do up to 10,000 at a time)\n", "import time\n", "\n", "done = False\n", "displayed_first_sql = False\n", "original_record_count = 0\n", "\n", "while done == False:\n", " # Get the count of records to score\n", " sql = \"\"\"\n", " SELECT COUNT(*) AS cnt\n", " FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n", " WHERE (llm_driver_quantitative_analysis_json IS NULL\n", " OR\n", " JSON_VALUE(llm_driver_quantitative_analysis_json, '$.candidates[0].content.parts[0].text') IS NULL\n", " )\n", " AND include_in_llm_processing = TRUE\n", " AND driver_quantitative_analysis_prompt IS NOT NULL;\n", " \"\"\"\n", "\n", " df_record_count = client.query(sql).to_dataframe()\n", " cnt = df_record_count['cnt'].head(1).item()\n", " if displayed_first_sql == False:\n", " original_record_count = cnt\n", " displayed_first_sql = True\n", "\n", " print(\"Remaining records to process: \", cnt, \" out of\", original_record_count, \" batch_size: \", batch_size)\n", "\n", "\n", " if cnt == 0:\n", " done = True\n", " else:\n", " # https://github.com/googleapis/python-bigquery/tree/master/samples\n", " job_config = bigquery.QueryJobConfig(priority=bigquery.QueryPriority.INTERACTIVE)\n", " query_job = client.query(update_sql, job_config=job_config)\n", "\n", " # Check on the progress by getting the job's updated state.\n", " query_job = client.get_job(\n", " query_job.job_id, location=query_job.location\n", " )\n", " print(\"Job {} is currently in state {}\".format(query_job.job_id, query_job.state))\n", "\n", " while query_job.state != \"DONE\":\n", " time.sleep(5)\n", " query_job = client.get_job(\n", " query_job.job_id, location=query_job.location\n", " )\n", " print(\"Job {} is currently in state {}\".format(query_job.job_id, query_job.state))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "hcpsyXXH0Do_" }, "source": [ "## Parse the LLM JSON results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FiLFULh0Ep1" }, "outputs": [], "source": [ "%%bigquery\n", "\n", "UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` driver\n", " SET llm_driver_quantitative_analysis = JSON_VALUE(llm_driver_quantitative_analysis_json, '$.candidates[0].content.parts[0].text')\n", " WHERE llm_driver_quantitative_analysis_json IS NOT NULL\n", " AND llm_driver_quantitative_analysis IS NULL;" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eJ9ImdzT0FNN" }, "outputs": [], "source": [ "%%bigquery\n", "\n", "SELECT driver_id, driver_quantitative_analysis_prompt, llm_driver_quantitative_analysis_json, llm_driver_quantitative_analysis\n", " FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n", " WHERE llm_driver_quantitative_analysis_json IS NOT NULL\n", " AND driver_quantitative_analysis_prompt IS NOT NULL\n", "LIMIT 20;\n" ] } ], "metadata": { "colab": { "name": "BigQuery table", "private_outputs": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }