跳到主要内容

状态管理

状态的角色

状态是流经图的数据载体,包含工作流执行所需的所有信息。

定义状态

基本状态

from typing import TypedDict

class SimpleState(TypedDict):
query: str
response: str
error: str | None

复杂状态

from typing import TypedDict, Annotated, Any
from operator import add

class ComplexState(TypedDict):
# 基本字段
user_id: str
query: str

# 列表字段(使用operator.add自动合并)
messages: Annotated[list, add]

# 字典字段
metadata: dict

# 可选字段
error: str | None = None

# Any 类型
context: Any

状态修改

返回修改

def process_node(state: SimpleState) -> dict:
"""修改状态的标准方式"""
# 只返回修改的字段
return {
"response": "处理结果"
}

# 不需要返回所有字段
# 返回的字段会merge进state

完全替换

def node_complete_replace(state: SimpleState) -> SimpleState:
"""返回完整的新状态"""
state["response"] = "新结果"
return state

Reducers(状态约化)

使用operator.add

from typing import Annotated
from operator import add

class MessageState(TypedDict):
messages: Annotated[list, add]

# 多个节点都可以添加消息
def node1(state: MessageState) -> dict:
return {"messages": [{"sender": "node1", "content": "消息1"}]}

def node2(state: MessageState) -> dict:
return {"messages": [{"sender": "node2", "content": "消息2"}]}

# 结果自动合并
# state["messages"] 会包含来自两个节点的消息

自定义Reducer

from typing import Annotated

def custom_reducer(a: list, b: list) -> list:
"""自定义reducer:去重和排序"""
combined = a + b
unique = list(set(combined))
return sorted(unique)

class CustomState(TypedDict):
unique_items: Annotated[list, custom_reducer]

私有状态

某些节点的状态不应该跨越边界:

class PrivateState(TypedDict):
# 公共字段
user_query: str
final_response: str

# 私有字段(仅在特定节点内使用)
# 注意:TypedDict中无法真正隐藏
# 但可以通过命名约定标识为私有
_internal_cache: dict
_temp_results: list

状态结构最佳实践

❌ 不好的实践

class BadState(TypedDict):
# 包含过多信息
huge_data: list[dict] # 可能很大
all_results: dict # 临时结果
buffer: str # 缓冲区

✅ 好的实践

class GoodState(TypedDict):
# 清晰的职责划分
user_id: str # 用户标识
query: str # 用户查询
results: list # 最终结果
error: str | None # 错误信息

使用MessagesState

LangGraph提供预定义的消息状态:

from langgraph.graph import MessagesState

# MessagesState等价于:
# class MessagesState(TypedDict):
# messages: Annotated[list, add]

# 使用
def process(state: MessagesState):
# 访问messages
last_message = state["messages"][-1]
return {"messages": [new_message]}

状态初始化

initial_state = {
"query": "用户问题",
"response": "",
"error": None
}

result = app.invoke(initial_state)

状态流动

实际例子:对话状态

from typing import TypedDict, Annotated
from operator import add
from langgraph.graph import MessagesState

class ConversationState(TypedDict):
# 继承消息状态
messages: Annotated[list, add]

# 添加对话元数据
user_id: str
session_id: str
conversation_type: str

# 处理结果
final_response: str | None = None
turns: int = 0

def model_node(state: ConversationState) -> dict:
"""模型处理消息"""
messages = state["messages"]
# 调用模型
response = model.invoke(messages)

return {
"messages": [response],
"turns": state["turns"] + 1
}

def tool_node(state: ConversationState) -> dict:
"""执行工具"""
last_message = state["messages"][-1]
tool_results = execute_tools(last_message)

return {
"messages": tool_results
}

性能考虑

状态大小

# ❌ 大状态
class BigState(TypedDict):
all_documents: list # 可能包含GB的数据

# ✅ 优化后
class OptimizedState(TypedDict):
document_ids: list # 只存储ID
retrieved_count: int

状态更新频率

# ❌ 频繁更新
def frequent_updates(state):
state["counter"] += 1
for i in range(1000):
state["items"].append(i)
return state

# ✅ 批量更新
def batch_updates(state):
new_items = [i for i in range(1000)]
return {
"counter": state["counter"] + 1,
"items": new_items
}

调试状态

def debug_state(state: ConversationState):
"""打印状态用于调试"""
print(f"=== State Debug ===")
for key, value in state.items():
if isinstance(value, list) and len(value) > 3:
print(f"{key}: [{len(value)} items]")
else:
print(f"{key}: {value}")

return state

常见错误

❌ 状态中存储大的二进制数据 ❌ 运行时动态修改状态结构 ❌ 忘记使用reducer处理聚合 ❌ 在节点中修改了state但没有返回

下一步