手写 AI Agent(06):长期记忆
本文是《手写 AI Agent》系列教程的第 6 篇,建议按顺序阅读。
一、概述
实现跨会话的长期记忆系统,包括 SQLite 持久化、Embedding 语义检索和记忆的写入/更新/遗忘策略
读完本文,你会理解 为什么需要长期记忆?,以及它为什么这样设计。
二、核心概念
为什么需要长期记忆?
短期记忆(上下文窗口)在会话结束后就消失了。长期记忆解决三个问题:
三、代码实现
下面是本章核心代码:
“””
长期记忆系统 —— 跨会话的知识保留。
包含:
– MemoryStore: SQLite 存储引擎
– EmbeddingService: Embedding 服务(支持 OpenAI 和本地)
– LongTermMemory: 长期记忆管理器
从零手写 AI Agent 课程 · 第 6 章
“””
from __future__ import annotations
import json
import math
import os
import sqlite3
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
# ============================================================
# 记忆条目
# ============================================================
@dataclass
class Memory:
“””单条记忆”””
id: int
content: str # 记忆内容
category: str # 分类:user_pref / project_info / tech_decision / general
embedding: list[float] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
access_count: int = 0
last_accessed: float = 0.0
defto_dict(self) -> dict:
return {
“id”: self.id,
“content”: self.content,
“category”: self.category,
“created_at”: self.created_at,
“access_count”: self.access_count,
“last_accessed”: self.last_accessed,
}
# ============================================================
# Embedding 服务
# ============================================================
class EmbeddingService:
“””
Embedding 服务 —— 将文本转换为向量。
支持两种模式:
1. OpenAI API(精确,需要 API Key)
2. 简单哈希(快速,不需要 API Key,但精度较低)
“””
def__init__(
self,
provider: str = “simple”,
api_key: Optional[str] = None,
model: str = “text-embedding-3-small”,
dimensions: int = 1536,
):
“””
Args:
provider: “openai” 或 “simple”
api_key: OpenAI API Key(provider=”openai” 时需要)
model: Embedding 模型名
dimensions: 向量维度
“””
self.provider = provider
self.dimensions = dimensions
if provider == “openai”:
try:
from openai import OpenAI
api_key = api_key or os.environ.get(“DASHSCOPE_API_KEY”)
ifnot api_key:
raise ValueError(“未找到 API 密钥!请设置 DASHSCOPE_API_KEY 环境变量”)
self.client = OpenAI(
api_key=api_key,
base_url=“https://dashscope.aliyuncs.com/compatible-mode/v1”,
)
self.model = model
except ImportError:
raise ImportError(
“openai is required for OpenAI embeddings. “
“Install: pip install openai”
)
defembed(self, text: str) -> list[float]:
“””
将文本转换为向量。
Args:
text: 输入文本
Returns:
向量(list of floats)
“””
ifself.provider == “openai”:
returnself._embed_openai(text)
else:
returnself._embed_simple(text)
def_embed_openai(self, text: str) -> list[float]:
“””使用 OpenAI API 生成 Embedding”””
response = self.client.embeddings.create(
model=self.model,
input=text,
dimensions=self.dimensions,
)
return response.data[0].embedding
def_embed_simple(self, text: str) -> list[float]:
“””
简单的哈希 Embedding(仅用于演示)。
将文本哈希后映射到固定维度的向量。
注意:这不是真正的语义 Embedding,精度很低!
“””
vector = [0.0] * self.dimensions
# 使用字符的哈希值填充向量
for i, char inenumerate(text):
idx = hash(char + str(i)) % self.dimensions
vector[idx] += 1.0
# 归一化
norm = math.sqrt(sum(v * v for v in vector))
if norm > 0:
vector = [v / norm for v in vector]
return vector
# ============================================================
# 向量相似度
# ============================================================
defcosine_similarity(a: list[float], b: list[float]) -> float:
“””
计算两个向量的余弦相似度。
Returns:
相似度,范围 [-1, 1],越接近 1 越相似
“””
iflen(a) != len(b):
return 0.0
dot_product = sum(x * y for x, y inzip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
# ============================================================
# SQLite 存储
# ============================================================
class MemoryStore:
“””
基于 SQLite 的记忆存储引擎。
表结构:
– memories: 存储记忆内容和元信息
– 支持 CRUD 操作
“””
def__init__(self, db_path: str = “memory.db”):
“””
Args:
db_path: SQLite 数据库文件路径
“””
self.db_path = db_path
self._init_db()
def_init_db(self):
“””初始化数据库表”””
with sqlite3.connect(self.db_path) as conn:
conn.execute(“””
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
category TEXT NOT NULL DEFAULT ‘general’,
embedding TEXT,
created_at REAL NOT NULL,
access_count INTEGER DEFAULT 0,
last_accessed REAL DEFAULT 0
)
“””)
conn.execute(“””
CREATE INDEX IF NOT EXISTS idx_category
ON memories(category)
“””)
conn.execute(“””
CREATE INDEX IF NOT EXISTS idx_created
ON memories(created_at)
“””)
defadd(
self,
content: str,
category: str = “general”,
embedding: Optional[list[float]] = None,
) -> int:
“””
添加一条记忆。
Returns:
新记忆的 ID
“””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“””INSERT INTO memories
(content, category, embedding, created_at)
VALUES (?, ?, ?, ?)”””,
(
content,
category,
json.dumps(embedding) if embedding elseNone,
time.time(),
),
)
return cursor.lastrowid
defget(self, memory_id: int) -> Optional[Memory]:
“””获取单条记忆”””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“SELECT * FROM memories WHERE id = ?”,
(memory_id,),
)
row = cursor.fetchone()
ifnot row:
returnNone
returnMemory(
id=row[0],
content=row[1],
category=row[2],
embedding=json.loads(row[3]) if row[3] else [],
created_at=row[4],
access_count=row[5],
last_accessed=row[6],
)
defsearch_by_category(
self,
category: str,
limit: int = 10,
) -> list[Memory]:
“””按分类搜索记忆”””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“SELECT * FROM memories WHERE category = ? ORDER BY created_at DESC LIMIT ?”,
(category, limit),
)
return [
Memory(
id=row[0], content=row[1], category=row[2],
embedding=json.loads(row[3]) if row[3] else [],
created_at=row[4], access_count=row[5],
last_accessed=row[6],
)
for row in cursor.fetchall()
]
defsearch_by_keyword(
self,
keyword: str,
limit: int = 10,
) -> list[Memory]:
“””按关键词搜索记忆(SQLite FTS 回退方案)”””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“””SELECT * FROM memories
WHERE content LIKE ?
ORDER BY access_count DESC, created_at DESC
LIMIT ?”””,
(f“%{keyword}%”, limit),
)
return [
Memory(
id=row[0], content=row[1], category=row[2],
embedding=json.loads(row[3]) if row[3] else [],
created_at=row[4], access_count=row[5],
last_accessed=row[6],
)
for row in cursor.fetchall()
]
defupdate_access(self, memory_id: int):
“””更新记忆的访问计数”””
with sqlite3.connect(self.db_path) as conn:
conn.execute(
“””UPDATE memories
SET access_count = access_count + 1,
last_accessed = ?
WHERE id = ?”””,
(time.time(), memory_id),
)
defdelete(self, memory_id: int) -> bool:
“””删除一条记忆”””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“DELETE FROM memories WHERE id = ?”,
(memory_id,),
)
return cursor.rowcount > 0
defget_all(self, limit: int = 100) -> list[Memory]:
“””获取所有记忆”””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“SELECT * FROM memories ORDER BY created_at DESC LIMIT ?”,
(limit,),
)
return [
Memory(
id=row[0], content=row[1], category=row[2],
embedding=json.loads(row[3]) if row[3] else [],
created_at=row[4], access_count=row[5],
last_accessed=row[6],
)
for row in cursor.fetchall()
]
defcount(self) -> int:
“””获取记忆总数”””
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(“SELECT COUNT(*) FROM memories”)
return cursor.fetchone()[0]
defforget_old(self, max_age_days: float = 30, min_access: int = 2) -> int:
“””
遗忘策略:删除过时且很少访问的记忆。
Args:
max_age_days: 超过此天数的记忆可能被删除
min_access: 访问次数低于此值的记忆可能被删除
Returns:
删除的记忆数量
“””
cutoff = time.time() – (max_age_days * 86400)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
“””DELETE FROM memories
WHERE created_at < ? AND access_count < ?”””,
(cutoff, min_access),
)
return cursor.rowcount
# ============================================================
# 长期记忆管理器
# ============================================================
class LongTermMemory:
“””
长期记忆管理器 —— 整合存储、Embedding 和检索。
使用示例:
ltm = LongTermMemory(
db_path=”memory.db”,
embedding_provider=”simple”,
)
# 写入记忆
ltm.add(“用户喜欢用 VS Code”, category=”user_pref”)
# 检索记忆
memories = ltm.search(“编辑器偏好”)
“””
# 检索时返回的最大记忆数
MAX_RESULTS = 5
def__init__(
self,
db_path: str = “memory.db”,
embedding_provider: str = “simple”,
api_key: Optional[str] = None,
):
self.store = MemoryStore(db_path)
self.embedding = EmbeddingService(
provider=embedding_provider,
api_key=api_key,
)
defadd(self, content: str, category: str = “general”) -> int:
“””
添加一条记忆。
Args:
content: 记忆内容
category: 分类(user_pref / project_info / tech_decision / general)
Returns:
记忆 ID
“””
# 生成 Embedding
vector = self.embedding.embed(content)
# 存储
returnself.store.add(
content=content,
category=category,
embedding=vector,
)
defsearch(self, query: str, limit: Optional[int] = None) -> list[Memory]:
“””
搜索记忆。
策略:
1. 先用关键词搜索
2. 如果有 Embedding,再用语义搜索排序
Args:
query: 搜索查询
limit: 返回数量限制
Returns:
按相关性排序的记忆列表
“””
limit = limit orself.MAX_RESULTS
# 1. 关键词搜索
keyword_results = self.store.search_by_keyword(query, limit=limit * 3)
ifnot keyword_results:
return []
# 2. 如果有 Embedding,用语义相似度重新排序
query_vector = self.embedding.embed(query)
scored = []
for memory in keyword_results:
if memory.embedding:
similarity = cosine_similarity(query_vector, memory.embedding)
scored.append((similarity, memory))
else:
scored.append((0.5, memory)) # 无 Embedding 给默认分
# 按相似度排序
scored.sort(key=lambda x: x[0], reverse=True)
# 更新访问计数
results = []
for score, memory in scored[:limit]:
self.store.update_access(memory.id)
results.append(memory)
return results
defformat_for_context(self, memories: list[Memory]) -> str:
“””
将记忆格式化为可注入上下文的文本。
Args:
memories: 记忆列表
Returns:
格式化后的文本
“””
ifnot memories:
return“”
lines = [“📚 相关记忆:”]
for i, m inenumerate(memories, 1):
category_labels = {
“user_pref”: “👤 用户偏好”,
“project_info”: “📁 项目信息”,
“tech_decision”: “⚙️ 技术决策”,
“general”: “📝 通用”,
}
label = category_labels.get(m.category, “📝 通用”)
lines.append(f” {i}. [{label}] {m.content}”)
return“\n”.join(lines)
defcleanup(self, max_age_days: float = 30) -> int:
“””
清理过时记忆。
Args:
max_age_days: 删除超过此天数的低访问记忆
Returns:
删除的记忆数量
“””
returnself.store.forget_old(max_age_days=max_age_days)
defstats(self) -> dict:
“””获取记忆统计信息”””
total = self.store.count()
all_memories = self.store.get_all(limit=1000)
# 按分类统计
categories = {}
for m in all_memories:
categories[m.category] = categories.get(m.category, 0) + 1
return {
“total”: total,
“by_category”: categories,
}
上面是 Memory 类的核心实现,包含 to_dict 等关键方法。
# 在 Agent.__init__ 中添加(在 context_manager 初始化之后)
from memory.long_term import LongTermMemory
# 以下代码添加到 Agent.__init__ 方法中
self.ltm = LongTermMemory(
db_path=“memory.db”,
embedding_provider=“simple”, # 生产环境用 §§0003§§
)
# 在 chat() 方法开头添加(在 append user message 之前)
defchat(self, user_input: str) -> str:
# 🆕 检索相关长期记忆
memories = self.ltm.search(user_input)
if memories:
memory_context = self.ltm.format_for_context(memories)
# 注入记忆到系统提示(临时)
self.messages.append({
“role”: “system”,
“content”: memory_context,
})
# 原有逻辑…
self.messages.append({“role”: “user”, “content”: user_input})
# (此处为原有的对话逻辑,保持不变)
# 清理临时注入的记忆(最后一条 system 消息)
if memories:
self.messages.pop(-len(self.messages) + self.messages.index(
next(m for m inreversed(self.messages) if m[“role”] == “system”and memory_context in m[“content”])
))
return reply
defremember(self, content: str, category: str = “general”) -> int:
“””手动添加一条长期记忆”””
returnself.ltm.add(content, category=category)
上面是 chat 方法的核心逻辑。
“””测试长期记忆系统”””
import tempfile
import os
from memory.long_term import (
MemoryStore, EmbeddingService, LongTermMemory,
cosine_similarity,
)
deftest_embedding():
“””测试 Embedding 服务”””
service = EmbeddingService(provider=“simple”, dimensions=64)
v1 = service.embed(“我喜欢 Python 编程”)
v2 = service.embed(“I love coding in Python”)
v3 = service.embed(“今天天气很好”)
sim_12 = cosine_similarity(v1, v2)
sim_13 = cosine_similarity(v1, v3)
print(f“✅ 语义相近的相似度: {sim_12:.3f}”)
print(f“✅ 语义无关的相似度: {sim_13:.3f}”)
assert sim_12 > sim_13, “语义相近的文本应该更相似”
deftest_memory_store():
“””测试 SQLite 存储”””
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, “test.db”)
store = MemoryStore(db_path)
# 添加记忆
id1 = store.add(“用户喜欢用 VS Code”, “user_pref”)
id2 = store.add(“项目使用 Python 3.12”, “project_info”)
id3 = store.add(“API 使用 REST 风格”, “tech_decision”)
print(f“✅ 添加 3 条记忆,ID: {id1}, {id2}, {id3}”)
# 获取记忆
m = store.get(id1)
assert m.content == “用户喜欢用 VS Code”
print(f“✅ 获取记忆: {m.content}”)
# 关键词搜索
results = store.search_by_keyword(“Python”)
assert len(results) == 1
print(f“✅ 关键词搜索: 找到 {len(results)} 条”)
# 分类搜索
results = store.search_by_category(“user_pref”)
assert len(results) == 1
print(f“✅ 分类搜索: 找到 {len(results)} 条”)
# 统计
print(f“✅ 总记忆数: {store.count()}”)
deftest_long_term_memory():
“””测试完整的长期记忆系统”””
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, “ltm.db”)
ltm = LongTermMemory(db_path=db_path, embedding_provider=“simple”)
# 添加多条记忆
ltm.add(“用户喜欢用 VS Code 编辑器”, “user_pref”)
ltm.add(“项目使用 FastAPI 框架”, “project_info”)
ltm.add(“数据库用 PostgreSQL”, “project_info”)
ltm.add(“部署在 AWS 上”, “tech_decision”)
print(f“✅ 添加了 {ltm.stats()[‘total’]} 条记忆”)
# 搜索
results = ltm.search(“编辑器”)
print(f“✅ 搜索 ‘编辑器’: 找到 {len(results)} 条”)
for m in results:
print(f” – {m.content}”)
results = ltm.search(“数据库”)
print(f“✅ 搜索 ‘数据库’: 找到 {len(results)} 条”)
for m in results:
print(f” – {m.content}”)
# 格式化
formatted = ltm.format_for_context(results)
print(f“\n✅ 格式化输出:\n{formatted}”)
if __name__ == “__main__”:
print(“=” * 50)
print(“Embedding 测试”)
print(“=” * 50)
test_embedding()
print(“\n” + “=” * 50)
print(“MemoryStore 测试”)
print(“=” * 50)
test_memory_store()
print(“\n” + “=” * 50)
print(“LongTermMemory 测试”)
print(“=” * 50)
test_long_term_memory()
print(“\n🎉 全部测试通过!”)
上面是 test_embedding 方法的核心逻辑。
defformat_with_age(self, memory: Memory) -> str:
age = time.time() – memory.last_accessed
if age < 3600:
return f“(刚刚更新)”
elif age < 86400:
return f“({age // 3600:.0f}小时前更新)”
else:
return f“({age // 86400:.0f}天前更新)”
上面是 format_with_age 方法的核心逻辑。
defauto_extract(self, conversation: str) -> list[str]:
“””用 LLM 从对话中提取值得记忆的信息”””
prompt = f“从以下对话中提取值得长期记住的信息(用户偏好、项目信息等):\n{conversation}\n\n只输出需要记忆的内容,每行一条。”
# 调用 LLM…
上面是 auto_extract 方法的核心逻辑。
四、注意事项
运行代码时,可能会遇到几个问题:
ImportError: openai is required for OpenAI embeddings使用 provider="openai" 但没有安装 openai 包。解决方式:pip install openai,或者使用 provider="simple"(仅用于测试)。
simple 模式的哈希 Embedding 不是真正的语义向量。生产环境使用 OpenAI 的 text-embedding-3-small 模型:
ltm = LongTermMemory(
embedding_provider=“openai”,
api_key=“sk-…”,
)
db_path 使用了相对路径,但运行目录不同。使用绝对路径:
import os
db_path = os.path.join(os.path.dirname(__file__), “..”, “memory.db”)
五、总结
本文内容可以总结为以下几点:
-
| 将文本转换为向量(支持 OpenAI 和简单哈希) |
-
| SQLite 持久化存储 |
-
| 整合存储、Embedding 和语义检索 |
-
| 自动清理过时且低访问的记忆 |
六、参考链接
-
OpenAI API 文档 -
Python 官方文档 -
本系列源码:关注后回复 “Agent” 获取
(完)
本文属于《手写 AI Agent》系列教程的第 6 篇。
如果你对本文的代码感兴趣,可以在公众号回复 “Agent6″,获取完整源码和课后练习。
下一篇文章会讲 让 AI 同时做多件事。
夜雨聆风