乐于分享
好东西不私藏

手写 AI Agent(06):长期记忆

手写 AI Agent(06):长期记忆

本文是《手写 AI Agent》系列教程的第 6 篇,建议按顺序阅读。


一、概述

实现跨会话的长期记忆系统,包括 SQLite 持久化、Embedding 语义检索和记忆的写入/更新/遗忘策略

读完本文,你会理解 为什么需要长期记忆?,以及它为什么这样设计。


二、核心概念

为什么需要长期记忆?

短期记忆(上下文窗口)在会话结束后就消失了。长期记忆解决三个问题:


三、代码实现

下面是本章核心代码:

“””

长期记忆系统 —— 跨会话的知识保留。

包含:

– MemoryStore: SQLite 存储引擎

– EmbeddingService: Embedding 服务(支持 OpenAI 和本地)

– LongTermMemory: 长期记忆管理器

从零手写 AI Agent 课程 · 第 6 章

“””

from __future__ import annotations

import json

import math

import os

import sqlite3

import time

from dataclasses import dataclass, field

from pathlib import Path

from typing import Optional

# ============================================================

# 记忆条目

# ============================================================

@dataclass

class Memory:

“””单条记忆”””

    id: int

    content: str          # 记忆内容

    category: str         # 分类:user_pref / project_info / tech_decision / general

    embedding: list[float] = field(default_factory=list)

    created_at: float = field(default_factory=time.time)

    access_count: int = 0

    last_accessed: float = 0.0

defto_dict(self) -> dict:

return {

“id”self.id,

“content”self.content,

“category”self.category,

“created_at”self.created_at,

“access_count”self.access_count,

“last_accessed”self.last_accessed,

        }

# ============================================================

# Embedding 服务

# ============================================================

class EmbeddingService:

“””

    Embedding 服务 —— 将文本转换为向量。

    支持两种模式:

    1. OpenAI API(精确,需要 API Key)

    2. 简单哈希(快速,不需要 API Key,但精度较低)

    “””

def__init__(

self,

        provider: str = “simple”,

        api_key: Optional[str] = None,

        model: str = “text-embedding-3-small”,

        dimensions: int = 1536,

    ):

“””

        Args:

            provider: “openai” 或 “simple”

            api_key: OpenAI API Key(provider=”openai” 时需要)

            model: Embedding 模型名

            dimensions: 向量维度

        “””

self.provider = provider

self.dimensions = dimensions

if provider == “openai”:

try:

from openai import OpenAI

                api_key = api_key or os.environ.get(“DASHSCOPE_API_KEY”)

ifnot api_key:

                    raise ValueError(“未找到 API 密钥!请设置 DASHSCOPE_API_KEY 环境变量”)

self.client = OpenAI(

                    api_key=api_key,

                    base_url=“https://dashscope.aliyuncs.com/compatible-mode/v1”,

                )

self.model = model

except ImportError:

                raise ImportError(

“openai is required for OpenAI embeddings. “

“Install: pip install openai”

                )

defembed(self, text: str) -> list[float]:

“””

        将文本转换为向量。

        Args:

            text: 输入文本

        Returns:

            向量(list of floats)

        “””

ifself.provider == “openai”:

returnself._embed_openai(text)

else:

returnself._embed_simple(text)

def_embed_openai(self, text: str) -> list[float]:

“””使用 OpenAI API 生成 Embedding”””

        response = self.client.embeddings.create(

            model=self.model,

            input=text,

            dimensions=self.dimensions,

        )

return response.data[0].embedding

def_embed_simple(self, text: str) -> list[float]:

“””

        简单的哈希 Embedding(仅用于演示)。

        将文本哈希后映射到固定维度的向量。

        注意:这不是真正的语义 Embedding,精度很低!

        “””

        vector = [0.0] * self.dimensions

# 使用字符的哈希值填充向量

for i, char inenumerate(text):

            idx = hash(char + str(i)) % self.dimensions

            vector[idx] += 1.0

# 归一化

        norm = math.sqrt(sum(v * v for v in vector))

if norm > 0:

            vector = [v / norm for v in vector]

return vector

# ============================================================

# 向量相似度

# ============================================================

defcosine_similarity(a: list[float], b: list[float]) -> float:

“””

    计算两个向量的余弦相似度。

    Returns:

        相似度,范围 [-1, 1],越接近 1 越相似

    “””

iflen(a) != len(b):

return 0.0

    dot_product = sum(x * y for x, y inzip(a, b))

    norm_a = math.sqrt(sum(x * x for x in a))

    norm_b = math.sqrt(sum(x * x for x in b))

if norm_a == 0 or norm_b == 0:

return 0.0

return dot_product / (norm_a * norm_b)

# ============================================================

# SQLite 存储

# ============================================================

class MemoryStore:

“””

    基于 SQLite 的记忆存储引擎。

    表结构:

    – memories: 存储记忆内容和元信息

    – 支持 CRUD 操作

    “””

def__init__(self, db_path: str = “memory.db”):

“””

        Args:

            db_path: SQLite 数据库文件路径

        “””

self.db_path = db_path

self._init_db()

def_init_db(self):

“””初始化数据库表”””

with sqlite3.connect(self.db_path) as conn:

            conn.execute(“””

                CREATE TABLE IF NOT EXISTS memories (

                    id INTEGER PRIMARY KEY AUTOINCREMENT,

                    content TEXT NOT NULL,

                    category TEXT NOT NULL DEFAULT ‘general’,

                    embedding TEXT,

                    created_at REAL NOT NULL,

                    access_count INTEGER DEFAULT 0,

                    last_accessed REAL DEFAULT 0

                )

            “””)

            conn.execute(“””

                CREATE INDEX IF NOT EXISTS idx_category 

                ON memories(category)

            “””)

            conn.execute(“””

                CREATE INDEX IF NOT EXISTS idx_created 

                ON memories(created_at)

            “””)

defadd(

self,

        content: str,

        category: str = “general”,

        embedding: Optional[list[float]] = None,

    ) -> int:

“””

        添加一条记忆。

        Returns:

            新记忆的 ID

        “””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“””INSERT INTO memories 

                   (content, category, embedding, created_at) 

                   VALUES (?, ?, ?, ?)”””,

                (

                    content,

                    category,

                    json.dumps(embedding) if embedding elseNone,

                    time.time(),

                ),

            )

return cursor.lastrowid

defget(self, memory_id: int) -> Optional[Memory]:

“””获取单条记忆”””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“SELECT * FROM memories WHERE id = ?”,

                (memory_id,),

            )

            row = cursor.fetchone()

ifnot row:

returnNone

returnMemory(

                id=row[0],

                content=row[1],

                category=row[2],

                embedding=json.loads(row[3]) if row[3] else [],

                created_at=row[4],

                access_count=row[5],

                last_accessed=row[6],

            )

defsearch_by_category(

self,

        category: str,

        limit: int = 10,

    ) -> list[Memory]:

“””按分类搜索记忆”””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“SELECT * FROM memories WHERE category = ? ORDER BY created_at DESC LIMIT ?”,

                (category, limit),

            )

return [

Memory(

                    id=row[0], content=row[1], category=row[2],

                    embedding=json.loads(row[3]) if row[3] else [],

                    created_at=row[4], access_count=row[5],

                    last_accessed=row[6],

                )

for row in cursor.fetchall()

            ]

defsearch_by_keyword(

self,

        keyword: str,

        limit: int = 10,

    ) -> list[Memory]:

“””按关键词搜索记忆(SQLite FTS 回退方案)”””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“””SELECT * FROM memories 

                   WHERE content LIKE ? 

                   ORDER BY access_count DESC, created_at DESC 

                   LIMIT ?”””,

                (f“%{keyword}%”, limit),

            )

return [

Memory(

                    id=row[0], content=row[1], category=row[2],

                    embedding=json.loads(row[3]) if row[3] else [],

                    created_at=row[4], access_count=row[5],

                    last_accessed=row[6],

                )

for row in cursor.fetchall()

            ]

defupdate_access(self, memory_id: int):

“””更新记忆的访问计数”””

with sqlite3.connect(self.db_path) as conn:

            conn.execute(

“””UPDATE memories 

                   SET access_count = access_count + 1,

                       last_accessed = ? 

                   WHERE id = ?”””,

                (time.time(), memory_id),

            )

defdelete(self, memory_id: int) -> bool:

“””删除一条记忆”””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“DELETE FROM memories WHERE id = ?”,

                (memory_id,),

            )

return cursor.rowcount > 0

defget_all(self, limit: int = 100) -> list[Memory]:

“””获取所有记忆”””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“SELECT * FROM memories ORDER BY created_at DESC LIMIT ?”,

                (limit,),

            )

return [

Memory(

                    id=row[0], content=row[1], category=row[2],

                    embedding=json.loads(row[3]) if row[3] else [],

                    created_at=row[4], access_count=row[5],

                    last_accessed=row[6],

                )

for row in cursor.fetchall()

            ]

defcount(self) -> int:

“””获取记忆总数”””

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(“SELECT COUNT(*) FROM memories”)

return cursor.fetchone()[0]

defforget_old(self, max_age_days: float = 30, min_access: int = 2) -> int:

“””

        遗忘策略:删除过时且很少访问的记忆。

        Args:

            max_age_days: 超过此天数的记忆可能被删除

            min_access: 访问次数低于此值的记忆可能被删除

        Returns:

            删除的记忆数量

        “””

        cutoff = time.time() – (max_age_days * 86400)

with sqlite3.connect(self.db_path) as conn:

            cursor = conn.execute(

“””DELETE FROM memories 

                   WHERE created_at < ? AND access_count < ?”””,

                (cutoff, min_access),

            )

return cursor.rowcount

# ============================================================

# 长期记忆管理器

# ============================================================

class LongTermMemory:

“””

    长期记忆管理器 —— 整合存储、Embedding 和检索。

    使用示例:

        ltm = LongTermMemory(

            db_path=”memory.db”,

            embedding_provider=”simple”,

        )

        # 写入记忆

        ltm.add(“用户喜欢用 VS Code”, category=”user_pref”)

        # 检索记忆

        memories = ltm.search(“编辑器偏好”)

    “””

# 检索时返回的最大记忆数

    MAX_RESULTS = 5

def__init__(

self,

        db_path: str = “memory.db”,

        embedding_provider: str = “simple”,

        api_key: Optional[str] = None,

    ):

self.store = MemoryStore(db_path)

self.embedding = EmbeddingService(

            provider=embedding_provider,

            api_key=api_key,

        )

defadd(self, content: str, category: str = “general”) -> int:

“””

        添加一条记忆。

        Args:

            content: 记忆内容

            category: 分类(user_pref / project_info / tech_decision / general)

        Returns:

            记忆 ID

        “””

# 生成 Embedding

        vector = self.embedding.embed(content)

# 存储

returnself.store.add(

            content=content,

            category=category,

            embedding=vector,

        )

defsearch(self, query: str, limit: Optional[int] = None) -> list[Memory]:

“””

        搜索记忆。

        策略:

        1. 先用关键词搜索

        2. 如果有 Embedding,再用语义搜索排序

        Args:

            query: 搜索查询

            limit: 返回数量限制

        Returns:

            按相关性排序的记忆列表

        “””

        limit = limit orself.MAX_RESULTS

# 1. 关键词搜索

        keyword_results = self.store.search_by_keyword(query, limit=limit * 3)

ifnot keyword_results:

return []

# 2. 如果有 Embedding,用语义相似度重新排序

        query_vector = self.embedding.embed(query)

        scored = []

for memory in keyword_results:

if memory.embedding:

                similarity = cosine_similarity(query_vector, memory.embedding)

                scored.append((similarity, memory))

else:

                scored.append((0.5, memory))  # 无 Embedding 给默认分

# 按相似度排序

        scored.sort(key=lambda x: x[0], reverse=True)

# 更新访问计数

        results = []

for score, memory in scored[:limit]:

self.store.update_access(memory.id)

            results.append(memory)

return results

defformat_for_context(self, memories: list[Memory]) -> str:

“””

        将记忆格式化为可注入上下文的文本。

        Args:

            memories: 记忆列表

        Returns:

            格式化后的文本

        “””

ifnot memories:

return“”

        lines = [“📚 相关记忆:”]

for i, m inenumerate(memories, 1):

            category_labels = {

“user_pref”“👤 用户偏好”,

“project_info”“📁 项目信息”,

“tech_decision”“⚙️ 技术决策”,

“general”“📝 通用”,

            }

            label = category_labels.get(m.category, “📝 通用”)

            lines.append(f”  {i}. [{label}] {m.content}”)

return“\n”.join(lines)

defcleanup(self, max_age_days: float = 30) -> int:

“””

        清理过时记忆。

        Args:

            max_age_days: 删除超过此天数的低访问记忆

        Returns:

            删除的记忆数量

        “””

returnself.store.forget_old(max_age_days=max_age_days)

defstats(self) -> dict:

“””获取记忆统计信息”””

        total = self.store.count()

        all_memories = self.store.get_all(limit=1000)

# 按分类统计

        categories = {}

for m in all_memories:

            categories[m.category] = categories.get(m.category, 0) + 1

return {

“total”: total,

“by_category”: categories,

        }

上面是 Memory 类的核心实现,包含 to_dict 等关键方法。

# 在 Agent.__init__ 中添加(在 context_manager 初始化之后)

from memory.long_term import LongTermMemory

# 以下代码添加到 Agent.__init__ 方法中

self.ltm = LongTermMemory(

    db_path=“memory.db”,

    embedding_provider=“simple”,  # 生产环境用 §§0003§§

)

# 在 chat() 方法开头添加(在 append user message 之前)

defchat(self, user_input: str) -> str:

# 🆕 检索相关长期记忆

    memories = self.ltm.search(user_input)

if memories:

        memory_context = self.ltm.format_for_context(memories)

# 注入记忆到系统提示(临时)

self.messages.append({

“role”“system”,

“content”: memory_context,

        })

# 原有逻辑…

self.messages.append({“role”“user”“content”: user_input})

# (此处为原有的对话逻辑,保持不变)

# 清理临时注入的记忆(最后一条 system 消息)

if memories:

self.messages.pop(-len(self.messages) + self.messages.index(

next(m for m inreversed(self.messages) if m[“role”] == “system”and memory_context in m[“content”])

        ))

return reply

defremember(self, content: str, category: str = “general”) -> int:

“””手动添加一条长期记忆”””

returnself.ltm.add(content, category=category)

上面是 chat 方法的核心逻辑。

“””测试长期记忆系统”””

import tempfile

import os

from memory.long_term import (

    MemoryStore, EmbeddingService, LongTermMemory,

    cosine_similarity,

)

deftest_embedding():

“””测试 Embedding 服务”””

    service = EmbeddingService(provider=“simple”, dimensions=64)

    v1 = service.embed(“我喜欢 Python 编程”)

    v2 = service.embed(“I love coding in Python”)

    v3 = service.embed(“今天天气很好”)

    sim_12 = cosine_similarity(v1, v2)

    sim_13 = cosine_similarity(v1, v3)

print(f“✅ 语义相近的相似度: {sim_12:.3f}”)

print(f“✅ 语义无关的相似度: {sim_13:.3f}”)

    assert sim_12 > sim_13, “语义相近的文本应该更相似”

deftest_memory_store():

“””测试 SQLite 存储”””

with tempfile.TemporaryDirectory() as tmpdir:

        db_path = os.path.join(tmpdir, “test.db”)

        store = MemoryStore(db_path)

# 添加记忆

        id1 = store.add(“用户喜欢用 VS Code”“user_pref”)

        id2 = store.add(“项目使用 Python 3.12”“project_info”)

        id3 = store.add(“API 使用 REST 风格”“tech_decision”)

print(f“✅ 添加 3 条记忆,ID: {id1}, {id2}, {id3}”)

# 获取记忆

        m = store.get(id1)

        assert m.content == “用户喜欢用 VS Code”

print(f“✅ 获取记忆: {m.content}”)

# 关键词搜索

        results = store.search_by_keyword(“Python”)

        assert len(results) == 1

print(f“✅ 关键词搜索: 找到 {len(results)} 条”)

# 分类搜索

        results = store.search_by_category(“user_pref”)

        assert len(results) == 1

print(f“✅ 分类搜索: 找到 {len(results)} 条”)

# 统计

print(f“✅ 总记忆数: {store.count()}”)

deftest_long_term_memory():

“””测试完整的长期记忆系统”””

with tempfile.TemporaryDirectory() as tmpdir:

        db_path = os.path.join(tmpdir, “ltm.db”)

        ltm = LongTermMemory(db_path=db_path, embedding_provider=“simple”)

# 添加多条记忆

        ltm.add(“用户喜欢用 VS Code 编辑器”“user_pref”)

        ltm.add(“项目使用 FastAPI 框架”“project_info”)

        ltm.add(“数据库用 PostgreSQL”“project_info”)

        ltm.add(“部署在 AWS 上”“tech_decision”)

print(f“✅ 添加了 {ltm.stats()[‘total’]} 条记忆”)

# 搜索

        results = ltm.search(“编辑器”)

print(f“✅ 搜索 ‘编辑器’: 找到 {len(results)} 条”)

for m in results:

print(f”   – {m.content}”)

        results = ltm.search(“数据库”)

print(f“✅ 搜索 ‘数据库’: 找到 {len(results)} 条”)

for m in results:

print(f”   – {m.content}”)

# 格式化

        formatted = ltm.format_for_context(results)

print(f“\n✅ 格式化输出:\n{formatted}”)

if __name__ == “__main__”:

print(“=” * 50)

print(“Embedding 测试”)

print(“=” * 50)

test_embedding()

print(“\n” + “=” * 50)

print(“MemoryStore 测试”)

print(“=” * 50)

test_memory_store()

print(“\n” + “=” * 50)

print(“LongTermMemory 测试”)

print(“=” * 50)

test_long_term_memory()

print(“\n🎉 全部测试通过!”)

上面是 test_embedding 方法的核心逻辑。

defformat_with_age(self, memory: Memory) -> str:

    age = time.time() – memory.last_accessed

if age < 3600:

return f“(刚刚更新)”

elif age < 86400:

return f“({age // 3600:.0f}小时前更新)”

else:

return f“({age // 86400:.0f}天前更新)”

上面是 format_with_age 方法的核心逻辑。

defauto_extract(self, conversation: str) -> list[str]:

“””用 LLM 从对话中提取值得记忆的信息”””

    prompt = f“从以下对话中提取值得长期记住的信息(用户偏好、项目信息等):\n{conversation}\n\n只输出需要记忆的内容,每行一条。”

# 调用 LLM…

上面是 auto_extract 方法的核心逻辑。


四、注意事项

运行代码时,可能会遇到几个问题:

错误 1:ImportError: openai is required for OpenAI embeddings

使用 provider="openai" 但没有安装 openai 包。解决方式:pip install openai,或者使用 provider="simple"(仅用于测试)。

错误 2:语义搜索精度很低simple 模式的哈希 Embedding 不是真正的语义向量。

生产环境使用 OpenAI 的 text-embedding-3-small 模型:

ltm = LongTermMemory(

    embedding_provider=“openai”,

    api_key=“sk-…”,

)

错误 3:记忆数据库文件找不到db_path 使用了相对路径,但运行目录不同。

使用绝对路径:

import os

db_path = os.path.join(os.path.dirname(__file__), “..”“memory.db”)



五、总结

本文内容可以总结为以下几点:

  1. | 将文本转换为向量(支持 OpenAI 和简单哈希) |
  1. | SQLite 持久化存储 |
  1. | 整合存储、Embedding 和语义检索 |
  1. | 自动清理过时且低访问的记忆 |

六、参考链接

  • OpenAI API 文档
  • Python 官方文档
  • 本系列源码:关注后回复 “Agent” 获取

(完)

本文属于《手写 AI Agent》系列教程的第 6 篇。

如果你对本文的代码感兴趣,可以在公众号回复 “Agent6″,获取完整源码和课后练习。

下一篇文章会讲 让 AI 同时做多件事。