runnable-hub/python/runnable_workers/chainWorker/worker.py (138 lines of code) (raw):
from typing import Dict
from runnable_hub import RunnableWorker, RunnableContext, RunnableStatus
from .request.chainRequest import ChainRequest, ChainFunction
from .request.chainFunction import ChainFunctionType
from .response import ChainResponse
import sys
from jinja2 import Environment
import os
import re
import json
# enter --> llmEnd <---- llmStart
# | ^
# | |
# v |
# functionCallStart ---> functionCallEnd ---> finish
class Worker(RunnableWorker):
runnableCode = "CHAIN"
Request = ChainRequest
Response = ChainResponse
pythonBin = sys.executable
def __init__(self, storePath = "/tmp/python"):
self_dir = os.path.dirname(__file__)
self.jinjaEnv = Environment()
with(open(self_dir + "/next.py", "r")) as h:
self.nextPostScript = h.read()
@staticmethod
def fnInputMerge(fn: ChainFunction, inputs: Dict|str) -> Dict:
finalInputs = fn.presetInputs.copy()
if isinstance(inputs, str):
finalInputs.update({
fn.inputDefine[0].name: inputs
})
else:
finalInputs.update(inputs)
return finalInputs
# @staticmethod
# async def run_python(command, cwd=None, env=None):
# process = await asyncio.create_subprocess_exec(
# *command,
# stdout=asyncio.subprocess.PIPE,
# stderr=asyncio.subprocess.PIPE,
# cwd=cwd,
# env=env
# )
# stdout, stderr = await process.communicate()
# return process.returncode, stdout.decode(), stderr.decode()
async def onNext(self, context: RunnableContext[ChainRequest, ChainResponse]) -> RunnableContext:
if context.data.get("runtime") is None:
# todo merge function call defines
renderData = context.request.data
renderData["tool_info"] = ""
renderData["function_info"] = ""
renderData["agent_info"] = ""
for fn in context.request.functions:
args = ", ".join([f"{arg.name}: {arg.type.value}" for arg in fn.inputDefine])
info_block = f"{fn.name}: {fn.name}({args}) - {fn.description}\n"
renderData["function_info"] += info_block
if fn.type == ChainFunctionType.TOOL:
renderData["tool_info"] += info_block
elif fn.type == ChainFunctionType.AGENT:
renderData["agent_info"] += info_block
systemPrompt = self.jinjaEnv.from_string(context.request.systemPrompt).render(**renderData)
userPrompt = self.jinjaEnv.from_string(context.request.userPrompt).render(**renderData)
context.promise.resolve["llm"] = {
"runnableCode": "LLM",
"setting": context.request.llm.model_dump(),
"systemPrompt": systemPrompt,
"userPrompt": userPrompt,
}
context.data["runtime"] = {
"history": [],
"systemPrompt": systemPrompt,
"nextStep": "llmEnd",
"lastCompletion": None,
}
return context
elif context.data["runtime"]["nextStep"] == "llmEnd":
if context.promise.result["llm"] is None:
raise ValueError("LLM response is missing")
context.promise.resolve["processMessage"] = {
"runnableCode": "PYTHON",
"data": {
"completion": context.promise.result["llm"]["content"],
},
"run": context.request.onNext + "\n" + self.nextPostScript,
}
context.data["runtime"]["nextStep"] = "functionCallStart"
context.data["runtime"]["lastCompletion"] = context.promise.result["llm"]["content"]
context.data["runtime"]["history"] += context.promise.result["llm"]["messages"]
return context
elif context.data["runtime"]["nextStep"] == "functionCallStart":
if context.promise.result["processMessage"] is None:
raise ValueError("processMessage response is missing")
stdout = context.promise.result["processMessage"]["stdout"]
matches = re.findall(r"<finalAnswer>(.*?)</finalAnswer>", stdout, re.DOTALL)
if len(matches) > 0:
context.status = RunnableStatus.SUCCESS
context.response = ChainResponse(finalAnswer=matches[0],history=context.data["runtime"]["history"])
return context
matches = re.findall(r"<function>(.*?)</function>", stdout, re.DOTALL)
if len(matches) > 0:
fnCall = json.loads(matches[0])
fnDefines = [fn for fn in context.request.functions if fn.name == fnCall["name"]]
if len(fnDefines) == 0:
context.status = RunnableStatus.ERROR
context.errorMessage = f"Function:{fnCall['name']} not found"
return context
fnDefine = fnDefines[0]
if fnDefine.type == ChainFunctionType.TOOL:
runnableCode = "TOOL"
elif fnDefine.type == ChainFunctionType.AGENT:
runnableCode = "AGENT"
context.promise.resolve["functionCall"] = {
"runnableCode": runnableCode,
"toolCode": fnDefine.name,
"toolVersion": fnDefine.version,
"inputs": self.fnInputMerge(fnDefine, fnCall["input"]),
}
context.data["runtime"]["nextStep"] = "functionCallEnd"
return context
context.status = RunnableStatus.ERROR
context.errorMessage = "finalAnswer or function not found"
return context
elif context.data["runtime"]["nextStep"] == "functionCallEnd":
if context.promise.result["functionCall"] is None:
raise ValueError("functionCall response is missing")
context.promise.resolve["processMessage"] = {
"runnableCode": "PYTHON",
"data": {
"function": json.dumps(context.promise.result["functionCall"]["outputs"]),
"completion": context.data["runtime"]["lastCompletion"],
},
"run": context.request.onNext + "\n" + self.nextPostScript,
}
context.data["runtime"]["nextStep"] = "llmStart"
return context
elif context.data["runtime"]["nextStep"] == "llmStart":
if context.promise.result["processMessage"] is None:
raise ValueError("processMessage response is missing")
stdout = context.promise.result["processMessage"]["stdout"]
matches = re.findall(r"<message>(.*?)</message>", stdout, re.DOTALL)
if len(matches) > 0:
context.promise.resolve["llm"] = {
"runnableCode": "LLM",
"setting": context.request.llm.model_dump(),
"systemPrompt": context.data["runtime"]["systemPrompt"],
"userPrompt": matches[0],
"history": context.data["runtime"]["history"],
}
context.data["runtime"]["nextStep"] = "llmEnd"
return context
context.status = RunnableStatus.ERROR
context.errorMessage = "message not found"
return context
return context