工具与自定义
工具的基本概念
工具是代理可以调用的函数,用于执行特定的操作或检索信息。在 DeepAgents 中,工具是代理与外部世界交互的主要方式。
工具的三个要素
- 函数名:自动用作工具名称
- 类型注解: 定义参数类型和返回类型
- 文档字符串:描述工具的功能和使用方式
定义工具
最简形式
def search_wikipedia(topic: str) -> str:
"""在维基百科上搜索指定主题的信息。"""
# 实现搜索逻辑
return f"关于 {topic} 的信息..."
标准工具定义
def fetch_stock_price(
symbol: str,
date: str | None = None
) -> dict[str, any]:
"""
获取股票价格信息。
参数:
- symbol: 股票代码(如 'AAPL')
- date: 查询日期(格式为 YYYY-MM-DD,可选)
返回:
包含价格、涨跌幅等信息的字典
"""
# 实现逻辑
return {
"symbol": symbol,
"price": 150.25,
"change": 2.5
}
工具最佳实践
1. 清晰的文档字符串
# ❌ 不好:文档不清晰
def process(x):
"""处理数据"""
pass
# ✅ 好:文档详细
def analyze_sentiment(text: str) -> dict[str, float]:
"""
分析文本的情感倾向。
返回包含 positive, negative, neutral 分数的字典,
每个分数在 0 到 1 之间,总和为 1。
"""
pass
2. 合理的参数粒度
# ❌ 过于复杂
def query_database(query: str, limit: int, offset: int,
cache: bool, timeout: int,
format: str, debug: bool): pass
# ✅ 简化版本
def query_database(query: str, limit: int = 10) -> list[dict]:
"""查询数据库"""
pass
3. 错误处理
def fetch_api_data(endpoint: str) -> str:
"""从 API 获取数据。"""
try:
response = requests.get(f"https://api.example.com/{endpoint}")
response.raise_for_status()
return response.text
except requests.RequestException as e:
return f"错误:无法获取数据({str(e)})"
except Exception as e:
return f"未知错误:{str(e)}"
使用工具
添加到代理
from deepagents import create_deep_agent
def add(a: int, b: int) -> int:
"""计算两个数的和"""
return a + b
def subtract(a: int, b: int) -> int:
"""计算两个数的差"""
return a - b
agent = create_deep_agent(
tools=[add, subtract],
system_prompt="你是一个数学助手"
)
result = agent.invoke({
"messages": [{"role": "user", "content": "10 加 5 是多少?"}]
})
多个工具组合
def web_search(query: str) -> str:
"""在互联网上搜索信息"""
pass
def read_document(path: str) -> str:
"""读取本地文档"""
pass
def summarize_text(text: str) -> str:
"""总结文本"""
pass
agent = create_deep_agent(
tools=[web_search, read_document, summarize_text],
system_prompt="""你是一个研究助手。
可以使用的工具:
- web_search:搜索互联网上的信息
- read_document:读取本地文档
- summarize_text:总结长文本
根据用户的需求选择合适的工具。"""
)
工具的类型
1. 数据查询工具
def query_customer_database(customer_id: int) -> dict:
"""查询客户信息数据库"""
# 模拟数据库查询
return {
"id": customer_id,
"name": "张三",
"email": "zhangsan@example.com",
"purchase_history": [...]
}
def search_products(keywords: str) -> list[dict]:
"""搜索产品"""
return [...]
2. 数据处理工具
def parse_csv(content: str) -> list[dict]:
"""解析 CSV 格式的数据"""
import csv
from io import StringIO
reader = csv.DictReader(StringIO(content))
return list(reader)
def convert_currency(amount: float, from_currency: str, to_currency: str) -> float:
"""货币换算"""
# 使用实时汇率或缓存的汇率
rates = {
"USD_CNY": 7.0,
"EUR_USD": 1.1,
}
return amount * rates.get(f"{from_currency}_{to_currency}", 1.0)
3. 外部 API 集成工具
import os
import requests
def weather_forecast(city: str, days: int = 3) -> str:
"""获取天气预报"""
api_key = os.environ.get("WEATHER_API_KEY")
response = requests.get(
f"https://api.weatherapi.com/v1/forecast.json",
params={
"key": api_key,
"q": city,
"days": days,
"aqi": "yes"
}
)
return response.json()
def send_email(to: str, subject: str, body: str) -> bool:
"""发送邮件"""
import smtplib
from email.mime.text import MIMEText
try:
msg = MIMEText(body)
msg["Subject"] = subject
msg["From"] = "agent@example.com"
msg["To"] = to
# 实际发送逻辑
return True
except Exception as e:
print(f"邮件发送失败:{e}")
return False
4. 文件操作工具
def save_report(filename: str, content: str) -> str:
"""保存报告到文件"""
try:
with open(filename, 'w', encoding='utf-8') as f:
f.write(content)
return f"报告已保存到 {filename}"
except Exception as e:
return f"保存失败:{str(e)}"
def load_config(config_file: str) -> dict:
"""加载配置文件"""
import json
with open(config_file, 'r') as f:
return json.load(f)
高级工具定义
使用枚举限制选项
from enum import Enum
class SortOrder(str, Enum):
ASCENDING = "ascending"
DESCENDING = "descending"
def list_items(
category: str,
sort_by: SortOrder = SortOrder.ASCENDING,
limit: int = 10
) -> list[dict]:
"""
列出指定分类的项目。
参数:
- category: 项目分类
- sort_by: 排序方式(ascending 或 descending)
- limit: 返回的最大项目数
"""
pass
返回复杂数据结构
from typing import TypedDict
class SearchResult(TypedDict):
title: str
url: str
snippet: str
relevance_score: float
def advanced_search(query: str, max_results: int = 5) -> list[SearchResult]:
"""执行高级搜索"""
pass
异步工具
import asyncio
async def fetch_multiple_apis(*endpoints: str) -> dict[str, any]:
"""并行调用多个 API"""
async def fetch(endpoint):
# 模拟异步请求
await asyncio.sleep(1)
return f"数据来自 {endpoint}"
results = await asyncio.gather(*[fetch(ep) for ep in endpoints])
return {endpoint: result for endpoint, result in zip(endpoints, results)}
工具的安全性考虑
参数验证
def delete_file(filepath: str) -> str:
"""删除文件(需要谨慎)"""
# 验证路径安全性
import os
from pathlib import Path
# 防止目录遍历攻击
base_dir = "/safe/directory"
target = Path(base_dir) / filepath
try:
if not str(target).startswith(base_dir):
raise ValueError("不允许访问此路径")
target.unlink()
return f"文件已删除:{filepath}"
except Exception as e:
return f"删除失败:{str(e)}"
权限检查
def modify_settings(setting_name: str, value: any) -> str:
"""修改系统设置"""
restricted_settings = {"admin_password", "api_keys", "database_url"}
if setting_name in restricted_settings:
return f"无权修改受限设置:{setting_name}"
# 执行修改
return f"设置已更新:{setting_name}"
将工具组织为类
对于相关工具的集合,可以组织为类:
class DataAnalesisTool:
"""数据分析工具集"""
def calculate_statistics(self, data: list[float]) -> dict[str, float]:
"""计算统计信息"""
import statistics
return {
"mean": statistics.mean(data),
"median": statistics.median(data),
"stdev": statistics.stdev(data),
}
def find_outliers(self, data: list[float], threshold: float = 2.0) -> list[float]:
"""找出异常值"""
mean = sum(data) / len(data)
stdev = (sum((x - mean) ** 2 for x in data) / len(data)) ** 0.5
return [x for x in data if abs((x - mean) / stdev) > threshold]
def trend_analysis(self, data: list[float]) -> str:
"""分析数据趋势"""
if len(data) < 2:
return "数据不足"
first_half_avg = sum(data[:len(data)//2]) / (len(data)//2)
last_half_avg = sum(data[len(data)//2:]) / (len(data) - len(data)//2)
if last_half_avg > first_half_avg:
return "上升趋势"
else:
return "下降趋势"
# 使用工具类
analysis = DataAnalesisTool()
agent = create_deep_agent(
tools=[
analysis.calculate_statistics,
analysis.find_outliers,
analysis.trend_analysis,
],
system_prompt="你是一个数据分析专家"
)
工具的文档生成
代理使用工具的文档字符串来理解如何使用工具。健壮的文档对代理的性能至关重要:
def transfer_funds(
from_account: str,
to_account: str,
amount: float
) -> dict[str, any]:
\"\"\"
在账户之间转账。
此操作会立即执行,且无法撤销。请在调用前确认所有参数。
参数:
- from_account: 源账户号(格式:CH-XXXXXX)
- to_account: 目标账户号(格式:CH-XXXXXX)
- amount: 转账金额,正数,单位为元
返回:
{
"success": bool,
"transaction_id": str,
"timestamp": str,
"message": str
}
示例:
transfer_funds("CH-000001", "CH-000002", 100.0)
异常情况:
- 账户不存在时返回 success=False
- 余额不足时返回 success=False
- 金额为负数时返回 success=False
\"\"\"
pass