Skip to content

第7章:记忆与上下文管理

没有记忆的 Agent 就像只有 7 秒记忆的金鱼——每次对话都从零开始,无法积累经验、无法理解上下文


7.1 短期记忆 vs 长期记忆

7.1.1 记忆的层次模型

Agent 的记忆系统借鉴了人类认知科学的模型,分为多个层次:

┌─────────────────────────────────────────────────────────┐
│                     记忆金字塔                            │
│                                                          │
│                    ┌──────────┐                          │
│                    │ 感知记忆  │  ← 当前输入的原始信息       │
│                    │ (Sensory)│     持续:毫秒级           │
│                    └────┬─────┘                          │
│                         ▼                                │
│                  ┌──────────────┐                         │
│                  │  工作记忆     │  ← 当前对话上下文         │
│                  │ (Working)    │     持续:分钟级           │
│                  │              │     容量:有限            │
│                  └──────┬──────┘                         │
│                         ▼                                │
│              ┌─────────────────────┐                      │
│              │    短期记忆          │  ← 当前会话历史         │
│              │ (Short-term)        │     持续:小时级         │
│              │                     │     容量:中等          │
│              └──────────┬──────────┘                      │
│                         ▼                                │
│              ┌─────────────────────┐                      │
│              │    长期记忆          │  ← 跨会话持久化知识     │
│              │ (Long-term)         │     持续:永久          │
│              │                     │     容量:大            │
│              └─────────────────────┘                      │
└─────────────────────────────────────────────────────────┘

7.1.2 各层记忆的实现

python
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from enum import Enum

class MemoryType(Enum):
    WORKING = "working"       # 工作记忆:当前推理上下文
    SHORT_TERM = "short"     # 短期记忆:当前会话
    LONG_TERM = "long"       # 长期记忆:跨会话持久化

@dataclass
class MemoryItem:
    """记忆条目"""
    content: str                    # 记忆内容
    memory_type: MemoryType         # 记忆类型
    timestamp: datetime = field(default_factory=datetime.now)
    importance: float = 0.5         # 重要性评分 (0-1)
    metadata: dict = field(default_factory=dict)
    access_count: int = 0           # 被访问次数
    last_accessed: datetime = field(default_factory=datetime.now)
    embedding: list[float] | None = None  # 向量嵌入(用于检索)
    source: str = ""                # 记忆来源
    expires_at: datetime | None = None    # 过期时间
    
    def touch(self):
        """更新访问时间"""
        self.access_count += 1
        self.last_accessed = datetime.now()
    
    @property
    def is_expired(self) -> bool:
        if self.expires_at is None:
            return False
        return datetime.now() > self.expires_at
    
    @property
    def age_hours(self) -> float:
        return (datetime.now() - self.timestamp).total_seconds() / 3600

7.1.3 短期记忆管理

短期记忆存储当前会话的上下文,直接作为 LLM 的输入:

python
class ShortTermMemory:
    """短期记忆——对话上下文管理"""
    
    def __init__(self, max_messages: int = 50, max_tokens: int = 8000):
        self.max_messages = max_messages
        self.max_tokens = max_tokens
        self.messages: list[dict] = []
        self._token_counter = lambda text: len(text) // 2  # 简化计数
    
    def add_message(self, role: str, content: str):
        """添加消息"""
        self.messages.append({
            "role": role,
            "content": content,
            "timestamp": datetime.now().isoformat()
        })
        self._trim()
    
    def add_system(self, content: str):
        """添加系统消息(始终保留)"""
        # 系统消息插入到最前面
        if self.messages and self.messages[0]["role"] == "system":
            self.messages[0]["content"] = content
        else:
            self.messages.insert(0, {
                "role": "system",
                "content": content,
                "timestamp": datetime.now().isoformat()
            })
    
    def get_context(self) -> list[dict]:
        """获取当前上下文"""
        return [
            {"role": m["role"], "content": m["content"]}
            for m in self.messages
        ]
    
    def _trim(self):
        """裁剪上下文"""
        # 保留系统消息
        system_msgs = [m for m in self.messages if m["role"] == "system"]
        other_msgs = [m for m in self.messages if m["role"] != "system"]
        
        # 按消息数限制裁剪
        if len(other_msgs) > self.max_messages:
            other_msgs = other_msgs[-self.max_messages:]
        
        # 按 Token 数限制裁剪
        total_tokens = sum(self._token_counter(m["content"]) for m in other_msgs)
        while total_tokens > self.max_tokens and len(other_msgs) > 2:
            removed = other_msgs.pop(0)
            total_tokens -= self._token_counter(removed["content"])
        
        self.messages = system_msgs + other_msgs
    
    def get_summary(self) -> str:
        """获取对话摘要"""
        if not self.messages:
            return "暂无对话历史"
        
        user_msgs = [m for m in self.messages if m["role"] == "user"]
        assistant_msgs = [m for m in self.messages if m["role"] == "assistant"]
        
        return f"""对话统计:
- 总消息数:{len(self.messages)}
- 用户消息:{len(user_msgs)}
- 助手回复:{len(assistant_msgs)}
- 首条消息:{self.messages[0].get('timestamp', 'N/A')}
- 最新消息:{self.messages[-1].get('timestamp', 'N/A')}"""

7.1.4 长期记忆管理

长期记忆需要持久化存储,并支持语义检索:

python
class LongTermMemory:
    """长期记忆——跨会话持久化"""
    
    def __init__(self, storage_backend: Any = None):
        self.memories: list[MemoryItem] = []
        self.storage = storage_backend  # 可以是文件、数据库等
        self._embedder = None           # 嵌入模型
    
    def store(self, content: str, importance: float = 0.5, **metadata):
        """存储记忆"""
        item = MemoryItem(
            content=content,
            memory_type=MemoryType.LONG_TERM,
            importance=importance,
            metadata=metadata
        )
        
        # 生成嵌入向量
        if self._embedder:
            item.embedding = self._embedder.embed(content)
        
        self.memories.append(item)
        
        # 持久化
        if self.storage:
            self.storage.save(item)
        
        return item
    
    def recall(self, query: str, top_k: int = 5, threshold: float = 0.3) -> list[MemoryItem]:
        """检索相关记忆"""
        if not self._embedder:
            # 无嵌入模型时,使用简单的关键词匹配
            return self._keyword_search(query, top_k)
        
        # 向量检索
        query_embedding = self._embedder.embed(query)
        
        scored = []
        for item in self.memories:
            if item.embedding is None:
                continue
            
            similarity = self._cosine_similarity(query_embedding, item.embedding)
            if similarity >= threshold:
                scored.append((similarity, item))
        
        scored.sort(reverse=True)
        
        results = [item for _, item in scored[:top_k]]
        
        # 更新访问记录
        for item in results:
            item.touch()
        
        return results
    
    def forget(self, criteria: callable):
        """遗忘——删除满足条件的记忆"""
        self.memories = [
            item for item in self.memories
            if not criteria(item)
        ]
    
    def consolidate(self):
        """记忆整合——合并相似记忆,删除冗余"""
        if not self._embedder:
            return
        
        # 找到相似度高的记忆对
        to_remove = set()
        for i in range(len(self.memories)):
            for j in range(i + 1, len(self.memories)):
                if j in to_remove:
                    continue
                
                mi, mj = self.memories[i], self.memories[j]
                if mi.embedding and mj.embedding:
                    sim = self._cosine_similarity(mi.embedding, mj.embedding)
                    
                    if sim > 0.9:  # 高度相似
                        # 保留更重要的那个
                        if mj.importance > mi.importance:
                            to_remove.add(i)
                        else:
                            to_remove.add(j)
        
        self.memories = [
            item for i, item in enumerate(self.memories)
            if i not in to_remove
        ]
    
    def _keyword_search(self, query: str, top_k: int) -> list[MemoryItem]:
        """关键词搜索(降级方案)"""
        query_words = set(query.lower().split())
        
        scored = []
        for item in self.memories:
            content_words = set(item.content.lower().split())
            overlap = len(query_words & content_words)
            if overlap > 0:
                scored.append((overlap, item))
        
        scored.sort(reverse=True)
        return [item for _, item in scored[:top_k]]
    
    @staticmethod
    def _cosine_similarity(a: list[float], b: list[float]) -> float:
        """计算余弦相似度"""
        import math
        dot = sum(x * y for x, y in zip(a, b))
        norm_a = math.sqrt(sum(x * x for x in a))
        norm_b = math.sqrt(sum(x * x for x in b))
        return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0

7.2 上下文窗口管理策略

7.2.1 上下文窗口的挑战

LLM 的上下文窗口是有限的(尽管在不断扩大),但 Agent 在运行中很容易超出限制:

上下文窗口组成:
┌─────────────────────────────────────────────┐
│ System Prompt        (~500-2000 tokens)     │
├─────────────────────────────────────────────┤
│ 工具定义              (~1000-5000 tokens)    │  ← 每+1个工具约+200-500 tokens
├─────────────────────────────────────────────┤
│ 对话历史              (~动态增长)            │  ← 最大威胁!
├─────────────────────────────────────────────┤
│ 工具调用结果          (~动态增长)            │
├─────────────────────────────────────────────┤
│ 检索到的记忆          (~1000-3000 tokens)    │
├─────────────────────────────────────────────┤
│ 预留给输出的空间      (~2000-4000 tokens)    │
└─────────────────────────────────────────────┘

7.2.2 滑动窗口策略

最简单的策略——保留最近的 N 条消息:

python
class SlidingWindowManager:
    """滑动窗口上下文管理"""
    
    def __init__(
        self,
        max_tokens: int,
        system_prompt_tokens: int,
        reserve_for_output: int = 4096
    ):
        self.max_tokens = max_tokens
        self.system_prompt_tokens = system_prompt_tokens
        self.reserve_for_output = reserve_for_output
        self.available_for_context = (
            max_tokens - system_prompt_tokens - reserve_for_output
        )
    
    def select_messages(
        self,
        messages: list[dict],
        token_counter: callable
    ) -> list[dict]:
        """选择要保留的消息"""
        # 始终保留系统消息
        system = [m for m in messages if m["role"] == "system"]
        others = [m for m in messages if m["role"] != "system"]
        
        # 从最新开始,向前添加,直到 Token 用完
        selected = []
        used_tokens = 0
        
        for msg in reversed(others):
            msg_tokens = token_counter(msg["content"])
            
            if used_tokens + msg_tokens > self.available_for_context:
                break
            
            selected.insert(0, msg)
            used_tokens += msg_tokens
        
        return system + selected
    
    def utilization(self, messages: list[dict], token_counter: callable) -> float:
        """计算上下文利用率"""
        total = sum(token_counter(m["content"]) for m in messages)
        return total / self.max_tokens

7.2.3 摘要压缩策略

当对话历史太长时,将早期对话压缩为摘要:

python
class SummaryCompressor:
    """摘要压缩器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def compress(
        self,
        messages: list[dict],
        max_summary_tokens: int = 500
    ) -> tuple[list[dict], str]:
        """
        压缩对话历史
        返回:(保留的近期消息, 历史摘要)
        """
        if len(messages) <= 6:
            return messages, ""
        
        # 分为早期和近期
        early = messages[:-4]  # 保留最近4条不压缩
        recent = messages[-4:]
        
        # 生成早期对话摘要
        conversation_text = "\n".join(
            f"{'用户' if m['role'] == 'user' else '助手'}: {m['content']}"
            for m in early
        )
        
        summary_prompt = f"""请将以下对话历史压缩为简洁的摘要。
保留关键信息:讨论的主题、做出的决定、重要的数据。

对话历史:
{conversation_text}

请用 {max_summary_tokens} 字以内的中文概括。"""

        summary = self.llm.chat(
            messages=[{"role": "user", "content": summary_prompt}],
            temperature=0.1
        ).content
        
        # 用摘要替代早期消息
        summary_message = {
            "role": "system",
            "content": f"[之前的对话摘要]\n{summary}"
        }
        
        return [summary_message] + recent, summary
    
    def incremental_summarize(
        self,
        existing_summary: str,
        new_messages: list[dict]
    ) -> str:
        """增量更新摘要"""
        if not new_messages:
            return existing_summary
        
        new_text = "\n".join(
            f"{'用户' if m['role'] == 'user' else '助手'}: {m['content']}"
            for m in new_messages
        )
        
        prompt = f"""现有摘要:
{existing_summary if existing_summary else "(无)"}

新的对话内容:
{new_text}

请更新摘要,整合新信息,保持简洁(300字以内)。"""
        
        return self.llm.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1
        ).content

7.2.4 优先级保留策略

不是所有消息同等重要——系统消息、工具结果摘要、用户的关键指令应该优先保留:

python
@dataclass
class MessagePriority:
    """消息优先级"""
    message: dict
    priority: float = 0.5  # 0-1
    tokens: int = 0

class PriorityBasedManager:
    """基于优先级的上下文管理"""
    
    # 优先级规则
    PRIORITY_RULES = {
        "system": 1.0,           # 系统消息:最高
        "tool_result": 0.8,      # 工具结果:高
        "tool_call": 0.7,        # 工具调用:高
        "user": 0.6,             # 用户消息:中高
        "assistant": 0.5,        # 助手回复:中
        "error": 0.3,            # 错误消息:低
    }
    
    def __init__(self, max_tokens: int):
        self.max_tokens = max_tokens
    
    def classify_message(self, message: dict) -> str:
        """分类消息类型"""
        role = message.get("role", "")
        
        if role == "system":
            return "system"
        elif role == "tool":
            return "tool_result"
        elif message.get("tool_calls"):
            return "tool_call"
        elif role == "user":
            return "user"
        elif role == "assistant":
            return "assistant"
        else:
            return "error"
    
    def select_messages(
        self,
        messages: list[dict],
        token_counter: callable
    ) -> list[dict]:
        """按优先级选择消息"""
        
        # 分类并评分
        prioritized = []
        for msg in messages:
            msg_type = self.classify_message(msg)
            priority = self.PRIORITY_RULES.get(msg_type, 0.5)
            
            # 时间衰减:越老的消息优先级略降
            index = messages.index(msg)
            total = len(messages)
            time_decay = 0.5 + 0.5 * (index / total)  # 越新衰减越小
            
            final_priority = priority * time_decay
            
            prioritized.append(MessagePriority(
                message=msg,
                priority=final_priority,
                tokens=token_counter(msg.get("content", ""))
            ))
        
        # 按优先级排序
        prioritized.sort(key=lambda p: p.priority, reverse=True)
        
        # 贪心选择
        selected = []
        used_tokens = 0
        
        for p in prioritized:
            if used_tokens + p.tokens <= self.max_tokens:
                selected.append(p.message)
                used_tokens += p.tokens
        
        # 按原始顺序排列
        original_order = {id(m): i for i, m in enumerate(messages)}
        selected.sort(key=lambda m: original_order.get(id(m), 0))
        
        return selected

7.2.5 混合策略

生产环境推荐使用混合策略:

python
class HybridContextManager:
    """混合上下文管理器"""
    
    def __init__(
        self,
        llm,
        max_tokens: int,
        system_prompt_tokens: int,
        reserve_output: int = 4096,
        keep_recent: int = 4,
        summary_max_tokens: int = 400
    ):
        self.llm = llm
        self.max_tokens = max_tokens
        self.available = max_tokens - system_prompt_tokens - reserve_output
        self.keep_recent = keep_recent
        self.compressor = SummaryCompressor(llm)
        self.priority_mgr = PriorityBasedManager(self.available)
        self.token_counter = lambda text: len(text) // 2
        self._cached_summary = ""
    
    def manage(self, messages: list[dict]) -> list[dict]:
        """管理上下文"""
        system = [m for m in messages if m["role"] == "system"]
        others = [m for m in messages if m["role"] != "system"]
        
        total_tokens = sum(
            self.token_counter(m.get("content", "")) for m in others
        )
        
        if total_tokens <= self.available:
            # 不需要压缩
            return messages
        
        # 策略1:如果超出不多,用优先级策略裁剪
        if total_tokens <= self.available * 1.5:
            selected = self.priority_mgr.select_messages(
                others, self.token_counter
            )
            return system + selected
        
        # 策略2:超出较多,先摘要再裁剪
        recent = others[-self.keep_recent:]
        old = others[:-self.keep_recent]
        
        # 增量摘要
        self._cached_summary = self.compressor.incremental_summarize(
            self._cached_summary, old
        )
        
        summary_msg = {
            "role": "system",
            "content": f"[对话历史摘要]\n{self._cached_summary}"
        }
        
        return system + [summary_msg] + recent

7.3 记忆检索与压缩

7.3.1 Embedding 基础

向量嵌入(Embedding)是语义检索的基础。它将文本转换为高维向量,使语义相似的文本在向量空间中距离更近:

python
class TextEmbedder:
    """文本嵌入生成器"""
    
    def __init__(self, model: str = "text-embedding-3-small", api_key: str = ""):
        from openai import OpenAI
        self.client = OpenAI(api_key=api_key)
        self.model = model
        self._cache: dict[str, list[float]] = {}
    
    def embed(self, text: str) -> list[float]:
        """生成文本嵌入"""
        if text in self._cache:
            return self._cache[text]
        
        response = self.client.embeddings.create(
            input=text,
            model=self.model
        )
        
        embedding = response.data[0].embedding
        self._cache[text] = embedding
        return embedding
    
    def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """批量生成嵌入"""
        # 过滤已有缓存的
        to_embed = [t for t in texts if t not in self._cache]
        
        if to_embed:
            response = self.client.embeddings.create(
                input=to_embed,
                model=self.model
            )
            
            for text, data in zip(to_embed, response.data):
                self._cache[text] = data.embedding
        
        return [self._cache[t] for t in texts]

7.3.2 记忆检索策略

python
class MemoryRetriever:
    """记忆检索器"""
    
    def __init__(self, embedder: TextEmbedder):
        self.embedder = embedder
    
    def retrieve(
        self,
        query: str,
        memories: list[MemoryItem],
        top_k: int = 5,
        strategy: str = "hybrid"
    ) -> list[tuple[MemoryItem, float]]:
        """
        检索相关记忆
        
        strategy:
        - "semantic": 纯语义检索
        - "recency": 纯时效检索
        - "importance": 纯重要性检索
        - "hybrid": 混合检索(推荐)
        """
        if not memories:
            return []
        
        if strategy == "semantic":
            return self._semantic_search(query, memories, top_k)
        elif strategy == "recency":
            return self._recency_search(memories, top_k)
        elif strategy == "importance":
            return self._importance_search(memories, top_k)
        else:
            return self._hybrid_search(query, memories, top_k)
    
    def _semantic_search(
        self, query: str, memories: list[MemoryItem], top_k: int
    ) -> list[tuple[MemoryItem, float]]:
        """语义检索"""
        query_embedding = self.embedder.embed(query)
        
        scored = []
        for item in memories:
            if item.embedding is None:
                item.embedding = self.embedder.embed(item.content)
            
            similarity = LongTermMemory._cosine_similarity(
                query_embedding, item.embedding
            )
            scored.append((similarity, item))
        
        scored.sort(reverse=True)
        return [(item, score) for score, item in scored[:top_k]]
    
    def _recency_search(
        self, memories: list[MemoryItem], top_k: int
    ) -> list[tuple[MemoryItem, float]]:
        """时效检索"""
        now = datetime.now()
        
        scored = []
        for item in memories:
            age_hours = (now - item.timestamp).total_seconds() / 3600
            # 时间衰减:1小时内1.0,每24小时衰减0.1
            recency_score = max(0, 1.0 - age_hours / 240)
            scored.append((recency_score, item))
        
        scored.sort(reverse=True)
        return [(item, score) for score, item in scored[:top_k]]
    
    def _importance_search(
        self, memories: list[MemoryItem], top_k: int
    ) -> list[tuple[MemoryItem, float]]:
        """重要性检索"""
        scored = [(item.importance, item) for item in memories]
        scored.sort(reverse=True)
        return [(item, score) for score, item in scored[:top_k]]
    
    def _hybrid_search(
        self, query: str, memories: list[MemoryItem], top_k: int
    ) -> list[tuple[MemoryItem, float]]:
        """混合检索——综合语义相关性、时效性和重要性"""
        query_embedding = self.embedder.embed(query)
        now = datetime.now()
        
        # 权重配置
        alpha = 0.6   # 语义相关性
        beta = 0.2    # 时效性
        gamma = 0.2   # 重要性
        
        scored = []
        for item in memories:
            # 语义分
            if item.embedding is None:
                item.embedding = self.embedder.embed(item.content)
            semantic = LongTermMemory._cosine_similarity(
                query_embedding, item.embedding
            )
            
            # 时效分
            age_hours = (now - item.timestamp).total_seconds() / 3600
            recency = max(0, 1.0 - age_hours / 240)
            
            # 重要性分
            importance = item.importance
            
            # 加权综合
            composite = alpha * semantic + beta * recency + gamma * importance
            scored.append((composite, item))
        
        scored.sort(reverse=True)
        return [(item, score) for score, item in scored[:top_k]]

7.3.3 记忆压缩

python
class MemoryCompressor:
    """记忆压缩器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def compress_memories(
        self,
        memories: list[MemoryItem],
        max_output_tokens: int = 500
    ) -> MemoryItem:
        """将多条记忆压缩为一条"""
        
        memory_texts = "\n".join(
            f"- [{m.timestamp.strftime('%m-%d %H:%M')}] {m.content}"
            for m in memories
        )
        
        prompt = f"""请将以下多条记忆压缩为一条精炼的摘要。
保留所有关键信息(事实、数据、决策、偏好),去除冗余。

原始记忆:
{memory_texts}

压缩为一条记忆({max_output_tokens}字以内):"""

        compressed = self.llm.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1
        ).content
        
        # 继承最重要记忆的元数据
        most_important = max(memories, key=lambda m: m.importance)
        
        return MemoryItem(
            content=compressed,
            memory_type=MemoryType.LONG_TERM,
            importance=most_important.importance,
            metadata={
                "compressed_from": len(memories),
                "source_memories": [m.timestamp.isoformat() for m in memories]
            }
        )
    
    def extract_key_facts(self, content: str) -> list[str]:
        """从文本中提取关键事实"""
        prompt = f"""从以下文本中提取关键事实,每条事实一行。
格式:[类别] 事实内容
类别:事实、偏好、决策、待办

文本:
{content}"""
        
        response = self.llm.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1
        ).content
        
        return [line.strip() for line in response.strip().split("\n") if line.strip()]

7.4 向量数据库集成

7.4.1 为什么需要向量数据库

当记忆数量达到数千甚至数百万条时,每次查询都遍历所有记忆计算相似度是不可行的。向量数据库通过近似最近邻(ANN)算法实现高效检索。

记忆数量 vs 检索延迟:
┌────────────────────────────────────────────┐
│                                            │
│  延迟  │  暴力搜索    │  向量数据库         │
│  (ms)  │  (遍历所有)  │  (ANN索引)         │
│        │             │                     │
│  1000  │  ~100       │  ~10                │
│  10000 │  ~1000      │  ~15                │
│  100K  │  ~10000     │  ~20                │
│  1M    │  ~100000    │  ~30                │
│  10M   │  ❌         │  ~50                │
└────────────────────────────────────────────┘

7.4.2 使用 ChromaDB

ChromaDB 是最易上手的向量数据库之一:

python
import chromadb
from chromadb.config import Settings

class ChromaMemoryStore:
    """基于 ChromaDB 的记忆存储"""
    
    def __init__(
        self,
        collection_name: str = "agent_memory",
        persist_directory: str = "./chroma_db"
    ):
        self.client = chromadb.Client(Settings(
            chroma_db_impl="duckdb+parquet",
            persist_directory=persist_directory
        ))
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}
        )
    
    def store(
        self,
        content: str,
        memory_id: str,
        metadata: dict | None = None,
        embedding: list[float] | None = None
    ):
        """存储记忆"""
        self.collection.upsert(
            ids=[memory_id],
            documents=[content],
            metadatas=[metadata or {}],
            embeddings=[embedding]  # 如果为 None,ChromaDB 会自动生成
        )
    
    def search(
        self,
        query: str,
        n_results: int = 5,
        where: dict | None = None,
        query_embedding: list[float] | None = None
    ) -> list[dict]:
        """检索记忆"""
        results = self.collection.query(
            query_texts=[query] if query else None,
            query_embeddings=[query_embedding] if query_embedding else None,
            n_results=n_results,
            where=where
        )
        
        memories = []
        for i in range(len(results["ids"][0])):
            memories.append({
                "id": results["ids"][0][i],
                "content": results["documents"][0][i],
                "metadata": results["metadatas"][0][i],
                "distance": results["distances"][0][i] if results["distances"] else None
            })
        
        return memories
    
    def delete(self, memory_ids: list[str]):
        """删除记忆"""
        self.collection.delete(ids=memory_ids)
    
    def count(self) -> int:
        """记忆总数"""
        return self.collection.count()
    
    def update_metadata(self, memory_id: str, metadata: dict):
        """更新元数据"""
        self.collection.update(
            ids=[memory_id],
            metadatas=[metadata]
        )

7.4.3 完整的记忆系统整合

python
class IntegratedMemorySystem:
    """整合的记忆系统——短期 + 长期(向量数据库)"""
    
    def __init__(
        self,
        llm,
        embedder: TextEmbedder,
        vector_store: ChromaMemoryStore,
        short_term_max: int = 50,
        long_term_top_k: int = 5
    ):
        self.llm = llm
        self.embedder = embedder
        self.vector_store = vector_store
        self.short_term = ShortTermMemory(max_messages=short_term_max)
        self.retriever = MemoryRetriever(embedder)
        self.long_term_top_k = long_term_top_k
        self.compressor = MemoryCompressor(llm)
    
    def add_conversation(self, role: str, content: str):
        """添加对话到短期记忆"""
        self.short_term.add_message(role, content)
    
    def store_to_long_term(
        self,
        content: str,
        importance: float = 0.5,
        category: str = "general",
        tags: list[str] | None = None
    ):
        """存储到长期记忆"""
        import uuid
        memory_id = str(uuid.uuid4())
        
        self.vector_store.store(
            content=content,
            memory_id=memory_id,
            metadata={
                "importance": importance,
                "category": category,
                "tags": json.dumps(tags or []),
                "created_at": datetime.now().isoformat()
            }
        )
    
    def build_context(self, query: str) -> list[dict]:
        """构建完整的上下文(短期记忆 + 检索到的长期记忆)"""
        # 1. 短期记忆
        context = self.short_term.get_context()
        
        # 2. 从长期记忆检索相关内容
        long_term_results = self.vector_store.search(
            query=query,
            n_results=self.long_term_top_k
        )
        
        if long_term_results:
            memory_text = "\n".join(
                f"- {r['content']}" for r in long_term_results
            )
            
            memory_context = {
                "role": "system",
                "content": f"[相关记忆]\n{memory_text}"
            }
            
            # 插入到系统消息之后
            if context and context[0]["role"] == "system":
                context.insert(1, memory_context)
            else:
                context.insert(0, memory_context)
        
        return context
    
    def end_session(self):
        """会话结束时,将重要信息转移到长期记忆"""
        # 提取对话中的关键信息
        conversation = self.short_term.get_context()
        conversation_text = "\n".join(
            f"{m['role']}: {m['content']}" for m in conversation
            if m["role"] != "system"
        )
        
        if len(conversation_text) < 50:
            return
        
        # 让 LLM 提取值得记住的信息
        prompt = f"""这个会话即将结束。请提取值得长期记住的信息。
包括:用户偏好、重要决策、关键数据、待办事项。
每条信息一行,格式:[类别] 内容

对话内容:
{conversation_text[:3000]}"""
        
        response = self.llm.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1
        ).content
        
        # 存储到长期记忆
        for line in response.strip().split("\n"):
            if line.strip():
                self.store_to_long_term(
                    content=line.strip(),
                    importance=0.7,
                    category="session_summary"
                )
        
        # 清空短期记忆
        self.short_term.messages = []

7.5 对话历史管理

7.5.1 会话管理

python
@dataclass
class Session:
    """会话"""
    session_id: str
    title: str = ""
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)
    message_count: int = 0
    metadata: dict = field(default_factory=dict)

class SessionManager:
    """会话管理器"""
    
    def __init__(self, memory_system: IntegratedMemorySystem):
        self.memory = memory_system
        self._sessions: dict[str, Session] = {}
    
    def create_session(self, session_id: str | None = None) -> Session:
        """创建新会话"""
        import uuid
        sid = session_id or str(uuid.uuid4())
        
        session = Session(session_id=sid)
        self._sessions[sid] = session
        
        self.memory.short_term = ShortTermMemory()
        self.memory.short_term.add_system(
            "你是一个有帮助的 AI Agent。"
        )
        
        return session
    
    def switch_session(self, session_id: str):
        """切换会话"""
        if session_id not in self._sessions:
            raise ValueError(f"会话不存在:{session_id}")
        
        # 保存当前会话
        self.end_session()
        
        # 切换
        session = self._sessions[session_id]
        # 实际实现中需要从持久化存储加载会话历史
    
    def end_session(self):
        """结束当前会话"""
        self.memory.end_session()
    
    def list_sessions(self) -> list[dict]:
        """列出所有会话"""
        return [
            {
                "session_id": s.session_id,
                "title": s.title,
                "message_count": s.message_count,
                "created_at": s.created_at.isoformat(),
                "updated_at": s.updated_at.isoformat()
            }
            for s in self._sessions.values()
        ]
    
    def auto_title(self, session_id: str, first_messages: list[str]):
        """自动生成会话标题"""
        conversation = "\n".join(first_messages[:5])
        
        prompt = f"""根据以下对话的开头,生成一个简洁的会话标题(10字以内):
{conversation}

标题:"""
        
        title = self.memory.llm.chat(prompt, temperature=0.3).content.strip()
        
        if session_id in self._sessions:
            self._sessions[session_id].title = title

7.5.2 对话历史持久化

python
import sqlite3
import json
from pathlib import Path

class ConversationStore:
    """对话历史持久化存储"""
    
    def __init__(self, db_path: str = "conversations.db"):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """初始化数据库"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS conversations (
                    id TEXT PRIMARY KEY,
                    session_id TEXT NOT NULL,
                    role TEXT NOT NULL,
                    content TEXT NOT NULL,
                    tokens INTEGER DEFAULT 0,
                    metadata TEXT DEFAULT '{}',
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (session_id) REFERENCES sessions(id)
                )
            """)
            conn.execute("""
                CREATE TABLE IF NOT EXISTS sessions (
                    id TEXT PRIMARY KEY,
                    title TEXT DEFAULT '',
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    metadata TEXT DEFAULT '{}'
                )
            """)
            # 创建索引加速查询
            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_conv_session 
                ON conversations(session_id, created_at)
            """)
    
    def save_message(
        self,
        session_id: str,
        role: str,
        content: str,
        tokens: int = 0,
        metadata: dict | None = None
    ):
        """保存消息"""
        import uuid
        msg_id = str(uuid.uuid4())
        
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                "INSERT INTO conversations VALUES (?,?,?,?,?,?,?)",
                (
                    msg_id, session_id, role, content,
                    tokens, json.dumps(metadata or {}),
                    datetime.now().isoformat()
                )
            )
            # 更新会话的 updated_at
            conn.execute(
                "UPDATE sessions SET updated_at = ? WHERE id = ?",
                (datetime.now().isoformat(), session_id)
            )
    
    def load_history(
        self,
        session_id: str,
        limit: int = 100,
        offset: int = 0
    ) -> list[dict]:
        """加载对话历史"""
        with sqlite3.connect(self.db_path) as conn:
            conn.row_factory = sqlite3.Row
            rows = conn.execute(
                """SELECT role, content, tokens, metadata, created_at
                   FROM conversations
                   WHERE session_id = ?
                   ORDER BY created_at ASC
                   LIMIT ? OFFSET ?""",
                (session_id, limit, offset)
            ).fetchall()
            
            return [
                {
                    "role": row["role"],
                    "content": row["content"],
                    "tokens": row["tokens"],
                    "metadata": json.loads(row["metadata"]),
                    "created_at": row["created_at"]
                }
                for row in rows
            ]
    
    def search_history(
        self,
        session_id: str,
        keyword: str,
        limit: int = 20
    ) -> list[dict]:
        """搜索历史消息"""
        with sqlite3.connect(self.db_path) as conn:
            conn.row_factory = sqlite3.Row
            rows = conn.execute(
                """SELECT role, content, created_at
                   FROM conversations
                   WHERE session_id = ? AND content LIKE ?
                   ORDER BY created_at DESC
                   LIMIT ?""",
                (session_id, f"%{keyword}%", limit)
            ).fetchall()
            
            return [dict(row) for row in rows]
    
    def get_session_stats(self, session_id: str) -> dict:
        """获取会话统计"""
        with sqlite3.connect(self.db_path) as conn:
            total = conn.execute(
                "SELECT COUNT(*) FROM conversations WHERE session_id = ?",
                (session_id,)
            ).fetchone()[0]
            
            total_tokens = conn.execute(
                "SELECT COALESCE(SUM(tokens), 0) FROM conversations WHERE session_id = ?",
                (session_id,)
            ).fetchone()[0]
            
            first_msg = conn.execute(
                "SELECT MIN(created_at) FROM conversations WHERE session_id = ?",
                (session_id,)
            ).fetchone()[0]
            
            return {
                "total_messages": total,
                "total_tokens": total_tokens,
                "first_message_at": first_msg,
                "avg_tokens_per_message": total_tokens / total if total > 0 else 0
            }

7.6 记忆的遗忘与更新机制

7.6.1 为什么要"遗忘"

人类大脑的遗忘不是缺陷,而是特性。Agent 同样需要遗忘机制:

  • 控制成本:长期记忆越多,检索越慢、存储越贵
  • 过滤噪音:并非所有信息都值得长期保存
  • 适应变化:过时的信息可能产生误导
python
class MemoryForgetter:
    """记忆遗忘管理器"""
    
    def __init__(
        self,
        decay_rate: float = 0.01,  # 每天的重要性衰减率
        min_importance: float = 0.1,  # 低于此值的记忆被遗忘
        max_age_days: int = 90  # 超过此天数的记忆强制遗忘
    ):
        self.decay_rate = decay_rate
        self.min_importance = min_importance
        self.max_age_days = max_age_days
    
    def apply_decay(self, memories: list[MemoryItem]) -> list[MemoryItem]:
        """应用时间衰减"""
        now = datetime.now()
        
        for item in memories:
            age_days = (now - item.timestamp).total_seconds() / 86400
            
            # 指数衰减
            decayed_importance = item.importance * (1 - self.decay_rate) ** age_days
            
            # 被访问过的记忆衰减更慢
            if item.access_count > 0:
                boost = min(0.3, 0.05 * item.access_count)
                decayed_importance = min(1.0, decayed_importance + boost)
            
            item.importance = decayed_importance
        
        return memories
    
    def get_forgettable(self, memories: list[MemoryItem]) -> list[MemoryItem]:
        """获取应该被遗忘的记忆"""
        self.apply_decay(memories)
        
        forgettable = []
        for item in memories:
            # 规则1:重要性过低
            if item.importance < self.min_importance:
                forgettable.append(item)
                continue
            
            # 规则2:年龄过大
            if item.age_hours / 24 > self.max_age_days:
                forgettable.append(item)
                continue
            
            # 规则3:过期
            if item.is_expired:
                forgettable.append(item)
        
        return forgettable
    
    def forget(self, memories: list[MemoryItem]) -> list[MemoryItem]:
        """执行遗忘"""
        forgettable_ids = {id(m) for m in self.get_forgettable(memories)}
        remaining = [m for m in memories if id(m) not in forgettable_ids]
        return remaining

7.6.2 记忆更新

当新信息与旧记忆冲突时,需要更新而非保留两者:

python
class MemoryUpdater:
    """记忆更新管理器"""
    
    def __init__(self, embedder: TextEmbedder, llm, threshold: float = 0.85):
        self.embedder = embedder
        self.llm = llm
        self.threshold = threshold  # 相似度阈值
    
    def check_conflicts(
        self,
        new_content: str,
        existing_memories: list[MemoryItem]
    ) -> list[tuple[MemoryItem, float]]:
        """检查新记忆是否与已有记忆冲突"""
        new_embedding = self.embedder.embed(new_content)
        
        conflicts = []
        for item in existing_memories:
            if item.embedding is None:
                item.embedding = self.embedder.embed(item.content)
            
            similarity = LongTermMemory._cosine_similarity(
                new_embedding, item.embedding
            )
            
            if similarity >= self.threshold:
                conflicts.append((item, similarity))
        
        return conflicts
    
    def resolve_conflict(
        self,
        new_content: str,
        old_memory: MemoryItem,
        similarity: float
    ) -> MemoryItem:
        """解决记忆冲突"""
        
        # 如果相似度极高(>0.95),可能是重复信息
        if similarity > 0.95:
            # 保留更重要的那个
            if new_content in old_memory.content:
                return old_memory  # 完全重复,保留旧的
            # 让 LLM 判断哪个更准确
            prompt = f"""以下两条信息非常相似,请判断哪个更准确/更新:

信息A(时间:{old_memory.timestamp}):
{old_memory.content}

信息B(最新):
{new_content}

请回答:
1. 如果B更新更准确,返回 "UPDATE: B的理由"
2. 如果A仍然准确,返回 "KEEP: A的理由"
3. 如果两者互补,返回 "MERGE: 合并后的内容"
"""
            response = self.llm.chat(
                messages=[{"role": "user", "content": prompt}],
                temperature=0.1
            ).content
            
            if response.startswith("UPDATE"):
                # 用新信息更新旧记忆
                old_memory.content = new_content
                old_memory.timestamp = datetime.now()
                old_memory.importance = max(old_memory.importance, 0.7)
                return old_memory
            elif response.startswith("MERGE"):
                # 合并
                merged_content = response.split(":", 1)[1].strip()
                old_memory.content = merged_content
                old_memory.timestamp = datetime.now()
                old_memory.metadata["merged"] = True
                return old_memory
            else:
                return old_memory
        else:
            # 相似但不完全相同,可能需要补充
            old_memory.metadata["related_new_info"] = new_content
            return old_memory

7.6.3 记忆的自动重要性评估

python
class ImportanceEvaluator:
    """记忆重要性自动评估器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def evaluate(self, content: str, context: str = "") -> float:
        """评估记忆的重要性(0-1)"""
        prompt = f"""评估以下信息的重要性(0-10分)。

评估标准:
- 10分:关键事实、重要决策、用户核心偏好
- 7-9分:有用的信息、中等重要的发现
- 4-6分:一般性信息、可能有用
- 1-3分:临时信息、很快会过时

信息:{content}
{f"上下文:{context}" if context else ""}

只返回数字(0-10)。"""
        
        try:
            response = self.llm.chat(
                messages=[{"role": "user", "content": prompt}],
                temperature=0.1
            )
            score = float(response.content.strip())
            return min(1.0, max(0.0, score / 10))
        except (ValueError, Exception):
            return 0.5  # 默认中等重要性
    
    def batch_evaluate(
        self,
        items: list[tuple[str, str]]
    ) -> list[float]:
        """批量评估"""
        # 简化实现:逐个评估
        return [
            self.evaluate(content, context)
            for content, context in items
        ]

7.7 常见陷阱与最佳实践

7.7.1 常见陷阱

陷阱1:把所有对话历史都发到 LLM

python
# ❌ 无限增长的上下文
def chat(messages: list):
    # messages 会不断增长,最终超出上下文窗口
    response = llm.chat(messages=messages)
    messages.append({"role": "assistant", "content": response})
    return response

# ✅ 管理上下文长度
def chat(messages: list, context_manager):
    managed = context_manager.manage(messages)
    response = llm.chat(messages=managed)
    return response

陷阱2:摘要丢失关键细节

python
# ❌ 过度压缩
prompt = "请用一句话总结以下对话"  # 太短了!

# ✅ 保留关键信息
prompt = """请将对话压缩为摘要,确保保留:
1. 所有数字、日期、人名等具体信息
2. 用户明确表达的偏好和要求
3. 任何承诺或待办事项
4. 技术细节和代码片段

摘要长度:200-300字"""

陷阱3:忽略记忆的时效性

python
# ❌ 检索到过时的记忆
memory: "用户偏好使用 Python 2.7"  # 2019年的记忆
# 直接使用这个记忆来推荐技术方案 → 错误!

# ✅ 检查时效性
if memory.age_hours > 24 * 365:  # 超过1年
    memory.importance *= 0.3  # 降低重要性
    # 或标记为"需要确认"

7.7.2 最佳实践

python
MEMORY_BEST_PRACTICES = """
## 记忆管理最佳实践

### ✅ 存储策略
- [ ] 区分短期/长期记忆,不要把所有东西都存长期
- [ ] 存储时标注重要性,方便后续过滤
- [ ] 记忆内容简洁化——存储结论而非原始对话
- [ ] 添加元数据(类别、标签、来源),方便检索

### ✅ 检索策略
- [ ] 使用混合检索(语义 + 时效 + 重要性)
- [ ] 设置相似度阈值,过滤低质量结果
- [ ] 检索结果数量控制在 3-7 条
- [ ] 将检索到的记忆格式化后注入上下文

### ✅ 上下文管理
- [ ] 监控上下文 Token 使用率
- [ ] 设置上下文预算(预留输出空间)
- [ ] 优先保留系统消息和最近对话
- [ ] 使用摘要而非简单截断

### ✅ 维护策略
- [ ] 定期执行遗忘(清理低重要性、过期记忆)
- [ ] 检测并解决记忆冲突
- [ ] 记忆整合(合并相似记忆)
- [ ] 监控记忆系统的存储大小和检索延迟
"""

7.8 本章小结

本章我们深入探讨了 Agent 记忆系统的设计与实现:

  1. 记忆层次:感知记忆、工作记忆、短期记忆、长期记忆的分层设计
  2. 上下文管理:滑动窗口、摘要压缩、优先级保留、混合策略
  3. 记忆检索:语义检索、时效检索、重要性检索、混合检索
  4. 向量数据库:ChromaDB 集成,实现高效的语义检索
  5. 对话历史管理:会话管理、持久化存储、搜索
  6. 遗忘与更新:时间衰减、冲突解决、重要性评估

核心洞察: 记忆系统是 Agent 区别于简单 Chatbot 的关键特征。好的记忆系统不是存储越多越好,而是"存该存的,忘该忘的,在需要时能快速找到对的"。记忆管理是一门平衡的艺术——在信息完整性和成本效率之间找到最佳平衡点。


卷二总结

恭喜你完成了卷二"基础篇"的学习!回顾一下我们走过的路:

章节核心收获
第4章:Agent核心概念理解了 Agent 的架构模型、核心组件、生命周期和评估体系
第5章:LLM与Prompt Engineering掌握了与 LLM 高效沟通的技巧——Prompt 设计、CoT 推理、模板管理
第6章:工具调用学会了赋予 Agent 行动能力——Function Calling、工具开发、错误处理
第7章:记忆与上下文实现了 Agent 的持久化能力——记忆分层、向量检索、上下文管理

现在你已经具备了构建一个完整 Agent 系统的所有基础知识。在卷三"进阶篇"中,我们将把这些组件整合为更复杂的系统——多 Agent 协作、生产级部署、安全与治理。准备好了吗?


下一卷:卷三《进阶篇》—— 多 Agent 协作、生产级部署、安全与治理、可观测性。

基于 MIT 许可发布