{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "# Deploy Jetbrains AI Mellum Python Model Package from AWS Marketplace\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mellum is JetBrains' first large language model (LLM) optimized for code-related tasks.\n",
    "\n",
    "Designed for integration into professional developer tooling (e.g., intelligent code suggestions in IDEs), AI-powered coding assistants, and research on code understanding and generation.\n",
    "\n",
    "This sample notebook shows you how to deploy [JetBrains AI Mellum Python](https://aws.amazon.com/marketplace/pp/prodview-btjfafaielwwa) using Amazon SageMaker.\n",
    "\n",
    "> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.\n",
    "\n",
    "## Pre-requisites:\n",
    "1. **Note**: This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.\n",
    "1. Ensure that IAM role used has **AmazonSageMakerFullAccess**\n",
    "\n",
    "\n",
    "## Contents:\n",
    "1. [Subscribe to the model package](#1.-Subscribe-to-the-model-package)\n",
    "2. [Create an endpoint and perform real-time inference](#2.-Create-an-endpoint-and-perform-real-time-inference)\n",
    "   1. [Create an endpoint](#A.-Create-an-endpoint)\n",
    "   2. [Create input payload](#B.-Create-input-payload)\n",
    "   3. [Perform real-time inference](#C.-Perform-real-time-inference)\n",
    "   4. [Visualize output](#D.-Visualize-output)\n",
    "   5. [Delete the endpoint](#E.-Delete-the-endpoint)\n",
    "3. [Clean-up](#3.-Clean-up)\n",
    "    1. [Delete the model](#A.-Delete-the-model)\n",
    "    2. [Unsubscribe to the listing (optional)](#B.-Unsubscribe-to-the-listing-(optional)\n",
    "    \n",
    "\n",
    "## Usage instructions\n",
    "You can run this notebook one cell at a time (By using Shift+Enter for running a cell)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Subscribe to the model package"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To subscribe to the model package:\n",
    "1. Open the model package listing page: [JetBrains AI Mellum Python](https://aws.amazon.com/marketplace/pp/prodview-btjfafaielwwa).\n",
    "1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.\n",
    "1. On the **Subscribe to this software** page, review and click on **\"Accept Offer\"** if you and your organization agrees with EULA, pricing, and support terms. \n",
    "1. Once you click on **Continue to configuration button** and then choose a **region**, you will see a **Product Arn** displayed. This is the model package ARN that you need to specify while creating a deployable model using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": "model_package_arn = \"<Customer to specify Model package ARN corresponding to their AWS region>\"",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import json\n",
    "\n",
    "import sagemaker as sage\n",
    "from sagemaker import get_execution_role\n",
    "from sagemaker import ModelPackage\n",
    "import boto3"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "role = get_execution_role()\n",
    "\n",
    "sagemaker_session = sage.Session()\n",
    "\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "runtime = boto3.client(\"runtime.sagemaker\")\n",
    "bucket"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create an endpoint and perform real-time inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to understand how real-time inference with Amazon SageMaker works, see [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model_name = \"jbai-mellum-python\"\n",
    "\n",
    "content_type = \"application/json\"\n",
    "\n",
    "# The recommended instance type for real-time inference is 'ml.g6e.xlarge'\n",
    "# but it's sometimes challenging to acquire, so we use easier to get 'ml.g5.2xlarge'\n",
    "real_time_inference_instance_type = \"ml.g5.2xlarge\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### A. Create an endpoint"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# create a deployable model from the model package.\n",
    "model = ModelPackage(\n",
    "    role=role, model_package_arn=model_package_arn, sagemaker_session=sagemaker_session\n",
    ")\n",
    "\n",
    "# Deploy the model\n",
    "mellum = model.deploy(1, real_time_inference_instance_type, endpoint_name=model_name)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once endpoint has been created, you would be able to perform real-time inference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Create input payload"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "payload = {\n",
    "  \"prefix\": \"from settings import USER\\n\\ndef main():\\n    print(\\\"Hello {\",\n",
    "  \"suffix\": \"\",\n",
    "  \"filepath\": \"main.py\",\n",
    "  \"context\": [\n",
    "      {\n",
    "          \"type\": \"DirectoryFile\",\n",
    "          \"filepath\": \"settings.py\",\n",
    "          \"content\": \"USER = 'cat'\\n\"\n",
    "      }\n",
    "  ],\n",
    "  \"max_length\": 32,\n",
    "  \"stop_token\": \"\\n\\n\",\n",
    "  \"use_control\": \"off\"\n",
    "}"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### C. Perform real-time inference"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "def run_inference(request, endpoint):\n",
    "    raw_response = runtime.invoke_endpoint(\n",
    "        EndpointName=endpoint,\n",
    "        Body=json.dumps(request),\n",
    "        ContentType=\"application/json\",\n",
    "    )\n",
    "    status_code = raw_response[\"ResponseMetadata\"][\"HTTPStatusCode\"]\n",
    "    body = raw_response[\"Body\"].read().decode()\n",
    "    assert 200 <= status_code < 300, f\"Request failed with the following message:\\n {body}\"\n",
    "\n",
    "    messages = []\n",
    "    for line in body.splitlines():\n",
    "        if line.startswith(\"data:\"):\n",
    "            event_body = line.removeprefix(\"data: \")\n",
    "            if event_body != \"end\":\n",
    "                messages.append(json.loads(event_body))\n",
    "    return messages\n",
    "\n",
    "output = run_inference(payload, model_name)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### D. Visualize output"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "TPL = \"\"\"Code completion:\n",
    "```\n",
    "{completion}\n",
    "```\n",
    "\n",
    "All decoded SSE messages from the endpoint:\n",
    "{raw_response}\"\"\"\n",
    "\n",
    "def pretty_print_output(inference_output):\n",
    "    first_line = \"\"\n",
    "    rest = \"\"\n",
    "    for message in inference_output:\n",
    "        if message[\"type\"] == \"FirstLine\":\n",
    "            first_line = message[\"completion\"]\n",
    "        elif message[\"type\"] == \"Rest\":\n",
    "            rest = message[\"completion\"]\n",
    "\n",
    "    print(\n",
    "        TPL.format(\n",
    "            completion=first_line + rest,\n",
    "            raw_response=json.dumps(inference_output, indent=2),\n",
    "        )\n",
    "    )\n",
    "    \n",
    "pretty_print_output(output)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### E. Delete the endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that you have successfully performed a real-time inference, you do not need the endpoint any more. You can terminate the endpoint to avoid being charged."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model.sagemaker_session.delete_endpoint(model_name)\n",
    "model.sagemaker_session.delete_endpoint_config(model_name)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Clean-up"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A. Delete the model"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model.delete_model()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Unsubscribe to the listing (optional)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you would like to unsubscribe to the model package, follow these steps. Before you cancel the subscription, ensure that you do not have any [deployable model](https://console.aws.amazon.com/sagemaker/home#/models) created from the model package or using the algorithm. Note - You can find this information by looking at the container name associated with the model. \n",
    "\n",
    "**Steps to unsubscribe to product from AWS Marketplace**:\n",
    "1. Navigate to __Machine Learning__ tab on [__Your Software subscriptions page__](https://aws.amazon.com/marketplace/ai/library?productType=ml&ref_=mlmp_gitdemo_indust)\n",
    "2. Locate the listing that you want to cancel the subscription for, and then choose __Cancel Subscription__  to cancel the subscription.\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
