#  Using TorchServe on SageMaker Inf2.24xlarge with Llama2-13B

## Contents

This notebook uses SageMaker notebook instance conda_python3 kernel, demonstrates how to use TorchServe to deploy Llama-2-13 on SageMaker inf2.24xlarge. There are multiple advanced features in this example.

* Neuronx AOT precompile model
* TorchServe microbatching
* TorchServe LLM batching streaming respone on SageMaker
* SageMaker uncompressed model artifacts

## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
!python --version

In [None]:
# Install the latest aws cli v2 if it is not installed
!curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
!unzip awscliv2.zip
#Follow the instruction to install aws v2 on the terminal
!cat aws/README.md

In [None]:
# Conform it is aws-cli/2.xx.xx
!aws --version

In [None]:
#%pip install sagemaker pip --upgrade  --quiet
!pip install numpy
!pip install pillow
!pip install -U sagemaker
!pip install -U boto 
!pip install -U botocore
!pip install -U boto3

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

boto3_session=boto3.session.Session(region_name="us-west-2")
smr = boto3.client('sagemaker-runtime-demo')
sm = boto3.client('sagemaker')
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess= sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account = sess.account_id()  # account_id of the current SageMaker Studio environment

# Configuration:
bucket_name = sess.default_bucket()
prefix = "torchserve"
output_path = f"s3://{bucket_name}/{prefix}"
print(f'account={account}, region={region}, role={role}, output_path={output_path}')

## Step 2: Build a BYOD TorchServe Docker container and push it to Amazon ECR

In [None]:
# Install our own dependencies
!cat workspace/docker/Dockerfile

In [2]:
%%capture build_output

baseimage = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.13.2-ubuntu20.04"
reponame = "neuronx"
versiontag = "2-13-2"

# Build our own docker image
!cd workspace/docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}

In [None]:
if 'Error response from daemon' in str(build_output):
    print(build_output)
    raise SystemExit('\n\n!!There was an error with the container build!!')
else:
    container = str(build_output).strip().split('\n')[-1]

print(container)

## Step 3: AOT Pre-Compile Model on EC2 

This step is to precompile model with batchSize and save it in cache dir to use it in later model loading. It can significantly reduce model loading latency to a few seconds. This example already ran this step and saved the model artifacts in **TorchServe `model zoo: s3://torchserve/mar_files/sm-neuronx/llama-2-13b-neuronx-b1/`**. You can **skip this step, directly jump to option 2 of Step 4 if you want to use it**. 

### Precompile the model at local EC2

Follow [Instruction](https://github.com/pytorch/serve/blob/master/examples/large_models/inferentia2/llama2/Readme.md?plain=1#L56) at local EC2 to precompile the model. The precompiled model will be stored in `model_store/llama-2-13b/neuron_cache`

#### Main steps on EC2
                     
* download llama-2-13b from HF 
* Save the model split checkpoints compatible with `transformers-neuronx`
* Create EC2 model artifacts at local
* Start TorchServe to load the model
* Precompile model and save it in `model_store/llama-2-13b/neuron_cache`
* Run inference to test

#### Keynotes

1. **Turn on neuron_cache** in function initialize of inf2_handler.py
```
os.environ["NEURONX_CACHE"] = "on"
os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache"
```

2. Set **micro_batch_size: 1** in model-config.yaml to precompile the model with batch size 1 in this example

In [3]:
!cat workspace/model-config.yaml

minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 10800
batchSize: 16

handler:
    model_checkpoint_dir: "llama-2-13b-split"
    amp: "bf16"
    tp_degree: 12 
    max_length: 100

micro_batching:
    micro_batch_size: 1
    parallelism:
        preprocess: 2
        inference: 1
        postprocess: 2


3. The generated cache files are **tightly coupled with neuronx SDK version**
```
This example is based on Neuron DLC 2.13.2: 763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.13.2-ubuntu20.04
```

## Step 4: Upload model artifacts to S3

You can choose either option 1 or option 2.

* option 1: Sample command to copy llama-2-13b from EC2 to S3
```
print(output_path)
!aws s3 cp llama-2-13b {output_path}/llama-2-13b --recursive
s3_uri = f"{output_path}/llama-2-13b-neuronx-b4/"
```

* option 2: The model artifacts of this example is available in **TorchServe Model Zoo `s3://torchserve/mar_files/sm-neuronx/llama-2-13b-neuronx-b1/`** which supports llama2-13b on neuronx batchSize = 1 based on torch-neuronx==1.13.1.1.10.1. You can copy it to your SM S3 model repo. For example

In [None]:
!aws s3 cp s3://torchserve/mar_files/sm-neuronx/llama-2-13b-neuronx-b1/ llama-2-13b-neuronx-b1 --recursive

In [None]:
!aws s3 cp llama-2-13b-neuronx-b1 {output_path}/llama-2-13b-neuronx-b1 --recursive

In [None]:
s3_uri = f"{output_path}/llama-2-13b-neuronx-b1/"
print(s3_uri)

## Step 5: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [20]:
from datetime import datetime

instance_type = "ml.inf2.24xlarge"
endpoint_name = sagemaker.utils.name_from_base("ts-inf2-llama2-13b-b1")

model = Model(
    name="torchserve-inf2-llama2-13b" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    # Enable SageMaker uncompressed model artifacts
    model_data={
        "S3DataSource": {
                "S3Uri": s3_uri,
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
        }
    },
    image_uri=container,
    role=role,
    sagemaker_session=sess,
    env={"TS_INSTALL_PY_DEP_PER_MODEL": "true"},
)
print(model)

<sagemaker.model.Model object at 0x7fda64bf7040>


In [21]:
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    volume_size=512, # increase the size to store large model
    model_data_download_timeout=3600, # increase the timeout to download large model
    container_startup_health_check_timeout=600, # increase the timeout to load large model
)

Your model is not compiled. Please compile your model before using Inferentia.


---------------!

## Run the inference with streaming response

### TorchServe streaming response sample code

TorchServe [TextIteratorStreamerBatch](https://github.com/pytorch/serve/blob/d0ae857abfe6d36813c88e531316149a5a354a93/ts/handler_utils/hf_batch_streamer.py#L7C7-L7C32) extends HF BaseStream to support a batch of inference requests. It streaming response via [send_intermediate_predict_response](https://github.com/pytorch/serve/blob/d0ae857abfe6d36813c88e531316149a5a354a93/ts/protocol/otf_message_handler.py#L355). 

In [5]:
!cat workspace/inf2_handler.py

import logging
import os
from abc import ABC
from threading import Thread

import torch_neuronx
from transformers import AutoConfig, LlamaTokenizer
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
from transformers_neuronx.llama.model import LlamaForSampling

from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch
from ts.handler_utils.micro_batching import MicroBatching
from ts.protocol.otf_message_handler import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class LLMHandler(BaseHandler, ABC):
    """
    Transformers handler class for text completion streaming on Inferentia2
    """

    def __init__(self):
        super().__init__()
        self.initialized = False
        self.max_length = None
        self.tokenizer = None
        self.output_streamer = None
        # enable micro batching
        self.handle = MicroBa

### SageMaker streaming response

In [22]:
import io

class Parser:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        data = self.buff.getvalue()
        
    def scan_lines(self):
        self.buff.seek(self.read_pos)
        for line in self.buff.readlines():
            if line[-1] != b'\n':
                self.read_pos += len(line)
                yield line[:-1]
                
    def reset(self):
        self.read_pos = 0

In [23]:
import json

body = "Today the weather is really nice and I am planning on".encode('utf-8')
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=body, ContentType="application/json")
event_stream = resp['Body']
parser = Parser()
for event in event_stream:
    parser.write(event['PayloadPart']['Bytes'])
    for line in parser.scan_lines():
        print(line.decode("utf-8"), end=' ')

Today the weather is really nice and I am planning on going to the beach. I am going to take my camera and take some pictures of the beach. I am going to take pictures of the sand, the water, and the people. I am also going to take pictures of the sunset. I am really excited to go to the beach and take pictures. The beach is a great place to take pictures. The sand, the water, and the people are all great subjects for pictures. The sunset is also a great subject for pictures 

## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()