Skip to content

第21章:状态机与流程编排

让Agent从"单步执行"进化为"可控的工作流"


21.1 引言

当 Agent 系统从简单的"输入→输出"模式演进到需要多步骤、多决策的复杂流程时,状态机(State Machine)和流程编排(Orchestration)就成为必不可少的工程工具。状态机赋予 Agent 确定性的行为控制,流程编排则让多个 Agent 或步骤能够协同完成复杂任务。

本章将系统讲解如何使用有限状态机、LangGraph 等工具来设计和实现 Agent 工作流,包括条件分支、循环、人工介入等高级模式。

本章学习目标

  • 理解有限状态机(FSM)在 Agent 系统中的作用
  • 掌握状态图的设计与转换规则
  • 使用 LangGraph 实现状态图编排
  • 实现条件分支、循环和人工介入模式
  • 设计可靠的复杂工作流编排模式

21.2 有限状态机(FSM)设计

21.2.1 为什么Agent需要状态机

LLM 本质上是无状态的——每次调用都是独立的。但在实际应用中,Agent 的行为需要依赖上下文和当前所处的阶段:

  • 客服系统:识别问题→收集信息→尝试解决→转人工
  • 代码助手:理解需求→设计方案→编写代码→审查→修改
  • 数据分析:接收请求→查询数据→分析结果→生成报告

没有状态管理,Agent 就像一个每次醒来都失忆的助手——无法维持对话的连贯性,也无法在复杂任务中保持正确的执行顺序。

21.2.2 状态机基础实现

python
from enum import Enum
from typing import Callable, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime


class AgentState(Enum):
    """Agent 状态枚举"""
    IDLE = "idle"                    # 空闲
    COLLECTING_INFO = "collecting"    # 收集信息
    PROCESSING = "processing"         # 处理中
    AWAITING_CONFIRMATION = "confirm" # 等待确认
    EXECUTING = "executing"           # 执行操作
    REVIEWING = "reviewing"           # 审查结果
    ESCALATING = "escalating"         # 升级处理
    COMPLETED = "completed"           # 已完成
    FAILED = "failed"                 # 失败


@dataclass
class StateTransition:
    """状态转换规则"""
    from_state: AgentState
    to_state: AgentState
    condition: Optional[Callable] = None
    action: Optional[Callable] = None
    description: str = ""


@dataclass
class StateContext:
    """状态上下文——携带状态间传递的数据"""
    session_id: str
    current_state: AgentState
    previous_state: Optional[AgentState] = None
    data: dict = field(default_factory=dict)
    history: list[dict] = field(default_factory=list)
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)
    
    def transition_to(self, new_state: AgentState):
        self.previous_state = self.current_state
        self.current_state = new_state
        self.updated_at = datetime.now()
        self.history.append({
            "from": self.previous_state.value,
            "to": new_state.value,
            "timestamp": self.updated_at.isoformat(),
        })


class AgentStateMachine:
    """Agent 有限状态机"""
    
    def __init__(self, session_id: str):
        self.context = StateContext(
            session_id=session_id,
            current_state=AgentState.IDLE,
        )
        self._transitions: list[StateTransition] = []
        self._invalid_transitions: set[tuple[str, str]] = set()
        self._state_handlers: dict[AgentState, Callable] = {}
    
    def add_transition(self, transition: StateTransition):
        """添加状态转换规则"""
        self._transitions.append(transition)
    
    def add_invalid_transition(self, from_s: AgentState, to_s: AgentState):
        """添加禁止的转换"""
        self._invalid_transitions.add((from_s.value, to_s.value))
    
    def set_handler(self, state: AgentState, handler: Callable):
        """设置状态处理函数"""
        self._state_handlers[state] = handler
    
    def transition(self, target_state: AgentState, **kwargs) -> bool:
        """尝试执行状态转换"""
        current = self.context.current_state
        
        # 检查是否为禁止转换
        if (current.value, target_state.value) in self._invalid_transitions:
            raise InvalidTransitionError(
                f"禁止从 {current.value} 转换到 {target_state.value}"
            )
        
        # 查找匹配的转换规则
        matching = [
            t for t in self._transitions
            if t.from_state == current and t.to_state == target_state
        ]
        
        for rule in matching:
            if rule.condition and not rule.condition(self.context, **kwargs):
                continue
            
            # 执行转换
            if rule.action:
                rule.action(self.context, **kwargs)
            
            self.context.transition_to(target_state)
            
            # 执行新状态的处理器
            handler = self._state_handlers.get(target_state)
            if handler:
                handler(self.context)
            
            return True
        
        # 没有匹配的规则
        raise InvalidTransitionError(
            f"没有从 {current.value}{target_state.value} 的转换规则"
        )
    
    def get_current_state(self) -> AgentState:
        return self.context.current_state
    
    def get_state_history(self) -> list[dict]:
        return self.context.history
    
    def reset(self):
        """重置到初始状态"""
        self.context = StateContext(
            session_id=self.context.session_id,
            current_state=AgentState.IDLE,
        )


class InvalidTransitionError(Exception):
    pass

21.2.3 客服Agent状态机实例

python
# customer_service_fsm.py

def build_customer_service_fsm(session_id: str) -> AgentStateMachine:
    """构建客服Agent的状态机"""
    fsm = AgentStateMachine(session_id)
    
    # 定义转换规则
    fsm.add_transition(StateTransition(
        from_state=AgentState.IDLE,
        to_state=AgentState.COLLECTING_INFO,
        action=lambda ctx, **kw: ctx.data.update({"query": kw.get("query", "")}),
        description="收到用户查询,开始收集信息",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.COLLECTING_INFO,
        to_state=AgentState.PROCESSING,
        condition=lambda ctx, **kw: ctx.data.get("info_complete", False),
        description="信息收集完毕,开始处理",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.PROCESSING,
        to_state=AgentState.AWAITING_CONFIRMATION,
        condition=lambda ctx, **kw: ctx.data.get("solution", None) is not None,
        description="生成解决方案,等待用户确认",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.PROCESSING,
        to_state=AgentState.ESCALATING,
        condition=lambda ctx, **kw: ctx.data.get("need_escalation", False),
        description="无法处理,升级到人工",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.AWAITING_CONFIRMATION,
        to_state=AgentState.EXECUTING,
        condition=lambda ctx, **kw: kw.get("confirmed", False),
        description="用户确认,执行方案",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.AWAITING_CONFIRMATION,
        to_state=AgentState.PROCESSING,
        condition=lambda ctx, **kw: not kw.get("confirmed", True),
        description="用户不满意,重新处理",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.EXECUTING,
        to_state=AgentState.COMPLETED,
        condition=lambda ctx, **kw: ctx.data.get("execution_success", True),
        description="执行成功",
    ))
    
    fsm.add_transition(StateTransition(
        from_state=AgentState.EXECUTING,
        to_state=AgentState.FAILED,
        condition=lambda ctx, **kw: not ctx.data.get("execution_success", True),
        description="执行失败",
    ))
    
    # 设置状态处理器
    fsm.set_handler(AgentState.COLLECTING_INFO, _handle_collecting)
    fsm.set_handler(AgentState.PROCESSING, _handle_processing)
    fsm.set_handler(AgentState.COMPLETED, _handle_completed)
    
    return fsm


def _handle_collecting(ctx: StateContext):
    print(f"[{ctx.session_id}] 收集信息中... 当前数据: {ctx.data}")


def _handle_processing(ctx: StateContext):
    print(f"[{ctx.session_id}] 处理查询: {ctx.data.get('query', '')}")


def _handle_completed(ctx: StateContext):
    print(f"[{ctx.session_id}] 任务完成!")


# 使用示例
fsm = build_customer_service_fsm("session-001")
fsm.transition(AgentState.COLLECTING_INFO, query="无法登录账户")
fsm.context.data["info_complete"] = True
fsm.transition(AgentState.PROCESSING)
fsm.context.data["solution"] = "重置密码"
fsm.transition(AgentState.AWAITING_CONFIRMATION)
fsm.transition(AgentState.EXECUTING, confirmed=True)
fsm.context.data["execution_success"] = True
fsm.transition(AgentState.COMPLETED)
print(f"状态历史: {fsm.get_state_history()}")

21.2.4 TypeScript实现

typescript
enum AgentState {
  IDLE = "idle",
  COLLECTING = "collecting",
  PROCESSING = "processing",
  CONFIRMING = "confirming",
  EXECUTING = "executing",
  COMPLETED = "completed",
  FAILED = "failed",
}

interface TransitionRule {
  from: AgentState;
  to: AgentState;
  condition?: (ctx: StateContext) => boolean;
  action?: (ctx: StateContext) => void;
}

interface StateContext {
  sessionId: string;
  current: AgentState;
  previous?: AgentState;
  data: Record<string, any>;
  history: Array<{ from: string; to: string; ts: string }>;
}

export class AgentFSM {
  private ctx: StateContext;
  private rules: TransitionRule[] = [];

  constructor(sessionId: string) {
    this.ctx = {
      sessionId,
      current: AgentState.IDLE,
      data: {},
      history: [],
    };
  }

  addRule(rule: TransitionRule): void {
    this.rules.push(rule);
  }

  transition(target: AgentState, extraData?: Record<string, any>): boolean {
    const current = this.ctx.current;
    const matching = this.rules.filter(
      r => r.from === current && r.to === target
    );

    for (const rule of matching) {
      if (rule.condition && !rule.condition(this.ctx)) continue;
      
      if (rule.action) rule.action(this.ctx);
      if (extraData) Object.assign(this.ctx.data, extraData);

      this.ctx.previous = current;
      this.ctx.current = target;
      this.ctx.history.push({
        from: current, to: target,
        ts: new Date().toISOString(),
      });
      return true;
    }
    throw new Error(`Invalid: ${current} → ${target}`);
  }

  getState(): AgentState { return this.ctx.current; }
  getData(): Record<string, any> { return this.ctx.data; }
}

21.3 状态图与转换规则

21.3.1 状态图设计原则

设计状态图时需要遵循几个关键原则:

  1. 状态最小化:每个状态代表一个明确的阶段,避免"万能状态"
  2. 转换明确:每个转换都有清晰的条件和动作
  3. 终态可达:任何状态都能通过某种路径到达终态(完成或失败)
  4. 错误处理:任何状态都应该能转到失败状态

21.3.2 声明式状态图定义

python
from typing import TypedDict


class StateGraphConfig(TypedDict):
    """状态图配置"""
    name: str
    initial_state: str
    final_states: list[str]
    states: dict[str, dict]       # 状态定义
    transitions: list[dict]        # 转换规则


# 声明式定义一个订单处理状态图
order_processing_graph: StateGraphConfig = {
    "name": "order_processing",
    "initial_state": "received",
    "final_states": ["completed", "cancelled", "failed"],
    "states": {
        "received": {
            "description": "订单已接收",
            "timeout_seconds": 300,
        },
        "validated": {
            "description": "订单验证通过",
        },
        "payment_pending": {
            "description": "等待支付",
            "timeout_seconds": 1800,
        },
        "paid": {
            "description": "支付完成",
        },
        "shipped": {
            "description": "已发货",
        },
        "completed": {
            "description": "订单完成",
        },
        "cancelled": {
            "description": "订单取消",
        },
        "failed": {
            "description": "处理失败",
        },
    },
    "transitions": [
        {"from": "received", "to": "validated", "condition": "valid_order"},
        {"from": "received", "to": "cancelled", "condition": "user_cancel"},
        {"from": "validated", "to": "payment_pending"},
        {"from": "payment_pending", "to": "paid", "condition": "payment_success"},
        {"from": "payment_pending", "to": "cancelled", "condition": "payment_timeout"},
        {"from": "paid", "to": "shipped"},
        {"from": "shipped", "to": "completed", "condition": "delivery_confirmed"},
        # 错误转换——任何状态都可以转到failed
        {"from": "*", "to": "failed", "condition": "error_occurred"},
    ],
}

21.3.3 通用状态图引擎

python
class StateGraphEngine:
    """通用状态图执行引擎"""
    
    def __init__(self, config: StateGraphConfig):
        self.config = config
        self._condition_handlers: dict[str, Callable] = {}
        self._state_entry_handlers: dict[str, Callable] = {}
        self._state_exit_handlers: dict[str, Callable] = {}
        self._timeout_timers: dict[str, any] = {}
    
    def register_condition(self, name: str, handler: Callable):
        """注册条件判断函数"""
        self._condition_handlers[name] = handler
    
    def on_enter(self, state: str, handler: Callable):
        """注册状态进入回调"""
        self._state_entry_handlers[state] = handler
    
    def on_exit(self, state: str, handler: Callable):
        """注册状态退出回调"""
        self._state_exit_handlers[state] = handler
    
    def execute(self, context: dict) -> str:
        """从初始状态执行状态图"""
        current = self.config["initial_state"]
        
        while current not in self.config["final_states"]:
            # 执行进入回调
            entry_handler = self._state_entry_handlers.get(current)
            if entry_handler:
                entry_handler(context, current)
            
            # 查找可用的转换
            next_state = self._find_next_transition(current, context)
            
            if next_state is None:
                # 没有可用的转换,检查是否有默认错误处理
                if "failed" in self.config["final_states"]:
                    next_state = "failed"
                else:
                    raise RuntimeError(f"状态 {current} 无法继续执行")
            
            # 执行退出回调
            exit_handler = self._state_exit_handlers.get(current)
            if exit_handler:
                exit_handler(context, current)
            
            context["current_state"] = next_state
            context.setdefault("state_history", []).append({
                "from": current, "to": next_state,
                "timestamp": __import__('datetime').datetime.now().isoformat(),
            })
            
            current = next_state
        
        return current
    
    def _find_next_transition(self, current: str, context: dict) -> str | None:
        """查找下一个状态"""
        for trans in self.config["transitions"]:
            from_state = trans["from"]
            if from_state != "*" and from_state != current:
                continue
            
            condition_name = trans.get("condition")
            if condition_name:
                handler = self._condition_handlers.get(condition_name)
                if handler and handler(context):
                    return trans["to"]
            else:
                # 无条件转换(第一个匹配的)
                return trans["to"]
        
        return None

21.4 LangGraph的状态图编排

21.4.1 LangGraph简介

LangGraph 是 LangChain 生态中用于构建有状态、多步骤 Agent 应用的框架。它基于图论概念,将 Agent 的工作流建模为节点(Node)和边(Edge)组成的图。

核心概念:

  • State:在图的各个节点之间传递的可变状态对象
  • Node:执行特定逻辑的函数,接收状态并可能修改状态
  • Edge:连接节点的有向边,可以是固定的或条件性的

21.4.2 使用LangGraph构建研究助手

python
from typing import TypedDict, Annotated, Literal
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage


class ResearchState(TypedDict):
    """研究助手的状态定义"""
    messages: Annotated[list, add_messages]  # 对话消息
    research_topic: str                       # 研究主题
    search_results: list[dict]                # 搜索结果
    analysis: str                             # 分析结果
    report: str                               # 最终报告
    iteration_count: int                      # 迭代次数
    max_iterations: int                       # 最大迭代次数
    needs_more_research: bool                 # 是否需要更多研究


def search_node(state: ResearchState) -> dict:
    """搜索节点——执行信息检索"""
    import json
    topic = state.get("research_topic", "")
    
    # 模拟搜索(实际中调用搜索API)
    results = [
        {"title": f"{topic}的最新进展", "relevance": 0.95},
        {"title": f"{topic}的技术分析", "relevance": 0.85},
    ]
    
    return {
        "search_results": results,
        "messages": [AIMessage(content=f"找到了 {len(results)} 条相关结果")],
    }


def analyze_node(state: ResearchState) -> dict:
    """分析节点——对搜索结果进行深度分析"""
    results = state.get("search_results", [])
    analysis = f"基于 {len(results)} 条搜索结果的分析..."
    
    # 决定是否需要更多研究
    needs_more = (
        len(results) < 5 
        or state.get("iteration_count", 0) < state.get("max_iterations", 2)
    )
    
    return {
        "analysis": analysis,
        "needs_more_research": needs_more,
        "iteration_count": state.get("iteration_count", 0) + 1,
        "messages": [AIMessage(content=f"完成分析,{'需要' if needs_more else '不需要'}更多研究")],
    }


def write_report_node(state: ResearchState) -> dict:
    """报告生成节点"""
    analysis = state.get("analysis", "")
    report = f"# 研究报告\n\n## 摘要\n{analysis}\n\n## 详细内容\n..."
    
    return {
        "report": report,
        "messages": [AIMessage(content="研究报告已生成")],
    }


def should_continue(state: ResearchState) -> Literal["search", "report"]:
    """条件边——决定下一步"""
    if state.get("needs_more_research", False):
        return "search"
    return "report"


def build_research_graph() -> StateGraph:
    """构建研究助手的状态图"""
    graph = StateGraph(ResearchState)
    
    # 添加节点
    graph.add_node("search", search_node)
    graph.add_node("analyze", analyze_node)
    graph.add_node("report", write_report_node)
    
    # 设置入口
    graph.set_entry_point("search")
    
    # 添加边
    graph.add_edge("search", "analyze")
    graph.add_conditional_edges(
        "analyze",
        should_continue,
        {"search": "search", "report": "report"},
    )
    graph.add_edge("report", END)
    
    return graph.compile()


# 使用示例
research_graph = build_research_graph()
result = research_graph.invoke({
    "messages": [HumanMessage(content="研究量子计算的最新进展")],
    "research_topic": "量子计算",
    "max_iterations": 3,
    "iteration_count": 0,
})

print(result["report"])

21.4.3 LangGraph高级模式:带人工审核的工作流

python
from typing import TypedDict
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver  # 持久化


class ApprovalState(TypedDict):
    """带审批的工作流状态"""
    request: dict
    review_result: str
    approved: bool
    execution_result: str
    reviewer_notes: str


def submit_node(state: ApprovalState) -> dict:
    """提交请求节点"""
    return {"review_result": "pending", "approved": False}


def auto_review_node(state: ApprovalState) -> dict:
    """自动审查节点"""
    request = state.get("request", {})
    amount = request.get("amount", 0)
    
    # 小额自动审批
    if amount <= 1000:
        return {
            "approved": True,
            "review_result": "auto_approved",
            "reviewer_notes": f"金额 ¥{amount},自动审批通过",
        }
    
    return {
        "approved": False,
        "review_result": "needs_manual_review",
        "reviewer_notes": f"金额 ¥{amount},需要人工审核",
    }


def manual_review_node(state: ApprovalState) -> dict:
    """人工审核节点——返回一个中断点"""
    # 在 LangGraph 中,这里会暂停执行等待人工输入
    return {
        "review_result": "manual_review_pending",
        "reviewer_notes": "等待人工审核...",
    }


def execute_node(state: ApprovalState) -> dict:
    """执行节点"""
    if state.get("approved"):
        return {"execution_result": "执行成功"}
    return {"execution_result": "已拒绝,未执行"}


def needs_manual_review(state: ApprovalState) -> str:
    """条件判断:是否需要人工审核"""
    if state.get("review_result") == "needs_manual_review":
        return "manual_review"
    return "execute"


def is_approved(state: ApprovalState) -> str:
    """条件判断:是否已审批"""
    if state.get("approved"):
        return "execute"
    return END


def build_approval_graph():
    graph = StateGraph(ApprovalState)
    
    graph.add_node("submit", submit_node)
    graph.add_node("auto_review", auto_review_node)
    graph.add_node("manual_review", manual_review_node)
    graph.add_node("execute", execute_node)
    
    graph.set_entry_point("submit")
    graph.add_edge("submit", "auto_review")
    graph.add_conditional_edges("auto_review", needs_manual_review)
    graph.add_conditional_edges("manual_review", is_approved)
    graph.add_edge("execute", END)
    
    # 使用 MemorySaver 实现持久化和中断恢复
    return graph.compile(checkpointer=MemorySaver())

21.5 复杂工作流的编排模式

21.5.1 编排模式概览

模式适用场景特点
顺序模式步骤间有明确依赖A→B→C→D
并行模式步骤间无依赖A+B+C→聚合
条件分支根据条件选择路径A→(条件)→B或C
循环模式需要反复优化A→B→(条件)→A或END
Saga模式分布式事务A→B→C→(失败)→补偿
Pipeline模式数据流处理ETL风格的连续处理

21.5.2 顺序管道模式

python
from typing import Generic, TypeVar, Callable
from dataclasses import dataclass

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class PipelineResult:
    success: bool
    data: any
    step_name: str = ""
    error: str = ""


class Pipeline(Generic[T]):
    """通用顺序管道"""
    
    def __init__(self, name: str = "pipeline"):
        self.name = name
        self._steps: list[tuple[str, Callable]] = []
    
    def add_step(self, name: str, fn: Callable) -> 'Pipeline':
        self._steps.append((name, fn))
        return self
    
    def execute(self, initial_data: T) -> PipelineResult:
        data = initial_data
        
        for step_name, fn in self._steps:
            try:
                data = fn(data)
                if not isinstance(data, PipelineResult):
                    # 包装为成功的PipelineResult
                    data = PipelineResult(success=True, data=data, step_name=step_name)
                elif not data.success:
                    return data
            except Exception as e:
                return PipelineResult(
                    success=False, data=data, step_name=step_name,
                    error=str(e),
                )
        
        return PipelineResult(success=True, data=data, step_name="complete")


# 使用示例
pipeline = Pipeline("text_processing")
pipeline.add_step("clean", lambda d: d.strip().lower())
pipeline.add_step("tokenize", lambda d: d.split())
pipeline.add_step("filter", lambda d: [w for w in d if len(w) > 2])
pipeline.add_step("count", lambda d: PipelineResult(True, len(d), "count"))

result = pipeline.execute("  Hello World, This is a Test Pipeline  ")
print(f"结果: {result}")

21.5.3 并行扇出-扇入模式

python
import asyncio
from typing import Any


class ParallelFanOut:
    """并行扇出-扇入模式"""
    
    def __init__(self, name: str = "parallel"):
        self.name = name
        self._tasks: list[tuple[str, Callable]] = []
        self._aggregator: Callable = None
    
    def add_task(self, name: str, fn: Callable):
        self._tasks.append((name, fn))
        return self
    
    def set_aggregator(self, fn: Callable):
        self._aggregator = fn
        return self
    
    async def execute(self, input_data: Any) -> dict:
        # 扇出:并行执行所有任务
        results = await asyncio.gather(
            *[fn(input_data) for _, fn in self._tasks],
            return_exceptions=True,
        )
        
        # 收集结果
        task_results = {}
        for i, (name, _) in enumerate(self._tasks):
            if isinstance(results[i], Exception):
                task_results[name] = {"error": str(results[i])}
            else:
                task_results[name] = results[i]
        
        # 扇入:聚合结果
        if self._aggregator:
            return self._aggregator(task_results)
        return task_results


# 使用示例
async def analyze_sentiment(text: str) -> dict:
    await asyncio.sleep(0.1)  # 模拟延迟
    return {"sentiment": "positive", "confidence": 0.92}

async def extract_entities(text: str) -> dict:
    await asyncio.sleep(0.15)
    return {"entities": ["Apple", "iPhone", "Tim Cook"]}

async def classify_topic(text: str) -> dict:
    await asyncio.sleep(0.05)
    return {"topic": "technology", "subtopics": ["mobile", "AI"]}

fanout = ParallelFanOut("text_analysis")
fanout.add_task("sentiment", analyze_sentiment)
fanout.add_task("entities", extract_entities)
fanout.add_task("topic", classify_topic)
fanout.set_aggregator(
    lambda results: {
        "combined": results,
        "summary": f"共分析 {len(results)} 个维度",
    }
)

result = await fanout.execute("Apple released new iPhone with AI features")
# result: {"combined": {...}, "summary": "共分析 3 个维度"}

21.5.4 Saga补偿模式

python
class SagaOrchestrator:
    """
    Saga 模式——处理分布式事务。
    每个步骤都有对应的补偿操作,失败时依次回滚。
    """
    
    def __init__(self):
        self._steps: list[tuple[str, Callable, Callable]] = []
        # (名称, 执行函数, 补偿函数)
    
    def add_step(self, name: str, execute_fn: Callable, compensate_fn: Callable):
        self._steps.append((name, execute_fn, compensate_fn))
        return self
    
    async def execute(self, context: dict) -> dict:
        completed_steps = []
        
        for step_name, execute_fn, compensate_fn in self._steps:
            try:
                result = execute_fn(context)
                if isinstance(result, dict):
                    context.update(result)
                completed_steps.append((step_name, compensate_fn))
            except Exception as e:
                # 失败,执行补偿
                print(f"步骤 {step_name} 失败: {e},开始补偿...")
                for comp_name, comp_fn in reversed(completed_steps):
                    try:
                        comp_fn(context)
                        print(f"补偿 {comp_name} 成功")
                    except Exception as ce:
                        print(f"补偿 {comp_name} 失败: {ce}")
                
                context["saga_status"] = "compensated"
                context["failed_at"] = step_name
                return context
        
        context["saga_status"] = "completed"
        return context


# 使用示例
async def reserve_stock(ctx):
    print(f"预留库存: {ctx.get('order_id')}")
    return {"stock_reserved": True}

async def cancel_stock(ctx):
    print("释放预留库存")

async def charge_payment(ctx):
    print(f"扣款: ¥{ctx.get('amount')}")
    return {"payment_charged": True}

async def refund_payment(ctx):
    print("退款")

async def create_shipment(ctx):
    print(f"创建发货单: {ctx.get('order_id')}")
    return {"shipment_created": True}

async def cancel_shipment(ctx):
    print("取消发货单")


saga = SagaOrchestrator()
saga.add_step("reserve_stock", reserve_stock, cancel_stock)
saga.add_step("charge_payment", charge_payment, refund_payment)
saga.add_step("create_shipment", create_shipment, cancel_shipment)

result = await saga.execute({"order_id": "ORD-001", "amount": 299})

21.6 条件分支与循环

21.6.1 条件分支

python
class ConditionalRouter:
    """条件路由器"""
    
    def __init__(self):
        self._routes: list[tuple[str, Callable, Callable]] = []
    
    def add_route(self, name: str, condition: Callable, handler: Callable):
        self._routes.append((name, condition, handler))
        return self
    
    def add_default(self, handler: Callable):
        self._default = handler
        return self
    
    def route(self, context: dict) -> any:
        for name, condition, handler in self._routes:
            if condition(context):
                return {"route": name, "result": handler(context)}
        if hasattr(self, '_default'):
            return {"route": "default", "result": self._default(context)}
        raise RuntimeError("没有匹配的路由且没有默认处理器")


router = ConditionalRouter()
router.add_route(
    "simple_query", 
    lambda ctx: ctx.get("complexity", "high") == "low",
    lambda ctx: f"简单回答: {ctx['query']}"
)
router.add_route(
    "research_required",
    lambda ctx: ctx.get("complexity", "high") == "high",
    lambda ctx: f"需要深入研究: {ctx['query']}"
)
router.add_default(lambda ctx: "请提供更多信息")

result = router.route({"query": "什么是Python?", "complexity": "low"})

21.6.2 循环模式

python
class IterativeLoop:
    """迭代循环控制器"""
    
    def __init__(
        self,
        max_iterations: int = 5,
        convergence_fn: Callable = None,
    ):
        self.max_iterations = max_iterations
        self.convergence_fn = convergence_fn
        self._step_fn: Callable = None
    
    def set_step(self, fn: Callable):
        self._step_fn = fn
        return self
    
    def execute(self, initial_state: dict) -> dict:
        state = initial_state
        history = []
        
        for i in range(self.max_iterations):
            # 执行一步
            state = self._step_fn(state)
            history.append({
                "iteration": i + 1,
                "state_snapshot": {k: v for k, v in state.items() if k != "history"},
            })
            
            # 检查收敛
            if self.convergence_fn and self.convergence_fn(state):
                state["converged"] = True
                state["iterations"] = i + 1
                break
        else:
            state["converged"] = False
            state["iterations"] = self.max_iterations
        
        state["history"] = history
        return state


# 使用示例:迭代优化Prompt
def optimize_prompt_step(state: dict) -> dict:
    current = state.get("current_prompt", "")
    score = state.get("score", 0)
    
    # 模拟优化:每次迭代提升分数
    new_score = min(1.0, score + 0.15 + (hash(current) % 10) / 100)
    state["score"] = new_score
    state["current_prompt"] = f"{current}\n[优化 #{state.get('iteration', 0) + 1}]"
    state["iteration"] = state.get("iteration", 0) + 1
    return state


loop = IterativeLoop(
    max_iterations=10,
    convergence_fn=lambda s: s.get("score", 0) >= 0.9,
)
loop.set_step(optimize_prompt_step)

result = loop.execute({
    "current_prompt": "你是一个助手",
    "score": 0.3,
})
print(f"迭代 {result['iterations']} 次,最终分数: {result['score']}")

21.7 人工介入(Human-in-the-Loop)设计

21.7.1 HITL的必要性

并非所有决策都应该由 Agent 自主完成。以下场景需要人工介入:

  • 高风险操作:大额资金转账、数据删除、权限变更
  • 模糊决策:多个方案优劣不明显,需要人类判断
  • 合规要求:法律、医疗等受监管领域
  • 质量控制:对关键输出进行人工审核

21.7.2 人工介入框架

python
import time
from enum import Enum
from typing import Optional
from dataclasses import dataclass


class ApprovalStatus(Enum):
    PENDING = "pending"
    APPROVED = "approved"
    REJECTED = "rejected"
    EXPIRED = "expired"


@dataclass
class HumanApprovalRequest:
    """人工审批请求"""
    request_id: str
    action_type: str           # 操作类型
    description: str           # 操作描述
    risk_level: str            # 风险等级: low/medium/high/critical
    context_data: dict         # 上下文数据
    status: ApprovalStatus = ApprovalStatus.PENDING
    created_at: float = field(default_factory=time.time)
    resolved_at: Optional[float] = None
    reviewer: Optional[str] = None
    review_notes: Optional[str] = None
    
    @property
    def is_expired(self) -> bool:
        """超时检查(默认24小时)"""
        return (time.time() - self.created_at) > 86400


class HumanInTheLoopManager:
    """人工介入管理器"""
    
    def __init__(
        self,
        auto_approve_below: str = "low",
        timeout_seconds: int = 86400,
    ):
        self.auto_approve_below = auto_approve_below
        self.timeout_seconds = timeout_seconds
        self._pending: dict[str, HumanApprovalRequest] = {}
        self._risk_handlers: dict[str, Callable] = {}
        self._reviewer_queue: list[str] = []  # 审核人员队列
    
    def set_risk_handler(self, risk_level: str, handler: Callable):
        """为不同风险等级设置处理策略"""
        self._risk_handlers[risk_level] = handler
    
    def request_approval(
        self, 
        action_type: str, 
        description: str,
        risk_level: str,
        context_data: dict,
    ) -> HumanApprovalRequest:
        """提交审批请求"""
        import uuid
        request = HumanApprovalRequest(
            request_id=str(uuid.uuid4())[:8],
            action_type=action_type,
            description=description,
            risk_level=risk_level,
            context_data=context_data,
        )
        
        # 低风险自动审批
        risk_order = ["low", "medium", "high", "critical"]
        if (risk_order.index(risk_level) 
            <= risk_order.index(self.auto_approve_below)):
            request.status = ApprovalStatus.APPROVED
            request.resolved_at = time.time()
            request.reviewer = "system:auto"
            request.review_notes = "低风险自动审批"
            return request
        
        # 需要人工审核
        self._pending[request.request_id] = request
        
        # 通知审核人员(实际中通过消息队列/邮件/通知)
        print(f"[审批请求] ID: {request.request_id}, "
              f"操作: {action_type}, 风险: {risk_level}")
        print(f"  描述: {description}")
        
        return request
    
    def approve(self, request_id: str, reviewer: str, notes: str = "") -> bool:
        """审批通过"""
        request = self._pending.get(request_id)
        if not request:
            return False
        
        request.status = ApprovalStatus.APPROVED
        request.resolved_at = time.time()
        request.reviewer = reviewer
        request.review_notes = notes
        del self._pending[request_id]
        return True
    
    def reject(self, request_id: str, reviewer: str, reason: str) -> bool:
        """拒绝"""
        request = self._pending.get(request_id)
        if not request:
            return False
        
        request.status = ApprovalStatus.REJECTED
        request.resolved_at = time.time()
        request.reviewer = reviewer
        request.review_notes = reason
        del self._pending[request_id]
        return True
    
    def check_pending(self, request_id: str) -> Optional[ApprovalStatus]:
        """检查审批状态"""
        request = self._pending.get(request_id)
        if not request:
            return None
        
        if request.is_expired:
            request.status = ApprovalStatus.EXPIRED
            request.resolved_at = time.time()
            del self._pending[request_id]
            return ApprovalStatus.EXPIRED
        
        return request.status
    
    def get_pending_requests(self) -> list[HumanApprovalRequest]:
        """获取所有待审批请求"""
        return list(self._pending.values())


# 使用示例
hitl = HumanInTheLoopManager(auto_approve_below="low")

# 低风险操作——自动审批
req1 = hitl.request_approval(
    "send_email", "发送确认邮件给用户",
    "low", {"user_id": "u123", "email": "user@example.com"}
)
print(f"邮件发送审批: {req1.status.value}")

# 高风险操作——需要人工
req2 = hitl.request_approval(
    "refund", "退款 ¥5,000",
    "high", {"order_id": "ORD-001", "amount": 5000}
)
print(f"退款审批: {req2.status.value}")

# 模拟人工审批
hitl.approve(req2.request_id, reviewer="manager_zhang", notes="确认用户已退回商品")
print(f"退款审批: {hitl.check_pending(req2.request_id)}")

21.7.3 中断-恢复模式

python
class InterruptibleWorkflow:
    """
    可中断的工作流——在特定节点暂停执行,
    等待外部输入后恢复。
    """
    
    def __init__(self):
        self._nodes: list[tuple[str, Callable]] = []
        self._interrupt_points: set[str] = set()
        self._state: dict = {}
        self._resume_data: dict = {}
    
    def add_node(self, name: str, fn: Callable):
        self._nodes.append((name, fn))
    
    def set_interrupt(self, name: str):
        """设置中断点"""
        self._interrupt_points.add(name)
    
    async def execute(self, initial_state: dict) -> dict:
        self._state = initial_state
        execution_pointer = self._state.get("_resume_at", 0)
        
        for i in range(execution_pointer, len(self._nodes)):
            name, fn = self._nodes[i]
            
            # 检查是否是中断点
            if name in self._interrupt_points:
                if not self._state.get("_interrupt_resolved", False):
                    self._state["_interrupted_at"] = name
                    self._state["_interrupted_node_index"] = i
                    return self._state  # 暂停执行
            
            # 执行节点
            result = fn(self._state)
            if isinstance(result, dict):
                self._state.update(result)
        
        self._state["_completed"] = True
        return self._state
    
    def resume(self, input_data: dict) -> dict:
        """从中断点恢复"""
        self._state.update(input_data)
        self._state["_interrupt_resolved"] = True
        return self._state  # 调用方需要再次调用 execute

21.8 最佳实践

21.8.1 状态机设计原则

  1. 状态应该正交:每个状态代表一个独立的阶段,不要让两个状态做类似的事
  2. 转换应该幂等:重复调用同一个转换不应该产生副作用
  3. 始终定义错误状态:确保任何状态下出错都能到达一个明确的错误处理状态
  4. 可视化状态图:用 Mermaid 或 Graphviz 画出状态图,确保逻辑正确
  5. 状态持久化:生产环境中,状态必须持久化到数据库,防止进程崩溃丢失

21.8.2 常见陷阱

陷阱1:过度设计状态机

不是所有 Agent 都需要状态机。简单的单轮问答或无状态 API 调用,用状态机反而增加复杂度。

陷阱2:忽略超时

状态机中的每个状态都应该有超时机制,否则 Agent 可能永远卡在某个状态。

陷阱3:状态爆炸

避免创建过多状态。当状态数量超过 10 个时,考虑是否可以用参数化状态替代。


21.9 本章小结

本章系统介绍了 Agent 状态机与流程编排的核心概念和实现:

  1. **有限状态机(FSM)**为 Agent 提供了确定性的行为控制
  2. LangGraph 提供了声明式的状态图编排能力
  3. 编排模式(顺序、并行、条件、Saga)解决了不同类型的工作流需求
  4. 人工介入(HITL) 是高风险场景下不可或缺的安全阀

记住:好的 Agent 架构不是让 AI 做所有决策,而是在合适的时机引入确定性控制和人类智慧。

基于 MIT 许可发布