in torchserve/inf2/llama2/workspace/inf2_handler.py [0:0]
def inference(self, tokenized_input):
generation_kwargs = dict(
tokenized_input,
streamer=self.output_streamer,
max_new_tokens=self.max_length,
)
self.model.reset_generation()
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
micro_batch_idx = self.handle.get_micro_batch_idx()
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
for new_text in self.output_streamer:
logger.debug("send response stream")
send_intermediate_predict_response(
new_text[: len(micro_batch_req_id_map)],
micro_batch_req_id_map,
"Intermediate Prediction success",
200,
self.context,
)
thread.join()
return [""] * len(micro_batch_req_id_map)