Spaces:
Runtime error
Runtime error
File size: 2,312 Bytes
abecf76 0375d91 abecf76 0375d91 abecf76 0375d91 abecf76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from retriever import retrieve_context
from tools import tools
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnableBranch
from langchain_community.chat_models import ChatOllama
from langgraph.graph import END, StateGraph
from typing import Annotated, TypedDict, List
import operator
model = ChatOllama(model="qwen:1.8b")
tools_with_names = {tool.name: tool for tool in tools}
class AgentState(TypedDict):
messages: Annotated[List], []
next: str
tool_chain = (
RunnableParallel({
"message": lambda x: x["messages"][-1].content,
"tool": lambda x: x["next"]
})
| (lambda x: tools_with_names[x["tool"]].invoke(x["message"]))
| (lambda x: {"messages": [AIMessage(content=str(x))], "next": "end"})
)
system = """
You are a helpful assistant. Use tools if needed. Keep responses short.
"""
prompt = PromptTemplate.from_template("""{context}
{question}
""")
context_chain = (
{
"context": RunnableLambda(retrieve_context),
"question": lambda x: x["messages"][-1].content,
}
| prompt
)
agent = context_chain | model | StrOutputParser() | (lambda x: {"messages": [AIMessage(content=x), HumanMessage(content="Do you want to use a tool?")], "next": "tool_picker"})
conditional_agent = RunnableBranch(
(lambda x: "tool" in x["next"], tool_chain),
agent
)
def create_graph():
graph_builder = StateGraph(AgentState)
graph_builder.add_node("agent", conditional_agent)
graph_builder.set_entry_point("agent")
graph_builder.add_node("tool_chain", tool_chain)
graph_builder.add_conditional_edges(
"agent", lambda x: x["next"], {
"tool": "tool_chain",
"end": END
}
)
graph_builder.add_edge("tool_chain", "agent")
return graph_builder.compile()
app = create_graph()
chain = RunnableWithMessageHistory(
app,
lambda session_id: {},
input_messages_key="messages",
history_messages_key="messages",
)
|