colab-enterprise/rideshare_llm_step_06_driver_quantitative_analysis.ipynb (351 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Fw_xXPkQm-Qe"
},
"source": [
"# Create Driver Summary (Quantitative Analysis)\n",
"- This notebook take about 5 to 10 minutes to execute\n",
"- Extract quantitative data from the Trips data\n",
" - How many pick up locations\n",
" - Do they drive to the airport\n",
" - Do they cross state lines\n",
" - Do they work only certain days of the week\n",
"- Create a LLM summary of the extracted data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P0mDZ_HwnMo9"
},
"source": [
"## Create Summary Prompt and run through LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TRJYApx9ZUd8"
},
"outputs": [],
"source": [
"%%bigquery\n",
"\n",
"-- OPTIONAL: Reset all the fields to null\n",
"-- If you need to reset you data back to fresh data run the stored procedure: CALL `${project_id}.${bigquery_rideshare_llm_curated_dataset}.sp_reset_demo`();\n",
"\n",
"/*\n",
"UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n",
" SET driver_quantitative_analysis_prompt = NULL,\n",
" llm_driver_quantitative_analysis_json = NULL,\n",
" llm_driver_quantitative_analysis = NULL\n",
" WHERE TRUE;\n",
"*/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "45e-7wxzoFVE"
},
"outputs": [],
"source": [
"%%bigquery\n",
"\n",
"-- Create the LLM prompt\n",
"UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` AS driver\n",
" SET driver_quantitative_analysis_prompt =\n",
" CONCAT('Write a 3 to 8 sentence summary of the following attributes of a driver in third person gender neutral form: ',\n",
" CASE WHEN pickup_location_habit = 'few-1-3-pickup-locations' AND pickup_location_count = 1\n",
" THEN CONCAT('- The driver only picks up customers at ',\n",
" CAST(pickup_location_count AS STRING),\n",
" ' pickup location. These location is ',\n",
" distinct_pickup_location_zones,\n",
" '.\\n')\n",
" WHEN pickup_location_habit = 'few-1-3-pickup-locations' THEN CONCAT('- The driver only picks up customers at ',\n",
" CAST(pickup_location_count AS STRING),\n",
" ' pickup locations. These locations are: ',\n",
" distinct_pickup_location_zones,\n",
" '.\\n')\n",
" WHEN pickup_location_habit = 'average-4-6-pickup-locations' THEN CONCAT('- The driver picks up customers at ',\n",
" CAST(pickup_location_count AS STRING),\n",
" ' pickup locations. This is an average number. These locations are: ',\n",
" distinct_pickup_location_zones,\n",
" '.\\n')\n",
" WHEN pickup_location_habit = 'many-7-9-pickup-locations' THEN CONCAT('- The driver picks up customers many pickup locations. ',\n",
" 'This is an above average number. These locations include: ',\n",
" distinct_pickup_location_zones,\n",
" '.\\n')\n",
" WHEN pickup_location_habit = 'any-pickup-locations' THEN 'The driver will pickup customers at a large number of locations.'\n",
" ELSE ''\n",
" END,\n",
" CASE WHEN dropoff_location_habit = 'few-1-3-dropoff-locations' AND dropoff_location_count = 1\n",
" THEN CONCAT('- The driver only drops off customers at ',\n",
" CAST(dropoff_location_count AS STRING),\n",
" ' dropoff location. The location is ',\n",
" distinct_dropoff_location_zones,\n",
" '.\\n')\n",
" WHEN dropoff_location_habit = 'few-1-3-dropoff-locations' THEN CONCAT('- The driver only drops off customers at ',\n",
" CAST(dropoff_location_count AS STRING),\n",
" ' dropoff locations. These locations are: ',\n",
" distinct_dropoff_location_zones,\n",
" '.\\n')\n",
" WHEN dropoff_location_habit = 'average-4-6-dropoff-locations' THEN CONCAT('- The driver drops off customers at ',\n",
" CAST(dropoff_location_count AS STRING),\n",
" ' dropoff locations. This is an average number. These locations are: ',\n",
" distinct_dropoff_location_zones,\n",
" '.\\n')\n",
" WHEN dropoff_location_habit = 'many-7-9-dropoff-locations' THEN CONCAT('- The driver drops off customers many locations. ',\n",
" 'This is an above average number. These locations include: ',\n",
" distinct_dropoff_location_zones,\n",
" '.\\n')\n",
" ELSE ''\n",
" END,\n",
"\n",
" CASE WHEN cross_state_habit = 'crosses-state-line' THEN CONCAT('- The driver is will to pickup or dropoff customers accross state lines.\\n')\n",
" WHEN cross_state_habit = 'does-not-cross-state-line' THEN CONCAT('- The driver is not willing to drive accross state lines.\\n')\n",
" ELSE ''\n",
" END,\n",
"\n",
" CASE WHEN airport_habit = 'airport-driver' THEN CONCAT('- The driver has a high preference for picking and dropping off customers at the airport.\\n')\n",
" WHEN airport_habit = 'non-airport-driver' THEN CONCAT('- The driver typically does not pickup or dropoff at the airport.\\n')\n",
" ELSE ''\n",
" END,\n",
"\n",
" CASE WHEN day_of_week = 'weekend-driver' THEN CONCAT('- The driver only works on weekends.\\n')\n",
" WHEN day_of_week = 'weekday-driver' THEN CONCAT('- The driver only works on weekdays.\\n')\n",
" ELSE '- The driver works a varity of days not targetting specific days of the week.\\n'\n",
" END,\n",
"\n",
" CASE WHEN hour_of_day = 'night-hour-driver' THEN CONCAT('- The driver likes to work late at night.\\n')\n",
" WHEN hour_of_day = 'rush-hour-driver' THEN CONCAT('- The driver likes to work a split shift which appears to target rush hour.\\n')\n",
" ELSE '- The driver does not appear to have a set schedule for the hours they work.\\n'\n",
" END,\n",
"\n",
" CASE WHEN average_daily_pay IS NOT NULL THEN CONCAT('- The driver appears to target a specific amount of income per day.\\n',\n",
" '- The drivers likes to their daily amount to be approximately $',\n",
" CAST(ROUND(average_daily_pay,2) AS STRING),\n",
" ' with a ',\n",
" CAST(ROUND(stddev_amt,2) AS STRING),\n",
" '% standard deviation.\\n')\n",
" ELSE ''\n",
" END\n",
" )\n",
" FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver_quantitative_analysis` AS driver_quantitative_analysis\n",
"WHERE driver.driver_id = driver_quantitative_analysis.driver_id\n",
";\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EnZ44v7eypQG"
},
"outputs": [],
"source": [
"%%bigquery\n",
"\n",
"SELECT driver_quantitative_analysis_prompt\n",
" FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` AS driver\n",
" WHERE driver_quantitative_analysis_prompt IS NOT NULL\n",
" LIMIT 10;"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n_8WMAP7yzqc"
},
"source": [
"## Run the LLM to generate a Driver Summary on Quantitative Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UV9N-LwVzre_"
},
"outputs": [],
"source": [
"from google.cloud import bigquery\n",
"import pandas as pd\n",
"\n",
"client = bigquery.Client()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9nuVE4BlzgaG"
},
"outputs": [],
"source": [
"# Process in batches\n",
"batch_size = 100\n",
"\n",
"# Set the parameters so we are more deterministic and less creative/random responses\n",
"llm_temperature = .80\n",
"llm_max_output_tokens = 1024\n",
"llm_top_p = .70\n",
"llm_top_k = 25\n",
"\n",
"update_sql=\"\"\"\n",
"UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` AS driver\n",
" SET llm_driver_quantitative_analysis_json = child.ml_generate_text_result\n",
" FROM (SELECT *\n",
" FROM ML.GENERATE_TEXT(MODEL`${project_id}.${bigquery_rideshare_llm_enriched_dataset}.gemini_model`,\n",
" (SELECT driver_id,\n",
" driver_quantitative_analysis_prompt AS prompt\n",
" FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n",
" WHERE (llm_driver_quantitative_analysis_json IS NULL\n",
" OR\n",
" JSON_VALUE(llm_driver_quantitative_analysis_json, '$.candidates[0].content.parts[0].text') IS NULL\n",
" )\n",
" AND include_in_llm_processing = TRUE\n",
" AND driver_quantitative_analysis_prompt IS NOT NULL\n",
" LIMIT {batch_size}),\n",
" STRUCT(\n",
" {llm_temperature} AS temperature,\n",
" {llm_max_output_tokens} AS max_output_tokens,\n",
" {llm_top_p} AS top_p,\n",
" {llm_top_k} AS top_k\n",
" ))\n",
" ) AS child\n",
"WHERE driver.driver_id = child.driver_id\n",
" \"\"\".format(batch_size = batch_size,\n",
" llm_temperature = llm_temperature,\n",
" llm_max_output_tokens = llm_max_output_tokens,\n",
" llm_top_p = llm_top_p,\n",
" llm_top_k = llm_top_k)\n",
"\n",
"print(\"SQL: {update_sql}\".format(update_sql=update_sql))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZB8YzCbdzlEl"
},
"outputs": [],
"source": [
"# Score while records remain\n",
"# score in groups of batch_size records (we can do up to 10,000 at a time)\n",
"import time\n",
"\n",
"done = False\n",
"displayed_first_sql = False\n",
"original_record_count = 0\n",
"\n",
"while done == False:\n",
" # Get the count of records to score\n",
" sql = \"\"\"\n",
" SELECT COUNT(*) AS cnt\n",
" FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n",
" WHERE (llm_driver_quantitative_analysis_json IS NULL\n",
" OR\n",
" JSON_VALUE(llm_driver_quantitative_analysis_json, '$.candidates[0].content.parts[0].text') IS NULL\n",
" )\n",
" AND include_in_llm_processing = TRUE\n",
" AND driver_quantitative_analysis_prompt IS NOT NULL;\n",
" \"\"\"\n",
"\n",
" df_record_count = client.query(sql).to_dataframe()\n",
" cnt = df_record_count['cnt'].head(1).item()\n",
" if displayed_first_sql == False:\n",
" original_record_count = cnt\n",
" displayed_first_sql = True\n",
"\n",
" print(\"Remaining records to process: \", cnt, \" out of\", original_record_count, \" batch_size: \", batch_size)\n",
"\n",
"\n",
" if cnt == 0:\n",
" done = True\n",
" else:\n",
" # https://github.com/googleapis/python-bigquery/tree/master/samples\n",
" job_config = bigquery.QueryJobConfig(priority=bigquery.QueryPriority.INTERACTIVE)\n",
" query_job = client.query(update_sql, job_config=job_config)\n",
"\n",
" # Check on the progress by getting the job's updated state.\n",
" query_job = client.get_job(\n",
" query_job.job_id, location=query_job.location\n",
" )\n",
" print(\"Job {} is currently in state {}\".format(query_job.job_id, query_job.state))\n",
"\n",
" while query_job.state != \"DONE\":\n",
" time.sleep(5)\n",
" query_job = client.get_job(\n",
" query_job.job_id, location=query_job.location\n",
" )\n",
" print(\"Job {} is currently in state {}\".format(query_job.job_id, query_job.state))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hcpsyXXH0Do_"
},
"source": [
"## Parse the LLM JSON results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FiLFULh0Ep1"
},
"outputs": [],
"source": [
"%%bigquery\n",
"\n",
"UPDATE `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver` driver\n",
" SET llm_driver_quantitative_analysis = JSON_VALUE(llm_driver_quantitative_analysis_json, '$.candidates[0].content.parts[0].text')\n",
" WHERE llm_driver_quantitative_analysis_json IS NOT NULL\n",
" AND llm_driver_quantitative_analysis IS NULL;"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eJ9ImdzT0FNN"
},
"outputs": [],
"source": [
"%%bigquery\n",
"\n",
"SELECT driver_id, driver_quantitative_analysis_prompt, llm_driver_quantitative_analysis_json, llm_driver_quantitative_analysis\n",
" FROM `${project_id}.${bigquery_rideshare_llm_enriched_dataset}.driver`\n",
" WHERE llm_driver_quantitative_analysis_json IS NOT NULL\n",
" AND driver_quantitative_analysis_prompt IS NOT NULL\n",
"LIMIT 20;\n"
]
}
],
"metadata": {
"colab": {
"name": "BigQuery table",
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}