# Amazon SageMaker Using AWS Large Model Inference (LMI) Deep Learning Container with Speculative Decoding 

**Recommended Kernel(s):** You can run this notebook with any Amazon SageMaker Studio kernel. We recommend using the Data Science 3.0 kernel.

This notebook demonstrates how to deploy [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) HuggingFace model to a SageMaker Endpoint for text generation with speculative decoding enabled. In this example, the SageMaker-managed [LMI (Large Model Inference)](https://docs.djl.ai/docs/serving/serving/docs/large_model_inference.html) Docker image will serve as the inference image. LMI images feature a [DJL serving](https://github.com/deepjavalibrary/djl-serving) stack powered by the [Deep Java Library](https://djl.ai/). 

**What is speculative decoding?** In the context of large language model inference, speculative decoding, as introduced by [Y. Leviathan et al. (ICML 2023)](https://arxiv.org/abs/2211.17192), is a technique used to accelerate the decoding process of large and therefore slow LLMs for latency-critical applications. The key idea is to use a smaller, less powerful but faster model called the *draft model* to generate candidate tokens that get validated by the larger, more powerful but slower model called the *target model*. At each iteration, the draft model generates $K>1$ candidate tokens. Then, using a single forward pass of the larger *target model*, none, part, or all candidate tokens get accepted. The more aligned the selected draft model is with the target model, the better guesses it makes, resulting in a higher candidate token acceptance rate and potential for speed up.


## 1. Dependency Installation
### 1.1. Python Dependencies & Imports
This notebook requires the following Python dependencies:
* AWS [`sagemaker_core`]()

Let's install or upgrade these dependencies using the following command:

In [None]:
%pip install pip --upgrade --quiet
%pip install sagemaker-core huggingface_hub --quiet

In [None]:
import json
import os

from sagemaker_core.helper.session_helper import get_execution_role, Session
import pathlib 
import huggingface_hub

In [None]:
PROJECT_ROOT_DIR = pathlib.Path.cwd()

SM_SESSION = Session()

REGION = SM_SESSION._region_name

SM_DEFAULT_EXECUTION_ROLE_ARN = get_execution_role()

INSTANCE_TYPE = "ml.g5.12xlarge" 

#INSTANCE_TYPE = "ml.p4d.24xlarge" 

# See https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers for lastest DLC. 
container_image_uri = f"763104351884.dkr.ecr.{REGION}.amazonaws.com/djl-inference:0.29.0-tensorrtllm0.11.0-cu124"

default_bucket = SM_SESSION.default_bucket()

TARGET_MODEL = "meta-llama/Meta-Llama-3-8B"
DRAFT_MODEL = "sagemaker"

## Download custom draft model

## 2. Deploy Speculative Decoding-Enabled Endpoint
### 2.1. Endpoint Deployment

We will configure what models to use and server startup parameters and via environment variables.   

* `OPTION_MODEL_ID` points to the HF Model ID or base S3 prefix of the target model artifacts
* `OPTION_SPECULATIVE_DRAFT_MODEL` points to the HF Model ID, base S3 prefix of the draft model artifacts or in our case `sagemaker` which represents the SageMaker draft model. 

You can read about the other config paramters in LMI [documentation](https://github.com/deepjavalibrary/djl-serving/tree/master/serving/docs/lmi).

Note you can replace the target and draft model for you own.

In [None]:
environment = {
    "HF_MODEL_ID":TARGET_MODEL,
    "OPTION_SPECULATIVE_DRAFT_MODEL":DRAFT_MODEL,
    "OPTION_GPU_MEMORY_UTILIZATION":"0.85",
    "HF_TOKEN":"<YourHFToken>"
}

In [None]:
from sagemaker_core.shapes import ContainerDefinition, ProductionVariant
from sagemaker_core.resources import Model, EndpointConfig, Endpoint
from time import gmtime, strftime

In [None]:
container_defintion = ContainerDefinition(
        image=container_image_uri,
        environment=environment
        
    )
container_defintion.environment = environment

In [None]:
model_name =f'speculative-decoding-hugging-face-{strftime("%H-%M-%S", gmtime())}'

model = Model.create( 
    model_name=model_name,
    primary_container=container_defintion,
    execution_role_arn=SM_DEFAULT_EXECUTION_ROLE_ARN,

)

Start-up of LLM inference containers can last longer than smaller models, mainly due to longer model downloading and loading times. Timeout values need to be increased accordingly from their default values. Each endpoint deployment takes a few minutes. We also set `routing_strategy` which would benefit us if we were to back our endpoint with multiple instances.

In [None]:
from sagemaker_core.shapes import ProductionVariantRoutingConfig

routing_config = ProductionVariantRoutingConfig(
    routing_strategy="LEAST_OUTSTANDING_REQUESTS"
)


In [None]:
endpoint_config = EndpointConfig.create(
    endpoint_config_name=model_name,
    production_variants=[
        ProductionVariant(
            variant_name=model_name,
            initial_instance_count=1,
            instance_type=INSTANCE_TYPE,
            model_name=model,
            container_startup_health_check_timeout_in_seconds=3600,
            model_data_download_timeout_in_seconds=3600 ,
            routing_config=routing_config
        )
    ]
)

In [None]:
endpoint = Endpoint.create(
    endpoint_name=model_name,
    endpoint_config_name=endpoint_config # Pass `EndpointConfig` object created above
)

This cells will block until the endpoint is deployed, which is necessary for the following steps.

In [None]:
endpoint.wait_for_status("InService")

### Endpoint invocation

Let's invoke our endpoint and get a sample response.

In [None]:
response = endpoint.invoke(
        body =json.dumps({
        "inputs": ["What is the capital of France?"],
        "max_new_tokens": 512,
        "temperature": 0.0,
    }),
        content_type ="application/json"
)
response['Body'].read()

### 3. Clean Up Endpoint

In [None]:
# Delete any sagemaker core resource objects created in this notebook
def delete_all_sagemaker_resources():
    all_objects = list(locals().values()) + list(globals().values())
    deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker_core.main.resources']
    
    for obj in deletable_objects:
        obj.delete()
        
delete_all_sagemaker_resources()