deploy-models/deploy-llm-chatbot/sagemaker_lmi_chatbot.ipynb (323 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deploy a Model to a SageMaker Endpoint and run a simple Gradio UI\n",
"In this notebook, we will demonstrate how to deploy a large language model as as SageMaker endpoint and interact with it using a simple Gradio UI. This approach enables quick and easy experimentation with a wide range of models. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup\n",
"\n",
"First, we need to install the required libraries and set up the environment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -Uqq sagemaker\n",
"%pip install -Uqq \"huggingface_hub[cli]\"\n",
"%pip install -Uqq gradio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import boto3\n",
"import sagemaker\n",
"from pathlib import Path\n",
"from sagemaker.djl_inference.model import DJLModel\n",
"import re\n",
"\n",
"role = sagemaker.get_execution_role() # execution role for the endpoint\n",
"sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs\n",
"region = sess._region_name # region name of the current SageMaker Studio environment\n",
"bucket = sess.default_bucket() # default bucket name\n",
"prefix = \"gradio_chatbot\" \n",
"account_id = sess.account_id() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Download Model\n",
"Next we will download the model from the Hugging Face model hub. Certain models are gated and require signing in to download. See the documentation [here](https://huggingface.co/docs/huggingface_hub/en/guides/cli) for the CLI commands to login. In this case, we will use the `Phi-3.5-mini-instruct` model that is publicly available."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"LLM_NAME = \"microsoft/Phi-3.5-mini-instruct\"\n",
"LLM_DIR = os.path.basename(LLM_NAME).lower()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use the `huggingface-cli` to download the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli download $LLM_NAME --local-dir $LLM_DIR"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we can upload the model to our S3 bucket so that we can deploy it to a SageMaker endpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_s3_path = f\"s3://{bucket}/{prefix}/llm/{LLM_DIR}\"\n",
"!aws s3 sync $LLM_DIR $llm_s3_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Deploy Model\n",
"Now that the model is in our S3 bucket, we can deploy it to a SageMaker endpoint. We will use the `DJLModel` class from the `sagemaker.djl_inference` module to deploy the model. **DGL** stands for Deep Java Library, which is a Java framework for deep learning that also underpins the [Large Model Inference](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/index.html) container which is what we will use to deploy the model.\n",
"\n",
"LMI allows us to deploy the models without needing to write any inference code. We can simply provide our desired deployment configuration via environment variables and pass it to the `DJLModel` class.\n",
"Below are example configurations that work with the `Phi-3.5-mini-instruct` model. You can adjust these configurations based on your model's requirements. See [here](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/deployment_guide/configurations.html) for a complete list of configurations."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# define inference environment for LLM\n",
"llm_env = env = {\n",
" \"TENSOR_PARALLEL_DEGREE\": \"1\", # use 1 GPUs\n",
" \"OPTION_ROLLING_BATCH\": \"vllm\", # use VLLM rolling batch\n",
" \"OPTION_MAX_ROLLING_BATCH_SIZE\": \"32\", # max rolling batch size (controls the concurrency)\n",
" \"OPTION_DTYPE\": \"fp16\", # load weights in fp16\n",
" \"OPTION_MAX_MODEL_LEN\": \"16384\", # max context length in tokens for the model\n",
" \"OPTION_TRUST_REMOTE_CODE\": \"true\", # trust remote code\n",
" \"OPTION_GPU_MEMORY_UTILIZATION\": \"0.95\", # use 95% of GPU memory\n",
"}\n",
"\n",
"# create DJLModel object for LLM\n",
"# see here for LMI version updates https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers \n",
"sm_llm_model = DJLModel(\n",
" model_id=llm_s3_path,\n",
" djl_version=\"0.29.0\",\n",
" djl_framework=\"djl-lmi\",\n",
" role=role,\n",
" env=llm_env,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"instance_type = \"ml.g5.2xlarge\"\n",
"endpoint_name = sagemaker.utils.name_from_base(f\"{re.sub('[._]+', '-', LLM_DIR)}\")\n",
"\n",
"llm_predictor = sm_llm_model.deploy(initial_instance_count=1,\n",
" instance_type=instance_type,\n",
" endpoint_name=endpoint_name,\n",
" container_startup_health_check_timeout=1800\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"LMI includes a [chat completions](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html) API which works with message based payloads. This only works with models that provide a [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating) as part of their tokenizer.\n",
"\n",
"We can validate the deployment by sending a test request to the endpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful AI assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n",
" {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n",
" {\"role\": \"user\", \"content\": \"Can you write a python function that parses BibTex using only standard python libraries?\"},\n",
"]\n",
"\n",
"result = llm_predictor.predict(\n",
"\n",
" {\"messages\": chat, \"max_tokens\": 2000, \"temperature\": 0.5}\n",
")\n",
"\n",
"response_message = result[\"choices\"][0][\"message\"]\n",
"\n",
"print(response_message[\"content\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Build chat interface\n",
"Finally, we will build a simple Gradio UI that allows us to interact with the model. Gradio is a Python library that allows you to quickly create UIs for your models.\n",
"\n",
"Gradio provides a built in [ChatInterface](https://www.gradio.app/docs/gradio/chatinterface) class for chat based interfaces. We can use this class to build a chat interface for our model simply by implementing a chat function that takes in the latest message, the history of prior messages, and any other generation parameters we wish to configure."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr\n",
"from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def chat(predictor, message, history, system_prompt=None, max_tokens=100, temperature=0.5, top_p=0.99):\n",
" \n",
" messages = []\n",
" if system_prompt:\n",
" messages.append({\"role\": \"system\", \"content\": system_prompt})\n",
" \n",
" for turn in history:\n",
" user_msg, assistant_msg = turn\n",
" messages.append({\"role\": \"user\", \"content\": user_msg})\n",
" messages.append({\"role\": \"assistant\", \"content\": assistant_msg})\n",
" \n",
" messages.append({\"role\": \"user\", \"content\": message})\n",
" \n",
" response = predictor.predict({\"messages\": messages, \"max_tokens\": max_tokens, \"temperature\": temperature, \"top_p\": top_p})\n",
" \n",
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
" \n",
" return response_message"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat_function = partial(chat, llm_predictor) # create a partial function with the predictor\n",
"\n",
"gr.close_all()\n",
"\n",
"chat_interface = gr.ChatInterface(\n",
" chat_function,\n",
" title=\"Example Chatbot\",\n",
" description=f\"Example chatbot powered by {LLM_DIR}\",\n",
" additional_inputs=[\n",
" gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"),\n",
" gr.Slider(1, 4000, 500, label=\"Max Tokens\"),\n",
" gr.Slider(0.1, 1.0, 0.5, label=\"Temperature\"),\n",
" gr.Slider(0.1, 1.0, 0.99, label=\"Top P\", value=0.99),\n",
" ],\n",
" additional_inputs_accordion = \"Model Settings\",\n",
")\n",
"chat_interface.chatbot.render_markdown = True\n",
"chat_interface.chatbot.height = 400\n",
"\n",
"chat_interface.load()\n",
"\n",
"chat_interface.queue(default_concurrency_limit=10).launch(share=True) # set to False to keep private"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Clean Up"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"llm_predictor.delete_endpoint()\n",
"# optional: clean up the S3 bucket\n",
"# !aws s3 rm --recursive $llm_s3_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}