notebooks/generative_ai/sentiment_analysis/sentiment_analysis_movie_reviews.ipynb (560 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "7b9f55ae-6bf5-4e91-bd30-64e07ea85450",
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2023 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"id": "fe656833-3d7c-4fde-87f5-c7c7ad21a106",
"metadata": {},
"source": [
"# Sentiment Analysis for large scale data using LLM (Gemini)"
]
},
{
"cell_type": "markdown",
"id": "f5528149-6b62-4ac2-b4b5-c30a381ad480",
"metadata": {},
"source": [
"<table align=\"left\">\n",
"\n",
"<a href=\"https://github.com/GoogleCloudPlatform/ai-ml-recipes/blob/main/notebooks/generative_ai/sentiment_analysis/sentiment_analysis_movie_reviews.ipynb\">\n",
"<img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\">\n",
"View on GitHub\n",
"</a>\n",
"</td>\n",
"<td>\n",
"<a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/ai-ml-recipes/main/notebooks/generative_ai/sentiment_analysis/sentiment_analysis_movie_reviews.ipynb\">\n",
"<img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\">\n",
"Open in Vertex AI Workbench\n",
"</a>\n",
"</td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "e6fd7856-e138-4831-801e-6318f856efe3",
"metadata": {},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"id": "e854e6c7-14f3-4b8a-933e-9091b24140af",
"metadata": {
"tags": []
},
"source": [
"This notebook shows how to perform sentimental analysis on large scale data using LLM.\n",
"The dataset used is a public dataset from Bigquery Public Datasets.\n",
"\n",
"#### **Steps**\n",
"Using Spark, \n",
"1) This notebook reads data from Bigquery public dataset **bigquery-public-data.imdb.reviews**\n",
"2) It calls [Vertex AI Gemini API](https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart#try_text_prompts) to find the sentiment of each review (positive vs negative)\n",
"3) We compare the result side by side\n",
"4) Find accuracy, and again trim the input and observe the accuracy increase\n",
"\n",
"#### Related content\n",
"\n",
"- [Text Prompt](https://cloud.google.com/vertex-ai/docs/generative-ai/text/text-prompts)\n",
"- [Content Classification](https://cloud.google.com/vertex-ai/docs/generative-ai/text/text-prompts#content-classification)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95c47182-9bc3-41e5-ab3c-f10f01727b99",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark.sql.functions import *\n",
"from pyspark.sql.types import StructType, StructField, StringType, ArrayType\n",
"from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
"from pyspark.ml.feature import StringIndexer\n",
"\n",
"import google.auth\n",
"import google.auth.transport.requests\n",
"import requests\n",
"\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9eb15ebe-c6aa-4f92-92d3-6fa8ec5dc5b3",
"metadata": {},
"outputs": [],
"source": [
"# When using Dataproc Serverless, installed packages are automatically available on all nodes\n",
"!pip install --upgrade google-cloud-aiplatform google-cloud-vision\n",
"# When using a Dataproc cluster, you will need to install these packages during cluster creation: https://cloud.google.com/dataproc/docs/tutorials/python-configuration"
]
},
{
"cell_type": "markdown",
"id": "50ebd0b3-c4b5-4df5-8d11-0f721a147467",
"metadata": {},
"source": [
"#### Get credentials to authenticate with Google APIs\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf715064-3b4b-44d2-9dcb-2c25143affa9",
"metadata": {},
"outputs": [],
"source": [
"credentials, project_id = google.auth.default()\n",
"auth_req = google.auth.transport.requests.Request()\n",
"credentials.refresh(auth_req)"
]
},
{
"cell_type": "markdown",
"id": "483fa597-0127-4077-b1f3-4bcb9ff11c65",
"metadata": {},
"source": [
"### Create Spark Session for the notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3e4dced-00af-4324-8d00-2855f6152f31",
"metadata": {},
"outputs": [],
"source": [
"spark = SparkSession.builder \\\n",
" .appName(\"Sentimental Analysis using Dataproc and Vertex LLM\") \\\n",
" .getOrCreate()"
]
},
{
"cell_type": "markdown",
"id": "97b310e5-ede5-4ac5-a5dd-df0732f61c8e",
"metadata": {},
"source": [
"### Read data from Bigquery Public Dataset "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02d7cad8-6ef6-44bb-baa9-78d5f10dc288",
"metadata": {},
"outputs": [],
"source": [
"movie_reviews = spark.read.format(\"bigquery\").option(\"table\", \"bigquery-public-data.imdb.reviews\").load()"
]
},
{
"cell_type": "markdown",
"id": "b76336bd-14a3-4d1f-a8b5-0e500eba39ef",
"metadata": {},
"source": [
"| review|split| label| movie_id|reviewer_rating| movie_url|title|\n",
"|----------------------------------------------------------------------------------------------------|-----|--------|---------|---------------|------------------------------------|-----|\n",
"|I had to see this on the British Airways plane. It was terribly bad acting and a dumb story. Not ...| test|Negative|tt0158887| 2|http://www.imdb.com/title/tt0158887/| null|\n",
"|This is a family movie that was broadcast on my local ITV station at 1.00 am a couple of nights a...| test|Negative|tt0158887| 4|http://www.imdb.com/title/tt0158887/| null|\n",
"|I would like to comment on how the girls are chosen. why is that their are always more white wome...| test|Negative|tt0391576| 2|http://www.imdb.com/title/tt0391576/| null|\n",
"|Tyra & the rest of the modeling world needs to know that real women like myself and my daughter d...| test|Negative|tt0391576| 3|http://www.imdb.com/title/tt0391576/| null|"
]
},
{
"cell_type": "markdown",
"id": "8efeb371-c8ae-4a20-b5e6-a21c5723e737",
"metadata": {},
"source": [
"### Get Positive Reviews from Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "523c5cf6-5165-4ca6-b2b4-447bb15762b4",
"metadata": {},
"outputs": [],
"source": [
"positive_movie_reviews = movie_reviews.select(col(\"review\"), col(\"reviewer_rating\"), col(\"movie_id\"), col(\"label\")).where(col(\"label\") == \"Positive\").limit(100)"
]
},
{
"cell_type": "markdown",
"id": "d5846a0c-1ca6-483f-b9f1-52873bd50312",
"metadata": {},
"source": [
"### Get Negative Reviews from Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec0a1557-cf01-4f86-9785-47ecb06ba752",
"metadata": {},
"outputs": [],
"source": [
"negative_movie_reviews = movie_reviews.select(col(\"review\"), col(\"reviewer_rating\"), col(\"movie_id\"), col(\"label\")).where(col(\"label\") == \"Negative\").limit(100)"
]
},
{
"cell_type": "markdown",
"id": "14843f6a-8f88-49b8-b2a0-db930d3a5c37",
"metadata": {},
"source": [
"### Mix positive and negative \n",
"Making union of positive and negative reviews to get a good dataset of mixed set of reviews. For the purpose notebook, each class of reviews has 100 rows each."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda456fa-b95a-4e95-835c-cbc64be1fe11",
"metadata": {},
"outputs": [],
"source": [
"movie_reviews_mixed = positive_movie_reviews.union(negative_movie_reviews)"
]
},
{
"cell_type": "markdown",
"id": "d39af261-8416-4b2f-b830-f90a7379f859",
"metadata": {},
"source": [
"| review|reviewer_rating| movie_id| label|\n",
"|--------------------|---------------|---------|--------|\n",
"|This movie is ama...| 10|tt0187123|Positive|\n",
"|THE HAND OF DEATH...| 10|tt0187123|Positive|\n",
"|The Hand of Death...| 7|tt0187123|Positive|\n",
"|Just as a reminde...| 10|tt0163955|Positive|\n",
"|Like an earlier c...| 9|tt0163955|Positive|"
]
},
{
"cell_type": "markdown",
"id": "558c6709-e265-4e21-a15b-4e5980303686",
"metadata": {},
"source": [
"### Final count is 200 as can be seen below"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4bfebe8-6491-4cc4-b8b5-d6b408e67196",
"metadata": {},
"outputs": [],
"source": [
"movie_reviews_mixed.count()"
]
},
{
"cell_type": "markdown",
"id": "2951e16c-5dd3-49b7-abf1-0fbb930453a2",
"metadata": {},
"source": [
"### Creating a UDF to get predictions from Gemini Model\n",
"In this method, text whose sentiment is to be predicted is passed"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5288e2ee-af06-47ee-a6ce-977852bb700e",
"metadata": {},
"outputs": [],
"source": [
"import vertexai\n",
"from vertexai.generative_models import GenerativeModel, Part , HarmCategory, HarmBlockThreshold\n",
"\n",
"vertexai.init(project=project_id, location=\"us-central1\")\n",
"\n",
"def gemini_predict(prompt):\n",
" \n",
" gemini_pro_model = GenerativeModel(\"gemini-1.0-pro\")\n",
" config = {\"max_output_tokens\": 2048, \"temperature\": 0.4, \"top_p\": 1, \"top_k\": 32}\n",
" safety_config = {\n",
" HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,\n",
" HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,\n",
" }\n",
" \n",
" prediction = gemini_pro_model.generate_content([\n",
" prompt\n",
" ],\n",
" generation_config=config,\n",
" safety_settings=safety_config,\n",
" stream=True\n",
" )\n",
" \n",
" text_responses = []\n",
" try:\n",
" for response in prediction:\n",
" text_responses.append(response.text)\n",
" except:\n",
" pass\n",
" return \"\".join(text_responses)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "385d6430-6415-4e16-97c3-b51cd1527395",
"metadata": {},
"outputs": [],
"source": [
"def find_sentiment_zero_shot(text):\n",
" \n",
" prompt = f\"\"\"For the given text below, provide the sentiment classification from the two classes mentioned below:\n",
" The two classes are: Negative, Positive.\n",
" Always choose between one of them (the most appropriate one.\n",
" Text: {text}\n",
" Sentiment:\"\"\"\n",
" \n",
" sentiment = gemini_predict(prompt)\n",
" return sentiment\n",
" \n",
"find_sentiment_zero_shot_udf = udf(find_sentiment_zero_shot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fede9f0-a2dc-40bb-9af5-bdfcf7365f8f",
"metadata": {},
"outputs": [],
"source": [
"movie_reviews_mixed.printSchema()"
]
},
{
"cell_type": "markdown",
"id": "3642ecc0-2734-4027-b9d3-5ed32ff6d867",
"metadata": {},
"source": [
"### Get prediction from the LLM using the UDF on the movie reviews"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7239a8d-f60a-40f5-95ee-a18b9b5d3a6c",
"metadata": {},
"outputs": [],
"source": [
"movie_review_sentiment_pred = movie_reviews_mixed.withColumn(\"predicted_sentiment\", find_sentiment_zero_shot_udf(movie_reviews_mixed[\"review\"]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8623a94c-1006-4dbd-9c82-fb9604fc412b",
"metadata": {},
"outputs": [],
"source": [
"# Trim whitespaces\n",
"trimmed_movie_review_sentiment_pred = movie_review_sentiment_pred.withColumn(\"predicted_sentiment\", trim(col(\"predicted_sentiment\"))).withColumn(\"label\", trim(col(\"label\")))"
]
},
{
"cell_type": "markdown",
"id": "92c545e1-8d4f-47b7-a595-c03e01efbacf",
"metadata": {},
"source": [
"### Let's check the predicted value and do a quick comparison of required output v/s actual label"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e178b7f-f747-4093-a69e-50019ea3529e",
"metadata": {},
"outputs": [],
"source": [
"trimmed_movie_review_sentiment_pred.select(col(\"predicted_sentiment\"), col(\"label\")).show(200,100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5cc405d9-a775-40c9-93e1-3a79a5251837",
"metadata": {},
"outputs": [],
"source": [
"trimmed_movie_review_sentiment_pred.cache()"
]
},
{
"cell_type": "markdown",
"id": "7fa62d52-795b-4c27-b3c0-cb725a93d473",
"metadata": {},
"source": [
"### Evaluation\n",
"\n",
"Let's index the classes Negative and Positive to 1 and 0. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11aaf2c1-e64c-459e-9439-2f9ac18e2d4c",
"metadata": {},
"outputs": [],
"source": [
"inputs = [\"predicted_sentiment\", \"label\"]\n",
"outputs = [\"predicted_sentiment_indexed\", \"label_indexed\"]\n",
"\n",
"stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)\n",
"indexer = stringIndexer.fit(trimmed_movie_review_sentiment_pred)\n",
"\n",
"movie_review_sentiment_pred_indexed = indexer.transform(trimmed_movie_review_sentiment_pred)"
]
},
{
"cell_type": "markdown",
"id": "3c19120e-d06f-4743-ad76-6d89c027f5af",
"metadata": {},
"source": [
"And use the BinaryClassificationEvaluator to output our Area Under the ROC curve (AUC-ROC)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d0a765d-91f6-40b2-82d4-f44fcef2a179",
"metadata": {},
"outputs": [],
"source": [
"evaluator = BinaryClassificationEvaluator()\n",
"evaluator.setRawPredictionCol(\"predicted_sentiment_indexed\")\n",
"evaluator.setLabelCol(\"label_indexed\")\n",
"\n",
"area_under_roc = evaluator.evaluate(movie_review_sentiment_pred_indexed, {evaluator.metricName: \"areaUnderROC\"})\n",
"\n",
"print(\"area_under_roc (%): \", area_under_roc)"
]
},
{
"cell_type": "markdown",
"id": "36daec26-213d-4a94-9c30-caa4a419ea13",
"metadata": {},
"source": [
"Without any prior training or chaining prompts, only by zero shot, the model has been able to predict sentiments properly with 94% AUC-ROC"
]
},
{
"cell_type": "markdown",
"id": "14b574d7-42e2-40f1-9108-74f9239aa0e3",
"metadata": {},
"source": [
"#### Count the number of unsuccessful predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5c30334-869e-405d-83a8-4f8a36837c4a",
"metadata": {},
"outputs": [],
"source": [
"match_predictions_df = movie_review_sentiment_pred_indexed.withColumn(\"if_match\", when((col(\"predicted_sentiment_indexed\")==col(\"label_indexed\")),1).otherwise(0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2363c33a-20b1-45ca-b46c-08d936233a8d",
"metadata": {},
"outputs": [],
"source": [
"match_predictions_df.where(col(\"if_match\")==0).count()"
]
},
{
"cell_type": "markdown",
"id": "724aa40e-b047-4e4a-9a67-f300202acadc",
"metadata": {},
"source": [
"#### Percentage Accuracy \n",
"\n",
"Total Rows = 200\n",
"* Mislabeled rows = 10\n",
"* Accuracy = (True positives + True Negatives)/ (True positives + True negatives + False positives + False negatives)\n",
"* Percentage Accuracy= 190 / 200 = 90%"
]
},
{
"cell_type": "markdown",
"id": "632afa55-edb1-4431-b78d-4765d527d838",
"metadata": {},
"source": [
"#### Check the mismatch predictions\n",
"Find the mismatched rows and show it"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e3f8d6c-7fab-4157-b60f-85784bc9fb74",
"metadata": {},
"outputs": [],
"source": [
"mismatch_df = match_predictions_df.where(col(\"if_match\")==0).select(col('predicted_sentiment'),col('label'),col('review'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "366b633a-46a1-4739-8758-58e483451365",
"metadata": {},
"outputs": [],
"source": [
"mismatch_df.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "runtime-galileo on Serverless Spark (Remote)",
"language": "python",
"name": "9c39b79e5d2e7072beb4bd59-runtime-galileo"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}