乐于分享
好东西不私藏

LangChain 源码剖析-工具类详解(BaseTool)

LangChain 源码剖析-工具类详解(BaseTool)

– 许多人工智能应用程序通过自然语言与用户交互。然而,一些用例要求模型使用结构化输入直接与外部系统(如API、数据库或文件系统)交互。

– 工具是代理调用以执行操作的组件。它们通过定义良好的输入和输出与世界互动,扩展了模型功能。

– 工具封装了一个可调用函数及其输入模式。这些可以传递给兼容的聊天模型,允许模型决定是否调用工具以及使用什么参数。在这些场景中,工具调用使模型能够生成符合指定输入模式的请求。

创建工具

基本工具定义

– 创建工具最简单的方法是使用@tool装饰器。默认情况下,函数的docstring成为工具的描述,帮助模型理解何时使用它:

from langchain.tools import tool@tooldef search_database(query: str, limit: int = 10) -> str:    """Search the customer database for records matching the query.    Args:        query: Search terms to look for        limit: Maximum number of results to return    """    return f"Found {limit} results for '{query}'"

自定义工具名称

– 默认情况下,工具名称来自函数名称。当您需要更具描述性的内容时,请覆盖它:

@tool("web_search")  # Custom namedef search(query: str) -> str:    """Search the web for information."""    return f"Results for: {query}"print(search.name)  # web_search

自定义工具描述

– 覆盖自动生成的工具描述,以获得更清晰的模型指导:

@tool("calculator", description="Performs arithmetic calculations. Use this for any math problems.")def calc(expression: str) -> str:    """Evaluate mathematical expressions."""    return str(eval(expression))

高级模式定义

– 使用Pydantic模型定义复杂的输入:

from pydantic import BaseModel, Fieldfrom typing import Literalclass WeatherInput(BaseModel):    """Input for weather queries."""    location: str = Field(description="City name or coordinates")    units: Literal["celsius""fahrenheit"] = Field(        default="celsius",        description="Temperature unit preference"    )    include_forecast: bool = Field(        default=False,        description="Include 5-day forecast"    )@tool(args_schema=WeatherInput)def get_weather(location: str, units: str = "celsius", include_forecast: bool = False) -> str:    """Get current weather and optional forecast."""    temp = 22 if units == "celsius" else 72    result = f"Current weather in {location}{temp} degrees {units[0].upper()}"    if include_forecast:        result += "\nNext 5 days: Sunny"    return result

– 使用JSON模式定义复杂的输入:

weather_schema = {    "type""object",    "properties": {        "location": {"type""string"},        "units": {"type""string"},        "include_forecast": {"type""boolean"}    },    "required": ["location""units""include_forecast"]}@tool(args_schema=weather_schema)def get_weather(location: str, units: str = "celsius", include_forecast: bool = False) -> str:    """Get current weather and optional forecast."""    temp = 22 if units == "celsius" else 72    result = f"Current weather in {location}{temp} degrees {units[0].upper()}"    if include_forecast:        result += "\nNext 5 days: Sunny"    return result

访问上下文

– 工具可以通过ToolRuntime参数访问运行时信息,该参数提供:

– State: 在执行过程中流动的可变数据(例如,消息、计数器、自定义字段)

– Context: 不可变的配置,如用户ID、会话详细信息或特定于应用程序的配置

– Store: 跨对话的持久长期记忆

– Stream Writer: 在工具执行时流式传输自定义更新

– Config: 运行Config以执行

– Tool Call ID: 当前工具调用的ID

ToolRuntime

– 使用ToolRuntime访问单个参数中的所有运行时信息。只需将runtime:ToolRuntime添加到您的工具签名中,它就会自动注入,而不会暴露给LLM。

访问状态(Accessing state)

– 工具可以使用ToolRuntime访问当前图形状态:

from langchain.tools import tool, ToolRuntime# Access the current conversation state@tooldef summarize_conversation(    runtime: ToolRuntime) -> str:    """Summarize the conversation so far."""    messages = runtime.state["messages"]    human_msgs = sum(1 for m in messages if m.__class__.__name__ == "HumanMessage")    ai_msgs = sum(1 for m in messages if m.__class__.__name__ == "AIMessage")    tool_msgs = sum(1 for m in messages if m.__class__.__name__ == "ToolMessage")    return f"Conversation has {human_msgs} user messages, {ai_msgs} AI responses, and {tool_msgs} tool results"# Access custom state fields@tooldef get_user_preference(    pref_name: str,    runtime: ToolRuntime  # ToolRuntime parameter is not visible to the model) -> str:    """Get a user preference value."""    preferences = runtime.state.get("user_preferences", {})    return preferences.get(pref_name, "Not set")

更新状态(Updating state)

– 使用命令更新代理的状态或控制图的执行流

from langgraph.types import Commandfrom langchain.messages import RemoveMessagefrom langgraph.graph.message import REMOVE_ALL_MESSAGESfrom langchain.tools import tool, ToolRuntime# Update the conversation history by removing all messages@tooldef clear_conversation() -> Command:    """Clear the conversation history."""    return Command(        update={            "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)],        }    )# Update the user_name in the agent state@tooldef update_user_name(    new_name: str,    runtime: ToolRuntime) -> Command:    """Update the user's name."""    return Command(update={"user_name": new_name})

上下文(Context)

– 工具可以通过ToolRuntime访问运行时上下文(runtime.context)

from dataclasses import dataclassfrom langchain_openai import ChatOpenAIfrom langchain.agents import create_agentfrom langchain.tools import tool, ToolRuntimeUSER_DATABASE = {    "user123": {        "name""Alice Johnson",        "account_type""Premium",        "balance"5000,        "email""alice@example.com"    },    "user456": {        "name""Bob Smith",        "account_type""Standard",        "balance"1200,        "email""bob@example.com"    }}@dataclassclass UserContext:    user_id: str@tooldef get_account_info(runtime: ToolRuntime[UserContext]) -> str:    """Get the current user's account information."""    user_id = runtime.context.user_id    if user_id in USER_DATABASE:        user = USER_DATABASE[user_id]        return f"Account holder: {user['name']}\nType: {user['account_type']}\nBalance: ${user['balance']}"    return "User not found"model = ChatOpenAI(model="gpt-4o")agent = create_agent(    model,    tools=[get_account_info],    context_schema=UserContext,    system_prompt="You are a financial assistant.")result = agent.invoke(    {"messages": [{"role""user""content""What's my current balance?"}]},    context=UserContext(user_id="user123"))

内存/存储(Memory)

– 工具可以通过ToolRuntime访问和更新存储(runtime.store):

from typing import Anyfrom langgraph.store.memory import InMemoryStorefrom langchain.agents import create_agentfrom langchain.tools import tool, ToolRuntime# Access memory@tooldef get_user_info(user_id: str, runtime: ToolRuntime) -> str:    """Look up user info."""    store = runtime.store    user_info = store.get(("users",), user_id)    return str(user_info.value) if user_info else "Unknown user"# Update memory@tooldef save_user_info(user_id: str, user_info: dict[strAny], runtime: ToolRuntime) -> str:    """Save user info."""    store = runtime.store    store.put(("users",), user_id, user_info)    return "Successfully saved user info."store = InMemoryStore()agent = create_agent(    model,    tools=[get_user_info, save_user_info],    store=store)# First session: save user infoagent.invoke({    "messages": [{"role""user""content""Save the following user: userid: abc123, name: Foo, age: 25, email: foo@langchain.dev"}]})# Second session: get user infoagent.invoke({    "messages": [{"role""user""content""Get user info for user with id 'abc123'"}]})# Here is the user info for user with ID "abc123":# - Name: Foo# - Age: 25# - Email: foo@langchain.dev

流式编写器(Stream Writer)

– 使用runtime.Stream_writer在工具执行时流式传输自定义更新。这对于向用户提供有关工具正在做什么的实时反馈非常有用。

from langchain.tools import tool, ToolRuntime@tooldef get_weather(city: str, runtime: ToolRuntime) -> str:    """Get weather for a given city."""    writer = runtime.stream_writer    # Stream custom updates as the tool executes    writer(f"Looking up data for city: {city}")    writer(f"Acquired data for city: {city}")    return f"It's always sunny in {city}!"