use_cases/archive/document_analysis/notebooks/01-get-embeddings.ipynb (369 lines of code) (raw):
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Get Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from IPython.core.interactiveshell import InteractiveShell\n",
"InteractiveShell.ast_node_interactivity = \"all\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up Azure OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import openai\n",
"from dotenv import load_dotenv\n",
"\n",
"# Set up Azure OpenAI\n",
"load_dotenv()\n",
"openai.api_type = \"azure\"\n",
"openai.api_base = \"\" # Api base is the 'Endpoint' which can be found in Azure Portal where Azure OpenAI is created. It looks like https://xxxxxx.openai.azure.com/\n",
"openai.api_version = \"2022-12-01\"\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>category</th>\n",
" <th>filename</th>\n",
" <th>title</th>\n",
" <th>content</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>business</td>\n",
" <td>001.txt</td>\n",
" <td>Ad sales boost Time Warner profit</td>\n",
" <td>Quarterly profits at US media giant TimeWarne...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>business</td>\n",
" <td>002.txt</td>\n",
" <td>Dollar gains on Greenspan speech</td>\n",
" <td>The dollar has hit its highest level against ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>business</td>\n",
" <td>003.txt</td>\n",
" <td>Yukos unit buyer faces loan claim</td>\n",
" <td>The owners of embattled Russian oil giant Yuk...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>business</td>\n",
" <td>004.txt</td>\n",
" <td>High fuel prices hit BA's profits</td>\n",
" <td>British Airways has blamed high fuel prices f...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>business</td>\n",
" <td>005.txt</td>\n",
" <td>Pernod takeover talk lifts Domecq</td>\n",
" <td>Shares in UK drinks and food firm Allied Dome...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2220</th>\n",
" <td>tech</td>\n",
" <td>397.txt</td>\n",
" <td>BT program to beat dialler scams</td>\n",
" <td>BT is introducing two initiatives to help bea...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2221</th>\n",
" <td>tech</td>\n",
" <td>398.txt</td>\n",
" <td>Spam e-mails tempt net shoppers</td>\n",
" <td>Computer users across the world continue to i...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2222</th>\n",
" <td>tech</td>\n",
" <td>399.txt</td>\n",
" <td>Be careful how you code</td>\n",
" <td>A new European directive could put software w...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2223</th>\n",
" <td>tech</td>\n",
" <td>400.txt</td>\n",
" <td>US cyber security chief resigns</td>\n",
" <td>The man making sure US computer networks are ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2224</th>\n",
" <td>tech</td>\n",
" <td>401.txt</td>\n",
" <td>Losing yourself in online gaming</td>\n",
" <td>Online role playing games are time-consuming,...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2225 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" category filename title \\\n",
"0 business 001.txt Ad sales boost Time Warner profit \n",
"1 business 002.txt Dollar gains on Greenspan speech \n",
"2 business 003.txt Yukos unit buyer faces loan claim \n",
"3 business 004.txt High fuel prices hit BA's profits \n",
"4 business 005.txt Pernod takeover talk lifts Domecq \n",
"... ... ... ... \n",
"2220 tech 397.txt BT program to beat dialler scams \n",
"2221 tech 398.txt Spam e-mails tempt net shoppers \n",
"2222 tech 399.txt Be careful how you code \n",
"2223 tech 400.txt US cyber security chief resigns \n",
"2224 tech 401.txt Losing yourself in online gaming \n",
"\n",
" content \n",
"0 Quarterly profits at US media giant TimeWarne... \n",
"1 The dollar has hit its highest level against ... \n",
"2 The owners of embattled Russian oil giant Yuk... \n",
"3 British Airways has blamed high fuel prices f... \n",
"4 Shares in UK drinks and food firm Allied Dome... \n",
"... ... \n",
"2220 BT is introducing two initiatives to help bea... \n",
"2221 Computer users across the world continue to i... \n",
"2222 A new European directive could put software w... \n",
"2223 The man making sure US computer networks are ... \n",
"2224 Online role playing games are time-consuming,... \n",
"\n",
"[2225 rows x 4 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"df_orig = pd.read_csv(\"../data/bbc-news-data.csv\", delimiter='\\t')\n",
"df = df_orig.copy()\n",
"df"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deploy a model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# list models deployed with embeddings capability\n",
"deployment_id = None\n",
"result = openai.Deployment.list()\n",
"\n",
"for deployment in result.data:\n",
" if deployment[\"status\"] != \"succeeded\":\n",
" continue\n",
" \n",
" model = openai.Model.retrieve(deployment[\"model\"])\n",
" if model[\"capabilities\"][\"embeddings\"] != True:\n",
" continue\n",
" \n",
" deployment_id = deployment[\"id\"]\n",
" break\n",
"\n",
"# if not model deployed, deploy one\n",
"if not deployment_id:\n",
" print('No deployment with status: succeeded found.')\n",
" model = \"text-similarity-davinci-001\"\n",
"\n",
" # Now let's create the deployment\n",
" print(f'Creating a new deployment with model: {model}')\n",
" result = openai.Deployment.create(model=model, scale_settings={\"scale_type\":\"standard\"})\n",
" deployment_id = result[\"id\"]\n",
" print(f'Successfully created {model} with deployment_id {deployment_id}')\n",
"else:\n",
" print(f'Found a succeeded deployment that supports embeddings with id: {deployment_id}.')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Embeddings\n",
"ref: https://learn.microsoft.com/en-us/azure/cognitive-services/openai/tutorials/embeddings?tabs=bash"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12288"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embedding = openai.Embedding.create(\n",
" input=\"Your text goes here\",\n",
" deployment_id=deployment_id)\n",
"\n",
"# Access embeddings\n",
"len(embedding[\"data\"][0][\"embedding\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df['embedding'] = ''\n",
"\n",
"for i in range(len(df)): \n",
"#for i in range(760,765):\n",
" try:\n",
" embedding = openai.Embedding.create(input=df['content'][i], deployment_id=deployment_id)\n",
" df['embedding'][i] = embedding['data'][0]['embedding']\n",
" except Exception as err:\n",
" i\n",
" print(f\"Unexpected {err=}, {type(err)=}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save embeddings\n",
"Following notebooks will require embeddings generated in this notebook. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Save embeddings\n",
"df.to_csv(\"../data/bbc-news-data-embedding.csv\", sep='\\t', index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "azureml_py38",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "6d65a8c07f5b6469e0fc613f182488c0dccce05038bbda39e5ac9075c0454d11"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}