状态管理
状态的角色
状态是流经图的数据载体,包含工作流执行所需的所有信息。
定义状态
基本状态
from typing import TypedDict
class SimpleState(TypedDict):
query: str
response: str
error: str | None
复杂状态
from typing import TypedDict, Annotated, Any
from operator import add
class ComplexState(TypedDict):
# 基本字段
user_id: str
query: str
# 列表字段(使用operator.add自动合并)
messages: Annotated[list, add]
# 字典字段
metadata: dict
# 可选字段
error: str | None = None
# Any 类型
context: Any
状态修改
返回修改
def process_node(state: SimpleState) -> dict:
"""修改状态的标准方式"""
# 只返回修改的字段
return {
"response": "处理结果"
}
# 不需要返回所有字段
# 返回的字段会merge进state
完全替换
def node_complete_replace(state: SimpleState) -> SimpleState:
"""返回完整的新状态"""
state["response"] = "新结果"
return state
Reducers(状态约化)
使用operator.add
from typing import Annotated
from operator import add
class MessageState(TypedDict):
messages: Annotated[list, add]
# 多个节点都可以添加消息
def node1(state: MessageState) -> dict:
return {"messages": [{"sender": "node1", "content": "消息1"}]}
def node2(state: MessageState) -> dict:
return {"messages": [{"sender": "node2", "content": "消息2"}]}
# 结果自动合并
# state["messages"] 会包含来自两个节点的消息
自定义Reducer
from typing import Annotated
def custom_reducer(a: list, b: list) -> list:
"""自定义reducer:去重和排序"""
combined = a + b
unique = list(set(combined))
return sorted(unique)
class CustomState(TypedDict):
unique_items: Annotated[list, custom_reducer]
私有状态
某些节点的状态不应该跨越边界:
class PrivateState(TypedDict):
# 公共字段
user_query: str
final_response: str
# 私有字段(仅在特定节点内使用)
# 注意:TypedDict中无法真正隐藏
# 但可以通过命名约定标识为私有
_internal_cache: dict
_temp_results: list
状态结构最佳实践
❌ 不好的实践
class BadState(TypedDict):
# 包含过多信息
huge_data: list[dict] # 可能很大
all_results: dict # 临时结果
buffer: str # 缓冲区
✅ 好的实践
class GoodState(TypedDict):
# 清晰的职责划分
user_id: str # 用户标识
query: str # 用户查询
results: list # 最终结果
error: str | None # 错误信息
使用MessagesState
LangGraph提供预定义的消息状态:
from langgraph.graph import MessagesState
# MessagesState等价于:
# class MessagesState(TypedDict):
# messages: Annotated[list, add]
# 使用
def process(state: MessagesState):
# 访问messages
last_message = state["messages"][-1]
return {"messages": [new_message]}
状态初始化
initial_state = {
"query": "用户问题",
"response": "",
"error": None
}
result = app.invoke(initial_state)
状态流动
实际例子:对话状态
from typing import TypedDict, Annotated
from operator import add
from langgraph.graph import MessagesState
class ConversationState(TypedDict):
# 继承消息状态
messages: Annotated[list, add]
# 添加对话元数据
user_id: str
session_id: str
conversation_type: str
# 处理结果
final_response: str | None = None
turns: int = 0
def model_node(state: ConversationState) -> dict:
"""模型处理消息"""
messages = state["messages"]
# 调用模型
response = model.invoke(messages)
return {
"messages": [response],
"turns": state["turns"] + 1
}
def tool_node(state: ConversationState) -> dict:
"""执行工具"""
last_message = state["messages"][-1]
tool_results = execute_tools(last_message)
return {
"messages": tool_results
}
性能考虑
状态大小
# ❌ 大状态
class BigState(TypedDict):
all_documents: list # 可能包含GB的数据
# ✅ 优化后
class OptimizedState(TypedDict):
document_ids: list # 只存储ID
retrieved_count: int
状态更新频率
# ❌ 频繁更新
def frequent_updates(state):
state["counter"] += 1
for i in range(1000):
state["items"].append(i)
return state
# ✅ 批量更新
def batch_updates(state):
new_items = [i for i in range(1000)]
return {
"counter": state["counter"] + 1,
"items": new_items
}
调试状态
def debug_state(state: ConversationState):
"""打印状态用于调试"""
print(f"=== State Debug ===")
for key, value in state.items():
if isinstance(value, list) and len(value) > 3:
print(f"{key}: [{len(value)} items]")
else:
print(f"{key}: {value}")
return state
常见错误
❌ 状态中存储大的二进制数据 ❌ 运行时动态修改状态结构 ❌ 忘记使用reducer处理聚合 ❌ 在节点中修改了state但没有返回