torchserve/inf2/llama2/workspace/inf2_handler.py (129 lines of code) (raw):
import logging
import os
from abc import ABC
from threading import Thread
import torch_neuronx
from transformers import AutoConfig, LlamaTokenizer
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
from transformers_neuronx.llama.model import LlamaForSampling
from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch
from ts.handler_utils.micro_batching import MicroBatching
from ts.protocol.otf_message_handler import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class LLMHandler(BaseHandler, ABC):
"""
Transformers handler class for text completion streaming on Inferentia2
"""
def __init__(self):
super().__init__()
self.initialized = False
self.max_length = None
self.tokenizer = None
self.output_streamer = None
# enable micro batching
self.handle = MicroBatching(self)
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get(
"model_checkpoint_dir", ""
)
model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}"
os.environ["NEURONX_CACHE"] = "on"
os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache"
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference"
# micro batching initialization
micro_batching_parallelism = ctx.model_yaml_config.get(
"micro_batching", {}
).get("parallelism", None)
if micro_batching_parallelism:
logger.info(
f"Setting micro batching parallelism from model_config_yaml: {micro_batching_parallelism}"
)
self.handle.parallelism = micro_batching_parallelism
micro_batch_size = ctx.model_yaml_config.get("micro_batching", {}).get(
"micro_batch_size", 1
)
logger.info(f"Setting micro batching size: {micro_batch_size}")
self.handle.micro_batch_size = micro_batch_size
# settings for model compiliation and loading
amp = ctx.model_yaml_config.get("handler", {}).get("amp", "f32")
tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6)
self.max_length = ctx.model_yaml_config.get("handler", {}).get("max_length", 50)
# allocate "tp_degree" number of neuron cores to the worker process
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
try:
num_neuron_cores_available = (
torch_neuronx.xla_impl.data_parallel.device_count()
)
assert num_neuron_cores_available >= int(tp_degree)
except (RuntimeError, AssertionError) as error:
logger.error(
"Required number of neuron cores for tp_degree "
+ str(tp_degree)
+ " are not available: "
+ str(error)
)
raise error
self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = LlamaForSampling.from_pretrained(
model_checkpoint_path,
batch_size=self.handle.micro_batch_size,
amp=amp,
tp_degree=tp_degree,
)
logger.info("Starting to compile the model")
self.model.to_neuron()
logger.info("Model has been successfully compiled")
model_config = AutoConfig.from_pretrained(model_checkpoint_path)
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model)
self.output_streamer = TextIteratorStreamerBatch(
self.tokenizer,
batch_size=self.handle.micro_batch_size,
skip_special_tokens=True,
)
self.initialized = True
def preprocess(self, requests):
input_text = []
for req in requests:
data = req.get("data") or req.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")
logger.info(f"received req={data}")
input_text.append(data.strip())
# Ensure the compiled model can handle the input received
if len(input_text) > self.handle.micro_batch_size:
raise ValueError(
f"Model is compiled for batch size {self.handle.micro_batch_size} but received input of size {len(input_text)}"
)
# Pad input to match compiled model batch size
input_text.extend([""] * (self.handle.micro_batch_size - len(input_text)))
return self.tokenizer(input_text, return_tensors="pt", padding=True)
def inference(self, tokenized_input):
generation_kwargs = dict(
tokenized_input,
streamer=self.output_streamer,
max_new_tokens=self.max_length,
)
self.model.reset_generation()
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
micro_batch_idx = self.handle.get_micro_batch_idx()
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
for new_text in self.output_streamer:
logger.debug("send response stream")
send_intermediate_predict_response(
new_text[: len(micro_batch_req_id_map)],
micro_batch_req_id_map,
"Intermediate Prediction success",
200,
self.context,
)
thread.join()
return [""] * len(micro_batch_req_id_map)
def postprocess(self, inference_output):
return inference_output
def get_micro_batch_req_id_map(self, micro_batch_idx: int):
start_idx = micro_batch_idx * self.handle.micro_batch_size
micro_batch_req_id_map = {
index: self.context.request_ids[batch_index]
for index, batch_index in enumerate(
range(start_idx, start_idx + self.handle.micro_batch_size)
)
if batch_index in self.context.request_ids
}
return micro_batch_req_id_map