sdk/python/foundation-models/ai21-labs/jamba-1-5/ai21_azure_client.ipynb (745 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use AI21's Azure Client with Jamba 1.5 Large and Jamba 1.5 Mini through Azure AI Models-as-a-Service\n",
"\n",
"Use AI21's Azure client to consume Jamba 1.5 Large or Jamba 1.5 Mini deployments in Azure AI and Azure ML through serverless API endpoints delivered through Models-as-a-Service (MaaS).\n",
"\n",
"> Review the documentation for Jamba 1.5 models for [AI Studio](https://aka.ms/ai21-jamba-1.5-large-ai-studio-docs) and for [ML Studio](https://aka.ms/ai21-jamba-1.5-large-ml-studio-docs) for details on how to provision inference endpoints, regional availability, pricing and inference schema reference.\n",
"\n",
"The below samples are seen on [AI21's GitHub](https://github.com/AI21Labs/ai21-python/tree/main/examples/studio) and shared here for ease of use with their Azure client."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites\n",
"\n",
"Before we start, there are certain steps we need to take to deploy the models:\n",
"\n",
"* Register for a valid Azure account with subscription \n",
"* Make sure you have access to [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio?tabs=home)\n",
"* Create a project and resource group\n",
"* Select `AI21 Jamba 1.5 Large` or `AI21 Jamba 1.5 Mini`\n",
"\n",
" > Notice that some models may not be available in all the regions in Azure AI and Azure Machine Learning. On those cases, you can create a workspace or project in the region where the models are available and then consume it with a connection from a different one. To learn more about using connections see [Consume models with connections](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deployments-connections)\n",
"\n",
"* Deploy with \"Pay-as-you-go\"\n",
"\n",
"Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.\n",
"\n",
"For more information, you should consult Azure's official documentation [here](https://aka.ms/ai21-jamba-1.5-large-azure-ai-studio-docs) for model deployment and inference.\n",
"\n",
"To complete this tutorial, you will need to: \n",
"\n",
"* Install `ai21`:\n",
"\n",
" ```bash\n",
" pip install -U \"ai21>=2.13.0\"\n",
" ```\n",
"* If it's not working on your first go, try restarting the kernel and then run the pip install again."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## General Example\n",
"\n",
"The following is an example about how to use `ai21`'s client on Azure and leveraging this for AI21 Jamba 1.5 Large through MaaS."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -U \"ai21>=2.13.0\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install asyncio aiohttp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"is_executing": true
},
"name": "imports"
},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import uuid\n",
"from enum import Enum\n",
"import asyncio\n",
"from pydantic import BaseModel\n",
"from ai21 import AsyncAI21Client\n",
"from ai21 import AI21AzureClient\n",
"from ai21.models.chat import (\n",
" ChatMessage,\n",
" ResponseFormat,\n",
" ToolMessage,\n",
" FunctionToolDefinition,\n",
" DocumentSchema,\n",
")\n",
"from ai21.models.chat.chat_message import SystemMessage, UserMessage, AssistantMessage\n",
"from ai21.models.chat.function_tool_definition import FunctionToolDefinition\n",
"from ai21.models.chat.tool_defintions import ToolDefinition\n",
"from ai21.models.chat.tool_parameters import ToolParameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use `ai21`, create a client and configure it as follows:\n",
"\n",
"- `endpoint`: Use the endpoint URL from your deployment. Include `/v1` at the end of the endpoint.\n",
"- `api_key`: Use your API key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"base_url = \"<your-maas-endpoint>\"\n",
"api_key = \"<your-api-key>\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"name": "chat_client"
},
"outputs": [],
"source": [
"client = AI21AzureClient(base_url=base_url, api_key=api_key)\n",
"# async_client = AsyncAI21Client(base_url=base_url, api_key=api_key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = (\n",
" \"jamba-1.5-large\" # Change to \"jamba-1.5-mini\" if you'd like to try the Mini model\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, you can set an environment variable for your API key:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"AI21_API_KEY\"] = \"<your-api-key>\"\n",
"client = AI21AzureClient(\n",
" base_url=\"<your-maas-endpoint>\", api_key=os.environ.get(\"AI21_API_KEY\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the client to create chat completions requests:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"system = \"You're a support engineer in a SaaS company\"\n",
"\n",
"messages = [\n",
" SystemMessage(content=system, role=\"system\"),\n",
" UserMessage(content=\"Hello, I need help with a signup process.\", role=\"user\"),\n",
" AssistantMessage(\n",
" content=\"Hi Alice, I can help you with that. What seems to be the problem?\",\n",
" role=\"assistant\",\n",
" ),\n",
" UserMessage(\n",
" content=\"I am having trouble signing up for your product with my Google account.\",\n",
" role=\"user\",\n",
" ),\n",
"]\n",
"\n",
"chat_completions = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" temperature=1.0, # Setting =1 allows for greater variability per API call.\n",
" top_p=1.0, # Setting =1 allows full sample of tokens to be considered per API call.\n",
" max_tokens=100,\n",
" stop=[\"\\n\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The generated text can be accessed as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"name": "chat_response"
},
"outputs": [],
"source": [
"print(chat_completions.to_json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Incorporating Chat response formatting"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TicketType(Enum):\n",
" ADULT = \"adult\"\n",
" CHILD = \"child\"\n",
"\n",
"\n",
"class ZooTicket(BaseModel):\n",
" ticket_type: TicketType\n",
" quantity: int\n",
"\n",
"\n",
"class ZooTicketsOrder(BaseModel):\n",
" date: str\n",
" tickets: list[ZooTicket]\n",
"\n",
"\n",
"messages = [\n",
" ChatMessage(\n",
" role=\"user\",\n",
" content=\"Please create a JSON object for ordering zoo tickets for September 22, 2024, \"\n",
" f\"for myself and two kids, based on the following JSON schema: {ZooTicketsOrder.schema()}.\",\n",
" )\n",
"]\n",
"\n",
"response = client.chat.completions.create(\n",
" messages=messages,\n",
" model=model,\n",
" max_tokens=800,\n",
" temperature=0,\n",
" response_format=ResponseFormat(type=\"json_object\"),\n",
")\n",
"\n",
"zoo_order_json = json.loads(response.choices[0].message.content)\n",
"print(zoo_order_json)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chat Function calling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_order_delivery_date(order_id: str) -> str:\n",
" print(f\"Retrieving the delivery date for order ID: {order_id} from the database...\")\n",
" return \"2025-05-04\"\n",
"\n",
"\n",
"messages = [\n",
" ChatMessage(\n",
" role=\"system\",\n",
" content=\"You are a helpful customer support assistant. Use the supplied tools to assist the user.\",\n",
" ),\n",
" ChatMessage(\n",
" role=\"user\", content=\"Hi, can you tell me the delivery date for my order?\"\n",
" ),\n",
" ChatMessage(\n",
" role=\"assistant\",\n",
" content=\"Hi there! I can help with that. Can you please provide your order ID?\",\n",
" ),\n",
" ChatMessage(role=\"user\", content=\"i think it is order_12345\"),\n",
"]\n",
"\n",
"tool_definition = ToolDefinition(\n",
" type=\"function\",\n",
" function=FunctionToolDefinition(\n",
" name=\"get_order_delivery_date\",\n",
" description=\"Retrieve the delivery date associated with the specified order ID\",\n",
" parameters=ToolParameters(\n",
" type=\"object\",\n",
" properties={\n",
" \"order_id\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The customer's order ID.\",\n",
" }\n",
" },\n",
" required=[\"order_id\"],\n",
" ),\n",
" ),\n",
")\n",
"\n",
"tools = [tool_definition]\n",
"\n",
"response = client.chat.completions.create(\n",
" messages=messages, model=\"jamba-1.5-large\", tools=tools\n",
")\n",
"\n",
"\"\"\" AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations.\n",
"The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function\n",
"to get the delivery date for the user's order. After retrieving the delivery date, we pass the response back\n",
"to the AI model to continue the conversation, using the ToolMessage object. \"\"\"\n",
"\n",
"assistant_message = response.choices[0].message\n",
"messages.append(assistant_message) # Adding the assistant message to the chat history\n",
"\n",
"delivery_date = None\n",
"tool_calls = assistant_message.tool_calls\n",
"if tool_calls:\n",
" tool_call = tool_calls[0]\n",
" if tool_call.function.name == \"get_order_delivery_date\":\n",
" func_arguments = tool_call.function.arguments\n",
" func_args_dict = json.loads(func_arguments)\n",
"\n",
" if \"order_id\" in func_args_dict:\n",
" delivery_date = get_order_delivery_date(func_args_dict[\"order_id\"])\n",
" else:\n",
" print(\"order_id not found in function arguments\")\n",
" else:\n",
" print(f\"Unexpected tool call found - {tool_call.function.name}\")\n",
"else:\n",
" print(\"No tool calls found\")\n",
"\n",
"if delivery_date is not None:\n",
" \"\"\"Continue the conversation by passing the delivery date back to the model\"\"\"\n",
"\n",
" tool_message = ToolMessage(\n",
" role=\"tool\", tool_call_id=tool_calls[0].id, content=delivery_date\n",
" )\n",
" messages.append(tool_message)\n",
"\n",
" response = client.chat.completions.create(\n",
" messages=messages, model=\"jamba-1.5-large\", tools=tools\n",
" )\n",
" print(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chat with Document Schema "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"schnoodel = DocumentSchema(\n",
" id=str(uuid.uuid4()),\n",
" content=\"Schnoodel Inc. Annual Report - 2024. Schnoodel Inc., a leader in innovative culinary technology, saw a \"\n",
" \"15% revenue growth this year, reaching $120 million. The launch of SchnoodelChef Pro has significantly \"\n",
" \"contributed, making up 35% of total sales. We've expanded into the Asian market, notably Japan, \"\n",
" \"and increased our global presence. Committed to sustainability, we reduced our carbon footprint \"\n",
" \"by 20%. Looking ahead, we plan to integrate more advanced machine learning features and expand \"\n",
" \"into South America.\",\n",
" metadata={\"topic\": \"revenue\"},\n",
")\n",
"shnokel = DocumentSchema(\n",
" id=str(uuid.uuid4()),\n",
" content=\"Shnokel Corp. Annual Report - 2024. Shnokel Corp., a pioneer in renewable energy solutions, \"\n",
" \"reported a 20% increase in revenue this year, reaching $200 million. The successful deployment of \"\n",
" \"our advanced solar panels, SolarFlex, accounted for 40% of our sales. We entered new markets in Europe \"\n",
" \"and have plans to develop wind energy projects next year. Our commitment to reducing environmental \"\n",
" \"impact saw a 25% decrease in operational emissions. Upcoming initiatives include a significant \"\n",
" \"investment in R&D for sustainable technologies.\",\n",
" metadata={\"topic\": \"revenue\"},\n",
")\n",
"\n",
"documents = [schnoodel, shnokel]\n",
"\n",
"messages = [\n",
" ChatMessage(\n",
" role=\"system\",\n",
" content=\"You are a helpful assistant that receives revenue documents and answers related questions\",\n",
" ),\n",
" ChatMessage(\n",
" role=\"user\",\n",
" content=\"Hi, which company earned more during 2024 - Schnoodel or Shnokel?\",\n",
" ),\n",
"]\n",
"\n",
"response = client.chat.completions.create(\n",
" messages=messages, model=\"jamba-1.5-mini\", documents=documents\n",
")\n",
"\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Stream Chat Completions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"system = \"You're a support engineer in a SaaS company\"\n",
"messages = [\n",
" ChatMessage(content=system, role=\"system\"),\n",
" ChatMessage(content=\"Hello, I need help with a signup process.\", role=\"user\"),\n",
" ChatMessage(\n",
" content=\"Hi Alice, I can help you with that. What seems to be the problem?\",\n",
" role=\"assistant\",\n",
" ),\n",
" ChatMessage(\n",
" content=\"I am having trouble signing up for your product with my Google account.\",\n",
" role=\"user\",\n",
" ),\n",
"]\n",
"\n",
"response = client.chat.completions.create(\n",
" messages=messages,\n",
" model=model,\n",
" max_tokens=100,\n",
" stream=True,\n",
")\n",
"for chunk in response:\n",
" print(chunk.choices[0].delta.content, end=\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chat function calling with multiple tools"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_weather(place: str, date: str) -> str:\n",
" \"\"\"\n",
" Retrieve the expected weather for a specified location and date.\n",
" \"\"\"\n",
" print(f\"Fetching expected weather for {place} on {date}...\")\n",
" return \"32 celsius\"\n",
"\n",
"\n",
"def get_sunset_hour(place: str, date: str) -> str:\n",
" \"\"\"\n",
" Fetch the expected sunset time for a given location and date.\n",
" \"\"\"\n",
" print(f\"Fetching expected sunset time for {place} on {date}...\")\n",
" return \"7 pm\"\n",
"\n",
"\n",
"messages = [\n",
" ChatMessage(\n",
" role=\"system\",\n",
" content=\"You are a helpful assistant. Use the supplied tools to assist the user.\",\n",
" ),\n",
" ChatMessage(\n",
" role=\"user\",\n",
" content=\"Hello, could you help me find out the weather forecast and sunset time for London?\",\n",
" ),\n",
" ChatMessage(\n",
" role=\"assistant\", content=\"Hi there! I can help with that. On which date?\"\n",
" ),\n",
" ChatMessage(role=\"user\", content=\"At 2024-08-27\"),\n",
"]\n",
"\n",
"get_sunset_tool = ToolDefinition(\n",
" type=\"function\",\n",
" function=FunctionToolDefinition(\n",
" name=\"get_sunset_hour\",\n",
" description=\"Fetch the expected sunset time for a given location and date.\",\n",
" parameters=ToolParameters(\n",
" type=\"object\",\n",
" properties={\n",
" \"place\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location for which the weather is being queried.\",\n",
" },\n",
" \"date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The date for which the weather is being queried.\",\n",
" },\n",
" },\n",
" required=[\"place\", \"date\"],\n",
" ),\n",
" ),\n",
")\n",
"\n",
"get_weather_tool = ToolDefinition(\n",
" type=\"function\",\n",
" function=FunctionToolDefinition(\n",
" name=\"get_weather\",\n",
" description=\"Retrieve the expected weather for a specified location and date.\",\n",
" parameters=ToolParameters(\n",
" type=\"object\",\n",
" properties={\n",
" \"place\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location for which the weather is being queried.\",\n",
" },\n",
" \"date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The date for which the weather is being queried.\",\n",
" },\n",
" },\n",
" required=[\"place\", \"date\"],\n",
" ),\n",
" ),\n",
")\n",
"\n",
"tools = [get_sunset_tool, get_weather_tool]\n",
"\n",
"response = client.chat.completions.create(\n",
" messages=messages, model=\"jamba-1.5-large\", tools=tools\n",
")\n",
"\n",
"\"\"\" AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations.\n",
"The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function\n",
"to get the delivery date for the user's order. After retrieving the delivery date, we pass the response back\n",
"to the AI model to continue the conversation, using the ToolMessage object. \"\"\"\n",
"\n",
"assistant_message = response.choices[0].message\n",
"messages.append(assistant_message) # Adding the assistant message to the chat history\n",
"\n",
"too_call_id_to_result = {}\n",
"tool_calls = assistant_message.tool_calls\n",
"if tool_calls:\n",
" for tool_call in tool_calls:\n",
" if tool_call.function.name == \"get_weather\":\n",
" \"\"\"Verify get_weather tool call arguments and invoke the function to get the weather forecast:\"\"\"\n",
" func_arguments = tool_call.function.arguments\n",
" args = json.loads(func_arguments)\n",
"\n",
" if \"place\" in args and \"date\" in args:\n",
" result = get_weather(args[\"place\"], args[\"date\"])\n",
" too_call_id_to_result[tool_call.id] = result\n",
" else:\n",
" print(f\"Got unexpected arguments in function call - {args}\")\n",
"\n",
" elif tool_call.function.name == \"get_sunset_hour\":\n",
" \"\"\"Verify get_sunset_hour tool call arguments and invoke the function to get the weather forecast:\"\"\"\n",
" func_arguments = tool_call.function.arguments\n",
" args = json.loads(func_arguments)\n",
"\n",
" if \"place\" in args and \"date\" in args:\n",
" result = get_sunset_hour(args[\"place\"], args[\"date\"])\n",
" too_call_id_to_result[tool_call.id] = result\n",
" else:\n",
" print(f\"Got unexpected arguments in function call - {args}\")\n",
"\n",
" else:\n",
" print(f\"Unexpected tool call found - {tool_call.function.name}\")\n",
"else:\n",
" print(\"No tool calls found\")\n",
"\n",
"if too_call_id_to_result:\n",
" \"\"\"Continue the conversation by passing the sunset and weather back to the AI model:\"\"\"\n",
"\n",
" for tool_id_called, result in too_call_id_to_result.items():\n",
" tool_message = ToolMessage(\n",
" role=\"tool\", tool_call_id=tool_id_called, content=str(result)\n",
" )\n",
" messages.append(tool_message)\n",
"\n",
" response = client.chat.completions.create(\n",
" messages=messages, model=\"jamba-1.5-large\", tools=tools\n",
" )\n",
" print(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Async Stream Chat Completions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"from ai21 import AsyncAI21Client, AsyncAI21AzureClient\n",
"from ai21.models.chat import ChatMessage\n",
"\n",
"system = \"You're a support engineer in a SaaS company\"\n",
"messages = [\n",
" ChatMessage(content=system, role=\"system\"),\n",
" ChatMessage(content=\"Hello, I need help with a signup process.\", role=\"user\"),\n",
" ChatMessage(\n",
" content=\"Hi Alice, I can help you with that. What seems to be the problem?\",\n",
" role=\"assistant\",\n",
" ),\n",
" ChatMessage(\n",
" content=\"I am having trouble signing up for your product with my Google account.\",\n",
" role=\"user\",\n",
" ),\n",
"]\n",
"\n",
"client = AsyncAI21AzureClient(base_url=base_url, api_key=api_key)\n",
"\n",
"\n",
"async def main():\n",
" response = await client.chat.completions.create(\n",
" messages=messages,\n",
" model=model,\n",
" max_tokens=100,\n",
" stream=True,\n",
" )\n",
" async for chunk in response:\n",
" print(chunk.choices[0].delta.content, end=\"\")\n",
"\n",
"\n",
"loop = asyncio.get_event_loop()\n",
"loop.create_task(main())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Asynch Chat Completions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"from ai21 import AsyncAI21AzureClient\n",
"from ai21.models.chat import ChatMessage\n",
"\n",
"\n",
"system = \"You're a support engineer in a SaaS company\"\n",
"messages = [\n",
" ChatMessage(content=system, role=\"system\"),\n",
" ChatMessage(content=\"Hello, I need help with a signup process.\", role=\"user\"),\n",
" ChatMessage(\n",
" content=\"Hi Alice, I can help you with that. What seems to be the problem?\",\n",
" role=\"assistant\",\n",
" ),\n",
" ChatMessage(\n",
" content=\"I am having trouble signing up for your product with my Google account.\",\n",
" role=\"user\",\n",
" ),\n",
"]\n",
"\n",
"client = AsyncAI21AzureClient(base_url=base_url, api_key=api_key)\n",
"\n",
"\n",
"async def main():\n",
" response = await client.chat.completions.create(\n",
" messages=messages,\n",
" model=model,\n",
" max_tokens=100,\n",
" temperature=0.7,\n",
" top_p=1.0,\n",
" stop=[\"\\n\"],\n",
" )\n",
"\n",
" print(response)\n",
"\n",
"\n",
"loop = asyncio.get_event_loop()\n",
"loop.create_task(main())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Aditional resources\n",
"\n",
"Here are some additional reference: \n",
"\n",
"* [Plan and manage costs (marketplace)](https://learn.microsoft.com/azure/ai-studio/how-to/costs-plan-manage#monitor-costs-for-models-offered-through-the-azure-marketplace)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jupyter",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}