genai-for-marketing/notebooks/simple_news_summarization.ipynb (486 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1aa304e2", "metadata": {}, "outputs": [], "source": [ "# Copyright 2023 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "attachments": {}, "cell_type": "markdown", "id": "359697d5", "metadata": { "id": "359697d5" }, "source": [ "# News summarization with PaLM API\n", "\n", "## Overview\n", "\n", "This notebook illustrates how to use Vertex AI PaLM text models for news summarization. You will discover the most popular Google Search terms and summarize news articles related to those terms. A system like that could be beneficial in a variety of business situations, including marketing, political analysis, and more.\n", "\n", "Trending search terms are retrieved from [Google Trends dataset](https://pantheon.corp.google.com/marketplace/product/bigquery-public-datasets/google-search-trends?project=jk-mlops-dev) and news articles from [the GDELT database](https://www.gdeltproject.org/). \n", "The Google Trends dataset contains the top 25 overall and top 25 rising queries from Google Trends in the past 30 days. The dataset is hosted on Google BigQuery as part of the Google Cloud Datasets initiative.\n", "\n", "The GDELT Project, which is supported by [Google Jigsaw](https://jigsaw.google.com/), monitors the world's broadcast, print, and web news from nearly every corner of every country in over 100 languages. The GDELT database is free to use and accessible via a variety of interfaces, including Google BigQuery and the REST API. In this notebook, we will be using the REST API.\n", "\n", "The notebook is structured as follows:\n", "- You will begin by installing the necessary packages and configuring your GCP environment.\n", "- You will query Google Trends dataset to bring top search terms\n", "- You will query GDELT API to bring news related to top search terms\n", "- Finally, you will summarize these news articles\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4d20f86c", "metadata": {}, "source": [ "## Install pre-requisites\n", "\n", "Install the following python packages." ] }, { "cell_type": "code", "execution_count": null, "id": "8f10d9d6", "metadata": {}, "outputs": [], "source": [ "! pip install -U google-cloud-aiplatform\n", "! pip install -U python-dateutil\n", "! pip install -U newspaper3k" ] }, { "attachments": {}, "cell_type": "markdown", "id": "73c6c4af", "metadata": {}, "source": [ "---\n", "\n", "#### ⚠️ Do not forget to RESTART THE RUNTIME before continue.\n", "\n", "---" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1e6a68f9", "metadata": {}, "source": [ "## Configure Google Cloud environment settings\n", "\n", "Set the following constants to reflect your GCP environment.\n", "- `PROJECT_ID`: Your Google Cloud Project ID.\n", "- `REGION`: The region to use for Vertex AI" ] }, { "cell_type": "code", "execution_count": null, "id": "5976cf1b", "metadata": {}, "outputs": [], "source": [ "PROJECT_ID = '<YOUR PROJECT ID HERE>'\n", "REGION = 'us-central1'" ] }, { "attachments": {}, "cell_type": "markdown", "id": "723e634b", "metadata": {}, "source": [ "Initialize the SDK and import some modules." ] }, { "cell_type": "code", "execution_count": null, "id": "7543d35e", "metadata": {}, "outputs": [], "source": [ "import logging\n", "import os\n", "import requests\n", "import vertexai\n", "\n", "from newspaper import Article\n", "from newspaper import ArticleException\n", "\n", "from dateutil.parser import parse as parse_date\n", "from datetime import date, timedelta, datetime\n", "from google.cloud import bigquery\n", "from vertexai.preview.language_models import TextGenerationModel\n", "from typing import Any, Dict, List\n", "\n", "\n", "logging.basicConfig(level = logging.INFO)\n", "vertexai.init(project=PROJECT_ID, location=REGION)\n", "\n", "bq_client = bigquery.Client(project=PROJECT_ID)\n", "llm = TextGenerationModel.from_pretrained(\"text-bison@001\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "xScdQhw-qpBl", "metadata": { "id": "xScdQhw-qpBl" }, "source": [ "### Google Trends lookup tool\n", "\n", "Returns top (rank 1) search term(s) for a given date." ] }, { "cell_type": "code", "execution_count": null, "id": "602610b8", "metadata": {}, "outputs": [], "source": [ "class GoogleTrends:\n", " \"\"\"Get Trends from BQ dataset\n", " Useful for when you need to find top search terms on a given date. \n", " Input is a JSON object that has the field date.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, \n", " project_id: str, \n", " bq_client: Any):\n", " self.project_id = project_id\n", " self.bq_client = bq_client\n", "\n", "\n", " def run(self, json_params: Dict):\n", " refresh_date = self._parse_date(json_params)\n", "\n", " if refresh_date:\n", " df = self._query_top_terms(refresh_date)\n", " terms = df.loc[0].values[0]\n", " terms = terms.split(' ')\n", " else:\n", " terms = []\n", "\n", " return terms\n", "\n", "\n", " def _query_top_terms(self, date: str):\n", " \"\"\"Retrieve top terms from Google Trends.\"\"\"\n", " query = f\"\"\"\n", " SELECT term, rank FROM `bigquery-public-data.google_trends.top_terms`\n", " WHERE refresh_date = '{date}'\n", " GROUP BY 1,2\n", " ORDER by rank ASC\n", " \"\"\"\n", " query_job = self.bq_client.query(\n", " query,\n", " location=\"US\",\n", " )\n", "\n", " df = query_job.to_dataframe()\n", " return df\n", "\n", "\n", " def _parse_date(self, json_params: Dict):\n", " \"\"\"Parse date.\"\"\"\n", " params = json_params\n", "\n", " if 'date' in params:\n", " try:\n", " dt = parse_date(params['date'])\n", " dt = dt.date()\n", " except:\n", " dt = date.today()\n", " else:\n", " dt = date.today()\n", "\n", " dt_str = dt.strftime('%Y-%m-%d')\n", " if dt >= date.today() or dt <= date.today() - timedelta(days=30):\n", " dt_str = \"\"\n", " else:\n", " dt_str = dt.strftime('%Y-%m-%d')\n", "\n", " return dt_str" ] }, { "cell_type": "code", "execution_count": null, "id": "Z6ONfOeo3GZ0", "metadata": { "id": "Z6ONfOeo3GZ0" }, "outputs": [], "source": [ "# Google Trends dataset in BigQuery only stores data from the past month.\n", "# Change to a valid date.\n", "\n", "google_trends_tool = GoogleTrends(project_id=PROJECT_ID, bq_client=bq_client)\n", "google_trends_tool.run({'date':'11-24-2023'})" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9287c69a", "metadata": {}, "source": [ "### GDELT Retriever\n", "\n", "The `GDELT Retriever` obtains information about articles that match a set of keywords. The class takes a JSON object as input, with the following format:\n", "\n", "```\n", "{\n", " \"date\": \"05-16-2023\",\n", " \"keywords\": [\"Real\", \"Madrid\"],\n", " \"tone\": \"positive\"\n", "}\n", "```\n", "\n", "The script will search for articles published between the dates `date - time_window_days` and `date + time_window_days`, where `time_window_days` is a configurable parameter. The `tone` field can be either `positive`, `negative`, or `unknown`. Please refer to the [GDELT documentation](https://blog.gdeltproject.org/gdelt-doc-2-0-api-debuts/#:~:text=theme%3ATERROR-,Tone,-.%20Allows%20you%20to) for more information on the tone setting. We define `positive` as a tone value greater than 5 and `negative` as a tone value less than 5.\n", "\n", "Besides the time_window, there are other configuration parameters that control the results, including the maximum number of returned records (`max_records`) and the maximum distance between the keywords in an article (`n_near_words`).\n", "\n", "The script returns a text that is a compilation of article titles and their content.\n", "\n", "Let's start by defining the retriever." ] }, { "cell_type": "code", "execution_count": null, "id": "780329bd", "metadata": {}, "outputs": [], "source": [ "class GDELTRetriever:\n", " def __init__(self, max_records:int = 10, tone: str = 'positive'):\n", " self.gdelt_api_url: str = 'https://api.gdeltproject.org/api/v2/doc/doc'\n", " self.mode: str = 'ArtList'\n", " self.format: str = 'json'\n", " self.max_records: int = max_records\n", " self.n_near_words: int = 50\n", " self.source_country: str = 'US'\n", " self.source_lang: str = 'english'\n", " self.time_window_days: int = 3\n", " \n", " if tone == 'positive':\n", " self.tone = 'tone>5'\n", " elif tone == 'negative':\n", " self.tone = 'tone<-5'\n", "\n", "\n", " def _get_articles_info(\n", " self, \n", " keywords: list[str], \n", " startdatetime: datetime, \n", " enddatetime: datetime) -> Dict:\n", "\n", " startdatetime_str = startdatetime.strftime('%Y%m%d%H%M%S')\n", " enddatetime_str = enddatetime.strftime('%Y%m%d%H%M%S')\n", "\n", " query = f'near{self.n_near_words}:\"{\" \".join(keywords)}\" '\n", " query += f'sourcecountry:{self.source_country} sourcelang:{self.source_lang} '\n", " query += f'{self.tone}'\n", " params = {'query': query,\n", " 'format': self.format,\n", " 'mode': self.mode,\n", " 'maxrecords': str(self.max_records),\n", " 'startdatetime': startdatetime_str,\n", " 'enddatetime': enddatetime_str}\n", "\n", " response = requests.get(self.gdelt_api_url, params=params)\n", " response.raise_for_status()\n", " return response.json()\n", "\n", "\n", " def _parse_article(self, url: str) -> Article:\n", " article = Article(url)\n", "\n", " try:\n", " article.download()\n", " article.parse()\n", " except ArticleException:\n", " return Article(\"www.google.com\")\n", " else:\n", " return article\n", "\n", "\n", " def _get_documents(self, articles: Dict) -> List[Dict]:\n", " documents = []\n", " unique_docs = set()\n", "\n", " for article in articles['articles']:\n", " parsed_article = self._parse_article(article['url'])\n", " if parsed_article and parsed_article.text and (article['title'] not in unique_docs):\n", " unique_docs.add(article['title'])\n", " document = {\n", " 'page_content': parsed_article.text,\n", " 'title': article['title'],\n", " 'url': article['url'],\n", " 'domain': article['domain'],\n", " 'date': article['seendate']\n", " }\n", " documents.append(document)\n", " return documents\n", "\n", "\n", " def get_relevant_documents(self, query: str) -> List[Dict]:\n", " query_params = query\n", " keywords = query_params['keywords']\n", " event_date = parse_date(query_params['date'])\n", " startdatetime = datetime.combine(event_date - timedelta(days=self.time_window_days), datetime.min.time())\n", " enddatetime = datetime.combine(event_date + timedelta(days=self.time_window_days), datetime.min.time())\n", " articles = self._get_articles_info(keywords, startdatetime, enddatetime)\n", " documents = self._get_documents(articles)\n", "\n", " return documents" ] }, { "cell_type": "code", "execution_count": null, "id": "FEDN9iByOvLl", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FEDN9iByOvLl", "outputId": "01b9b241-6aeb-4791-dda9-d1a9922c8a59" }, "outputs": [], "source": [ "retriever = GDELTRetriever(max_records=10)\n", "documents = retriever.get_relevant_documents(query={'keywords': ['Cowboys'], 'date': '11-24-2023'})\n", "\n", "print('=> Sample article')\n", "print(documents[0]['page_content'][:200] + '...\\n')\n", "print('=> Total number of documents: ' + str(len(documents)))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "Qc1kenKQ-rY2", "metadata": { "id": "Qc1kenKQ-rY2" }, "source": [ "#### Summarize news articles\n", "\n", "This is a function that processes each document in a list using a configured LLM and a prompt. The script returns the list of responses together with original documents." ] }, { "cell_type": "code", "execution_count": null, "id": "1019712f", "metadata": {}, "outputs": [], "source": [ "def summarize_news_article(document: Dict, llm):\n", " prompt_template = f\"\"\"Write a one sentence summary of the following article delimited by triple backticks:\n", "\n", " ```{document['page_content']}```\n", " \"\"\"\n", " document['summary'] = llm.predict(prompt_template).text\n", " return document\n", "\n", "\n", "def summarize_documents(documents: Dict, llm) -> List:\n", " summaries = []\n", "\n", " for document in documents:\n", " summaries.append(\n", " summarize_news_article(document, llm)\n", " )\n", "\n", " return summaries" ] }, { "cell_type": "code", "execution_count": null, "id": "CW2dfTPzro22", "metadata": { "id": "CW2dfTPzro22" }, "outputs": [], "source": [ "summarize_news_article(documents[0], llm)" ] }, { "cell_type": "code", "execution_count": null, "id": "a3afb57f", "metadata": {}, "outputs": [], "source": [ "summarize_documents(documents=documents, llm=llm)" ] }, { "cell_type": "code", "execution_count": null, "id": "2b7fbcdb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }