跳到主要内容

工具与自定义

工具的基本概念

工具是代理可以调用的函数,用于执行特定的操作或检索信息。在 DeepAgents 中,工具是代理与外部世界交互的主要方式。

工具的三个要素

  1. 函数名:自动用作工具名称
  2. 类型注解:定义参数类型和返回类型
  3. 文档字符串:描述工具的功能和使用方式

定义工具

最简形式

def search_wikipedia(topic: str) -> str:
"""在维基百科上搜索指定主题的信息。"""
# 实现搜索逻辑
return f"关于 {topic} 的信息..."

标准工具定义

def fetch_stock_price(
symbol: str,
date: str | None = None
) -> dict[str, any]:
"""
获取股票价格信息。

参数:
- symbol: 股票代码(如 'AAPL')
- date: 查询日期(格式为 YYYY-MM-DD,可选)

返回:
包含价格、涨跌幅等信息的字典
"""
# 实现逻辑
return {
"symbol": symbol,
"price": 150.25,
"change": 2.5
}

工具最佳实践

1. 清晰的文档字符串

# ❌ 不好:文档不清晰
def process(x):
"""处理数据"""
pass

# ✅ 好:文档详细
def analyze_sentiment(text: str) -> dict[str, float]:
"""
分析文本的情感倾向。

返回包含 positive, negative, neutral 分数的字典,
每个分数在 0 到 1 之间,总和为 1。
"""
pass

2. 合理的参数粒度

# ❌ 过于复杂
def query_database(query: str, limit: int, offset: int,
cache: bool, timeout: int,
format: str, debug: bool): pass

# ✅ 简化版本
def query_database(query: str, limit: int = 10) -> list[dict]:
"""查询数据库"""
pass

3. 错误处理

def fetch_api_data(endpoint: str) -> str:
"""从 API 获取数据。"""
try:
response = requests.get(f"https://api.example.com/{endpoint}")
response.raise_for_status()
return response.text
except requests.RequestException as e:
return f"错误:无法获取数据({str(e)})"
except Exception as e:
return f"未知错误:{str(e)}"

使用工具

添加到代理

from deepagents import create_deep_agent

def add(a: int, b: int) -> int:
"""计算两个数的和"""
return a + b

def subtract(a: int, b: int) -> int:
"""计算两个数的差"""
return a - b

agent = create_deep_agent(
tools=[add, subtract],
system_prompt="你是一个数学助手"
)

result = agent.invoke({
"messages": [{"role": "user", "content": "10 加 5 是多少?"}]
})

多个工具组合

def web_search(query: str) -> str:
"""在互联网上搜索信息"""
pass

def read_document(path: str) -> str:
"""读取本地文档"""
pass

def summarize_text(text: str) -> str:
"""总结文本"""
pass

agent = create_deep_agent(
tools=[web_search, read_document, summarize_text],
system_prompt="""你是一个研究助手。

可以使用的工具:
- web_search:搜索互联网上的信息
- read_document:读取本地文档
- summarize_text:总结长文本

根据用户的需求选择合适的工具。"""
)

工具的类型

1. 数据查询工具

def query_customer_database(customer_id: int) -> dict:
"""查询客户信息数据库"""
# 模拟数据库查询
return {
"id": customer_id,
"name": "张三",
"email": "zhangsan@example.com",
"purchase_history": [...]
}

def search_products(keywords: str) -> list[dict]:
"""搜索产品"""
return [...]

2. 数据处理工具

def parse_csv(content: str) -> list[dict]:
"""解析 CSV 格式的数据"""
import csv
from io import StringIO

reader = csv.DictReader(StringIO(content))
return list(reader)

def convert_currency(amount: float, from_currency: str, to_currency: str) -> float:
"""货币换算"""
# 使用实时汇率或缓存的汇率
rates = {
"USD_CNY": 7.0,
"EUR_USD": 1.1,
}
return amount * rates.get(f"{from_currency}_{to_currency}", 1.0)

3. 外部 API 集成工具

import os
import requests

def weather_forecast(city: str, days: int = 3) -> str:
"""获取天气预报"""
api_key = os.environ.get("WEATHER_API_KEY")
response = requests.get(
f"https://api.weatherapi.com/v1/forecast.json",
params={
"key": api_key,
"q": city,
"days": days,
"aqi": "yes"
}
)
return response.json()

def send_email(to: str, subject: str, body: str) -> bool:
"""发送邮件"""
import smtplib
from email.mime.text import MIMEText

try:
msg = MIMEText(body)
msg["Subject"] = subject
msg["From"] = "agent@example.com"
msg["To"] = to

# 实际发送逻辑
return True
except Exception as e:
print(f"邮件发送失败:{e}")
return False

4. 文件操作工具

def save_report(filename: str, content: str) -> str:
"""保存报告到文件"""
try:
with open(filename, 'w', encoding='utf-8') as f:
f.write(content)
return f"报告已保存到 {filename}"
except Exception as e:
return f"保存失败:{str(e)}"

def load_config(config_file: str) -> dict:
"""加载配置文件"""
import json
with open(config_file, 'r') as f:
return json.load(f)

高级工具定义

使用枚举限制选项

from enum import Enum

class SortOrder(str, Enum):
ASCENDING = "ascending"
DESCENDING = "descending"

def list_items(
category: str,
sort_by: SortOrder = SortOrder.ASCENDING,
limit: int = 10
) -> list[dict]:
"""
列出指定分类的项目。

参数:
- category: 项目分类
- sort_by: 排序方式(ascending 或 descending)
- limit: 返回的最大项目数
"""
pass

返回复杂数据结构

from typing import TypedDict

class SearchResult(TypedDict):
title: str
url: str
snippet: str
relevance_score: float

def advanced_search(query: str, max_results: int = 5) -> list[SearchResult]:
"""执行高级搜索"""
pass

异步工具

import asyncio

async def fetch_multiple_apis(*endpoints: str) -> dict[str, any]:
"""并行调用多个 API"""
async def fetch(endpoint):
# 模拟异步请求
await asyncio.sleep(1)
return f"数据来自 {endpoint}"

results = await asyncio.gather(*[fetch(ep) for ep in endpoints])
return {endpoint: result for endpoint, result in zip(endpoints, results)}

工具的安全性考虑

参数验证

def delete_file(filepath: str) -> str:
"""删除文件(需要谨慎)"""
# 验证路径安全性
import os
from pathlib import Path

# 防止目录遍历攻击
base_dir = "/safe/directory"
target = Path(base_dir) / filepath

try:
if not str(target).startswith(base_dir):
raise ValueError("不允许访问此路径")
target.unlink()
return f"文件已删除:{filepath}"
except Exception as e:
return f"删除失败:{str(e)}"

权限检查

def modify_settings(setting_name: str, value: any) -> str:
"""修改系统设置"""
restricted_settings = {"admin_password", "api_keys", "database_url"}

if setting_name in restricted_settings:
return f"无权修改受限设置:{setting_name}"

# 执行修改
return f"设置已更新:{setting_name}"

将工具组织为类

对于相关工具的集合,可以组织为类:

class DataAnalesisTool:
"""数据分析工具集"""

def calculate_statistics(self, data: list[float]) -> dict[str, float]:
"""计算统计信息"""
import statistics
return {
"mean": statistics.mean(data),
"median": statistics.median(data),
"stdev": statistics.stdev(data),
}

def find_outliers(self, data: list[float], threshold: float = 2.0) -> list[float]:
"""找出异常值"""
mean = sum(data) / len(data)
stdev = (sum((x - mean) ** 2 for x in data) / len(data)) ** 0.5
return [x for x in data if abs((x - mean) / stdev) > threshold]

def trend_analysis(self, data: list[float]) -> str:
"""分析数据趋势"""
if len(data) < 2:
return "数据不足"

first_half_avg = sum(data[:len(data)//2]) / (len(data)//2)
last_half_avg = sum(data[len(data)//2:]) / (len(data) - len(data)//2)

if last_half_avg > first_half_avg:
return "上升趋势"
else:
return "下降趋势"

# 使用工具类
analysis = DataAnalesisTool()

agent = create_deep_agent(
tools=[
analysis.calculate_statistics,
analysis.find_outliers,
analysis.trend_analysis,
],
system_prompt="你是一个数据分析专家"
)

工具的文档生成

代理使用工具的文档字符串来理解如何使用工具。健壮的文档对代理的性能至关重要:

def transfer_funds(
from_account: str,
to_account: str,
amount: float
) -> dict[str, any]:
\"\"\"
在账户之间转账。

此操作会立即执行,且无法撤销。请在调用前确认所有参数。

参数:
- from_account: 源账户号(格式:CH-XXXXXX)
- to_account: 目标账户号(格式:CH-XXXXXX)
- amount: 转账金额,正数,单位为元

返回:
{
"success": bool,
"transaction_id": str,
"timestamp": str,
"message": str
}

示例:
transfer_funds("CH-000001", "CH-000002", 100.0)

异常情况:
- 账户不存在时返回 success=False
- 余额不足时返回 success=False
- 金额为负数时返回 success=False
\"\"\"
pass

调试和测试工具

单元测试

import unittest

class TestCalculationTools(unittest.TestCase):

def test_add_basic(self):
result = add(2, 3)
self.assertEqual(result, 5)

def test_add_negative(self):
result = add(-2, 3)
self.assertEqual(result, 1)

if __name__ == "__main__":
unittest.main()

集成测试

def test_agent_with_tools():
agent = create_deep_agent(
tools=[add, subtract, multiply],
system_prompt="你是一个计算器"
)

result = agent.invoke({
"messages": [{"role": "user", "content": "计算 (10 + 5) * 2"}]
})

# 验证代理正确使用了工具
assert "30" in result["messages"][-1].content

工具的性能优化

缓存机制

from functools import lru_cache

@lru_cache(maxsize=128)
def get_user_profile(user_id: int) -> dict:
"""获取用户资料(带缓存)"""
# 数据库查询
pass

def clear_cache():
"""清除缓存"""
get_user_profile.cache_clear()

超时控制

import signal

def timeout_handler(signum, frame):
raise TimeoutError("工具执行超时")

def long_running_task(input_data: str) -> str:
"""长时间运行的任务"""
# 设置 30 秒超时
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(30)

try:
# 执行任务
result = process_large_dataset(input_data)
signal.alarm(0) # 取消超时
return result
except TimeoutError:
return "任务超时"

下一步