File size: 2,312 Bytes
abecf76
0375d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abecf76
0375d91
 
abecf76
0375d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abecf76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from retriever import retrieve_context
from tools import tools

from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnableBranch
from langchain_community.chat_models import ChatOllama

from langgraph.graph import END, StateGraph
from typing import Annotated, TypedDict, List
import operator

model = ChatOllama(model="qwen:1.8b")

tools_with_names = {tool.name: tool for tool in tools}

class AgentState(TypedDict):
    messages: Annotated[List], []
    next: str

tool_chain = (
    RunnableParallel({
        "message": lambda x: x["messages"][-1].content,
        "tool": lambda x: x["next"]
    })
    | (lambda x: tools_with_names[x["tool"]].invoke(x["message"]))
    | (lambda x: {"messages": [AIMessage(content=str(x))], "next": "end"})
)

system = """
You are a helpful assistant. Use tools if needed. Keep responses short.
"""
prompt = PromptTemplate.from_template("""{context}

{question}
""")

context_chain = (
    {
        "context": RunnableLambda(retrieve_context),
        "question": lambda x: x["messages"][-1].content,
    }
    | prompt
)

agent = context_chain | model | StrOutputParser() | (lambda x: {"messages": [AIMessage(content=x), HumanMessage(content="Do you want to use a tool?")], "next": "tool_picker"})

conditional_agent = RunnableBranch(
    (lambda x: "tool" in x["next"], tool_chain),
    agent
)

def create_graph():
    graph_builder = StateGraph(AgentState)
    graph_builder.add_node("agent", conditional_agent)
    graph_builder.set_entry_point("agent")
    graph_builder.add_node("tool_chain", tool_chain)
    graph_builder.add_conditional_edges(
        "agent", lambda x: x["next"], {
            "tool": "tool_chain",
            "end": END
        }
    )
    graph_builder.add_edge("tool_chain", "agent")
    return graph_builder.compile()

app = create_graph()
chain = RunnableWithMessageHistory(
    app,
    lambda session_id: {},
    input_messages_key="messages",
    history_messages_key="messages",
)