in packages/constructs/L3/ai/gaia-l3-construct/lib/model-interfaces/langchain/functions/request-handler/adapters/base/base.py [0:0]
def run_with_chain_v2(self, user_prompt, workspace_id=None):
if not self.llm:
raise ValueError("llm must be set")
self.callback_handler.prompts = []
documents = []
retriever = None
if workspace_id:
retriever = WorkspaceRetriever(workspace_id=workspace_id)
# Only stream the last llm call (otherwise the internal
# llm response will be visible)
llm_without_streaming = self.get_llm({"streaming": False})
history_aware_retriever = create_history_aware_retriever(
llm_without_streaming,
retriever,
self.get_condense_question_prompt(),
)
question_answer_chain = create_stuff_documents_chain(
self.llm,
self.get_qa_prompt(),
)
chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
else:
chain = self.get_prompt() | self.llm
conversation = RunnableWithMessageHistory(
chain,
lambda session_id: self.chat_history,
history_messages_key="chat_history",
input_messages_key="input",
output_messages_key="output",
)
config = {"configurable": {"session_id": self.session_id}}
try:
if not self.disable_streaming and self.model_kwargs.get("streaming", False):
answer = ""
for chunk in conversation.stream(
input={"input": user_prompt}, config=config
):
logger.debug("chunk", chunk=chunk)
if "answer" in chunk:
answer = answer + chunk["answer"]
elif isinstance(chunk, AIMessageChunk):
for c in chunk.content:
if "text" in c:
answer = answer + c.get("text")
else:
response = conversation.invoke(
input={"input": user_prompt}, config=config
)
if "answer" in response:
answer = response.get("answer") # RAG flow
else:
answer = response.content
except Exception as e:
logger.exception(e)
raise e
if workspace_id:
# In the RAG flow, the history is not updated automatically
self.chat_history.add_message(HumanMessage(user_prompt))
self.chat_history.add_message(AIMessage(answer))
if retriever is not None:
documents = [
{
"page_content": doc.page_content,
"metadata": doc.metadata,
}
for doc in retriever.get_last_search_documents()
]
metadata = {
"modelId": self.model_id,
"modelKwargs": self.model_kwargs,
"mode": self._mode,
"sessionId": self.session_id,
"userId": self.user_id,
"documents": documents,
"prompts": self.callback_handler.prompts,
"usage": self.callback_handler.usage,
}
self.chat_history.add_metadata(metadata)
if (
self.callback_handler.usage is not None
and "total_tokens" in self.callback_handler.usage
):
# Used by Cloudwatch filters to generate a metric of token usage.
logger.info(
"Usage Metric",
# Each unique value of model id will create a
# new cloudwatch metric (each one has a cost)
model=self.model_id,
metric_type="token_usage",
value=self.callback_handler.usage.get("total_tokens"),
)
return {
"sessionId": self.session_id,
"type": "text",
"content": answer,
"metadata": metadata,
}