Skip to content

第27章:代码审查助手

让 AI 成为你的 Code Reviewer——构建一个自动化代码审查系统


27.1 需求分析与功能规划

27.1.1 业务背景

代码审查(Code Review)是软件工程中至关重要但极其耗时的环节:

  1. 效率瓶颈:平均每个 PR 的审查时间 30-60 分钟,资深工程师每周 10+ 小时花在 Code Review 上
  2. 质量波动:人工审查容易遗漏问题,尤其是安全漏洞和性能隐患
  3. 知识传递慢:新人代码缺少建设性的改进建议
  4. 标准不统一:不同审查者关注点不同,导致代码风格不一致

我们需要一个代码审查助手,实现:

  • 自动化审查:PR 提交后自动分析,在开发者等待人工 Review 时给出即时反馈
  • 多维度检查:代码规范、安全漏洞、性能问题、最佳实践
  • 可操作建议:不只是指出问题,还要给出修复方案和代码示例
  • Git 深度集成:无缝融入现有开发工作流

27.1.2 功能清单

功能描述优先级
AST 语法分析解析代码结构,检测语法问题P0
代码规范检查PEP8、命名规范、文档规范P0
安全漏洞扫描SQL注入、XSS、硬编码密钥等P0
性能问题检测N+1 查询、内存泄漏、不必要的复制P1
复杂度分析圈复杂度、认知复杂度P1
修复建议生成问题代码 + 修复代码对比P0
Git PR 集成自动评论、状态检查P1
审查报告生成HTML/PDF 格式的完整报告P2

27.1.3 技术选型

  • AST 解析:Python ast 模块 + tree-sitter(多语言支持)
  • 安全扫描:自建规则引擎 + Semgrep 集成
  • LLM 分析:GPT-4o 处理语义层面的代码理解
  • Git 集成:GitPython + GitHub/GitLab API

27.2 架构设计

code-review-assistant/
├── app/
│   ├── main.py
│   ├── config.py
│   ├── analyzers/              # 分析器
│   │   ├── ast_analyzer.py     # AST 分析
│   │   ├── security_analyzer.py # 安全扫描
│   │   ├── performance_analyzer.py # 性能检测
│   │   ├── style_analyzer.py   # 风格检查
│   │   └── complexity_analyzer.py # 复杂度分析
│   ├── agents/                 # Agent
│   │   ├── review_agent.py     # 审查协调 Agent
│   │   └── fix_agent.py        # 修复建议 Agent
│   ├── git/                    # Git 集成
│   │   ├── git_client.py
│   │   └── github_client.py
│   ├── models/
│   │   ├── issue.py            # 问题模型
│   │   └── report.py           # 报告模型
│   └── utils/
│       └── llm_client.py
├── tests/
└── requirements.txt

27.3 核心代码实现

27.3.1 数据模型

python
# app/models/issue.py
"""代码审查问题模型"""

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional


class Severity(str, Enum):
    INFO = "info"
    WARNING = "warning"
    ERROR = "error"
    CRITICAL = "critical"


class IssueCategory(str, Enum):
    STYLE = "style"          # 代码风格
    SECURITY = "security"    # 安全漏洞
    PERFORMANCE = "performance" # 性能问题
    COMPLEXITY = "complexity"  # 复杂度
    BEST_PRACTICE = "best_practice" # 最佳实践
    BUG_RISK = "bug_risk"    # 潜在 Bug


@dataclass
class CodeLocation:
    """代码位置"""
    file_path: str
    line_start: int
    line_end: int
    column_start: int = 0
    column_end: int = 0
    
    def __str__(self):
        if self.line_start == self.line_end:
            return f"{self.file_path}:{self.line_start}"
        return f"{self.file_path}:{self.line_start}-{self.line_end}"


@dataclass
class ReviewIssue:
    """审查问题"""
    id: str
    title: str
    description: str
    severity: Severity
    category: IssueCategory
    location: CodeLocation
    code_snippet: str = ""
    fix_suggestion: str = ""
    fixed_code: str = ""
    confidence: float = 0.0
    cwe_id: Optional[str] = None  # CWE 漏洞编号
    rule_id: Optional[str] = None # 规则编号

27.3.2 AST 分析器

python
# app/analyzers/ast_analyzer.py
"""AST 语法分析器"""

import ast
from typing import List
from app.models.issue import ReviewIssue, Severity, IssueCategory, CodeLocation


class ASTAnalyzer:
    """Python AST 分析器"""
    
    def analyze(self, file_path: str, source_code: str) -> List[ReviewIssue]:
        """分析 Python 源码"""
        issues = []
        
        try:
            tree = ast.parse(source_code, filename=file_path)
        except SyntaxError as e:
            issues.append(ReviewIssue(
                id=f"ast-syntax-{e.lineno}",
                title="语法错误",
                description=f"Python 语法错误: {e.msg}",
                severity=Severity.ERROR,
                category=IssueCategory.BUG_RISK,
                location=CodeLocation(file_path, e.lineno or 1, e.lineno or 1),
                code_snippet=self._get_line(source_code, e.lineno or 1),
                confidence=1.0,
            ))
            return issues
        
        # 运行各类检查
        issues.extend(self._check_unused_imports(tree, file_path, source_code))
        issues.extend(self._check_bare_except(tree, file_path, source_code))
        issues.extend(self._check_mutable_defaults(tree, file_path, source_code))
        issues.extend(self._check_shadowing(tree, file_path, source_code))
        issues.extend(self._check_unused_variables(tree, file_path, source_code))
        
        return issues
    
    def _get_line(self, source: str, lineno: int) -> str:
        lines = source.split('\n')
        if 1 <= lineno <= len(lines):
            return lines[lineno - 1].strip()
        return ""
    
    def _check_unused_imports(self, tree, path, source) -> List[ReviewIssue]:
        """检查未使用的 import"""
        issues = []
        # 收集所有 import 名称
        imported_names = set()
        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    name = alias.asname or alias.name
                    imported_names.add(name)
            elif isinstance(node, ast.ImportFrom):
                for alias in node.names:
                    if alias.name != '*':
                        name = alias.asname or alias.name
                        imported_names.add(name)
        
        # 简化检查:如果 import 了但没在其他地方出现
        source_lines = source.split('\n')
        for name in imported_names:
            # 排除常见的 "导入即使用" 场景(如 __all__、类型注解等)
            occurrences = sum(1 for line in source_lines if name in line)
            if occurrences == 1:  # 只在 import 语句中出现
                for i, line in enumerate(source_lines, 1):
                    if name in line and ('import' in line or 'from' in line):
                        issues.append(ReviewIssue(
                            id=f"ast-unused-import-{i}",
                            title=f"未使用的导入: {name}",
                            description=f"导入了 `{name}` 但未在代码中使用。",
                            severity=Severity.WARNING,
                            category=IssueCategory.STYLE,
                            location=CodeLocation(path, i, i),
                            code_snippet=line.strip(),
                            fix_suggestion=f"删除未使用的 import 语句。",
                            fixed_code="",  # 空行
                            confidence=0.7,
                        ))
                        break
        
        return issues
    
    def _check_bare_except(self, tree, path, source) -> List[ReviewIssue]:
        """检查裸 except"""
        issues = []
        for node in ast.walk(tree):
            if isinstance(node, ast.ExceptHandler):
                if node.type is None:
                    issues.append(ReviewIssue(
                        id=f"ast-bare-except-{node.lineno}",
                        title="裸 except 子句",
                        description=(
                            "使用裸 `except:` 会捕获所有异常,包括 "
                            "KeyboardInterrupt 和 SystemExit,"
                            "可能隐藏严重错误。"),
                        severity=Severity.WARNING,
                        category=IssueCategory.BEST_PRACTICE,
                        location=CodeLocation(path, node.lineno, node.lineno),
                        code_snippet=self._get_line(source, node.lineno),
                        fix_suggestion="使用 `except Exception:` 代替裸 except",
                        fixed_code="except Exception as e:",
                        confidence=0.95,
                    ))
        return issues
    
    def _check_mutable_defaults(self, tree, path, source) -> List[ReviewIssue]:
        """检查可变默认参数"""
        issues = []
        mutable_types = (ast.List, ast.Dict, ast.Set)
        
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                for default in node.args.defaults + node.args.kw_defaults:
                    if default and isinstance(default, mutable_types):
                        issues.append(ReviewIssue(
                            id=f"ast-mutable-default-{node.lineno}",
                            title=f"可变默认参数: {node.name}",
                            description=(
                                f"函数 `{node.name}` 使用了可变对象作为默认参数。"
                                f"Python 中可变默认参数在函数定义时创建一次,"
                                f"所有调用共享同一个对象,会导致意外的状态共享。"),
                            severity=Severity.WARNING,
                            category=IssueCategory.BUG_RISK,
                            location=CodeLocation(path, node.lineno, node.lineno),
                            code_snippet=self._get_line(source, node.lineno),
                            fix_suggestion="使用 None 作为默认值,在函数体内创建",
                            fixed_code=(
                                f"def {node.name}(..., param=None):\n"
                                f"    if param is None:\n"
                                f"        param = []  # 或 {}"
                            ),
                            confidence=0.95,
                        ))
        return issues
    
    def _check_shadowing(self, tree, path, source) -> List[ReviewIssue]:
        """检查变量名覆盖内置名称"""
        issues = []
        builtins = {
            'list', 'dict', 'set', 'str', 'int', 'float', 'bool',
            'type', 'id', 'input', 'print', 'range', 'len', 'max',
            'min', 'sum', 'open', 'file', 'dir', 'vars', 'hash',
            'object', 'super', 'property', 'staticmethod', 'classmethod',
        }
        
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                for arg in node.args.args:
                    if arg.arg in builtins:
                        issues.append(ReviewIssue(
                            id=f"ast-shadow-{node.lineno}",
                            title=f"覆盖内置名称: {arg.arg}",
                            description=(
                                f"参数名 `{arg.arg}` 覆盖了 Python 内置函数/类型。"
                                f"这会导致在该函数作用域内无法使用原始的 `{arg.arg}`。"),
                            severity=Severity.INFO,
                            category=IssueCategory.BEST_PRACTICE,
                            location=CodeLocation(path, node.lineno, node.lineno),
                            code_snippet=self._get_line(source, node.lineno),
                            fix_suggestion=f"将参数名改为 `{arg.arg}_` 或更具描述性的名称",
                            confidence=0.8,
                        ))
            
            elif isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name) and target.id in builtins:
                        issues.append(ReviewIssue(
                            id=f"ast-shadow-assign-{node.lineno}",
                            title=f"覆盖内置名称: {target.id}",
                            description=f"变量 `{target.id}` 覆盖了内置名称。",
                            severity=Severity.INFO,
                            category=IssueCategory.BEST_PRACTICE,
                            location=CodeLocation(path, node.lineno, node.lineno),
                            code_snippet=self._get_line(source, node.lineno),
                            confidence=0.8,
                        ))
        return issues
    
    def _check_unused_variables(self, tree, path, source) -> List[ReviewIssue]:
        """检查未使用的变量"""
        issues = []
        
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                # 收集所有赋值
                assigned = set()
                used = set()
                
                for child in ast.walk(node):
                    if isinstance(child, ast.Assign):
                        for t in child.targets:
                            if isinstance(t, ast.Name):
                                assigned.add(t.id)
                    elif isinstance(child, ast.Name) and isinstance(child.ctx, ast.Load):
                        used.add(child.id)
                
                unused = assigned - used
                for var in unused:
                    if not var.startswith('_'):  # _ 前缀的变量约定为未使用
                        issues.append(ReviewIssue(
                            id=f"ast-unused-var-{node.lineno}",
                            title=f"未使用的变量: {var}",
                            description=f"变量 `{var}` 在函数 `{node.name}` 中赋值后未使用。",
                            severity=Severity.INFO,
                            category=IssueCategory.STYLE,
                            location=CodeLocation(path, node.lineno, node.lineno),
                            confidence=0.6,
                        ))
        
        return issues

27.3.3 安全扫描分析器

python
# app/analyzers/security_analyzer.py
"""安全漏洞扫描分析器"""

import re
import ast
from typing import List
from app.models.issue import ReviewIssue, Severity, IssueCategory, CodeLocation


class SecurityAnalyzer:
    """安全漏洞扫描器"""
    
    # 常见危险模式
    DANGEROUS_PATTERNS = [
        {
            "pattern": r'(?:password|passwd|pwd|secret|api_key|apikey|token)\s*=\s*["\'][^"\']+["\']',
            "title": "硬编码敏感信息",
            "description": "检测到硬编码的密码、密钥或令牌。应使用环境变量或密钥管理服务。",
            "severity": Severity.CRITICAL,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-798",
            "fix": "使用 os.environ.get('SECRET_NAME') 或配置管理工具。",
        },
        {
            "pattern": r'eval\s*\(',
            "title": "使用 eval()",
            "description": "eval() 会执行任意 Python 代码,存在代码注入风险。",
            "severity": Severity.CRITICAL,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-95",
            "fix": "使用 ast.literal_eval() 或 json.loads() 替代。",
        },
        {
            "pattern": r'exec\s*\(',
            "title": "使用 exec()",
            "description": "exec() 会执行任意代码,存在代码注入风险。",
            "severity": Severity.CRITICAL,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-78",
            "fix": "避免使用 exec(),考虑使用函数或策略模式。",
        },
        {
            "pattern": r'pickle\.loads?\s*\(',
            "title": "不安全的反序列化",
            "description": "pickle 反序列化可能执行任意代码。",
            "severity": Severity.CRITICAL,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-502",
            "fix": "使用 json 或其他安全序列化格式。",
        },
        {
            "pattern": r'subprocess\.(?:call|run|Popen)\s*\([^)]*shell\s*=\s*True',
            "title": "Shell 注入风险",
            "description": "subprocess 的 shell=True 参数可能导致命令注入。",
            "severity": Severity.ERROR,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-78",
            "fix": "使用 shell=False 并传递参数列表。",
        },
        {
            "pattern": r'(?:cursor\.execute|\.execute)\s*\(\s*f["\']|\.format\s*\(|%\s*[' + "'" + r']\s*\)',
            "title": "潜在的 SQL 注入",
            "description": "检测到字符串拼接/格式化构建 SQL 语句。应使用参数化查询。",
            "severity": Severity.CRITICAL,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-89",
            "fix": "使用参数化查询:cursor.execute('SELECT * FROM users WHERE id = %s', (user_id,))",
        },
        {
            "pattern": r'(?:(?:verify|ssl_verify|cert_verify)\s*=\s*False)',
            "title": "禁用 SSL 证书验证",
            "description": "禁用 SSL 验证会暴露中间人攻击风险。",
            "severity": Severity.ERROR,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-295",
            "fix": "启用 SSL 验证,或使用自定义 CA 证书。",
        },
        {
            "pattern": r'(?:requests|httpx|aiohttp)\.(?:get|post|put|delete|patch)\s*\([^)]*verify\s*=\s*False',
            "title": "HTTP 请求禁用 SSL 验证",
            "description": "HTTP 客户端禁用了 SSL 证书验证。",
            "severity": Severity.ERROR,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-295",
            "fix": "移除 verify=False 参数。",
        },
        {
            "pattern": r'debug\s*=\s*True',
            "title": "生产环境调试模式",
            "description": "DEBUG=True 可能暴露敏感信息。",
            "severity": Severity.WARNING,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-489",
            "fix": "使用环境变量控制 DEBUG 状态。",
        },
        {
            "pattern": r'\.raw\s*=\s*request\.(?:body|data|get|post)',
            "title": "未验证的用户输入",
            "description": "直接使用未经验证的请求数据。",
            "severity": Severity.WARNING,
            "category": IssueCategory.SECURITY,
            "cwe": "CWE-20",
            "fix": "使用 Pydantic 或其他验证库验证输入。",
        },
    ]
    
    def analyze(self, file_path: str, source_code: str) -> List[ReviewIssue]:
        """安全扫描"""
        issues = []
        lines = source_code.split('\n')
        
        for rule in self.DANGEROUS_PATTERNS:
            pattern = re.compile(rule["pattern"])
            for i, line in enumerate(lines, 1):
                if pattern.search(line):
                    issues.append(ReviewIssue(
                        id=f"sec-{rule['cwe']}-{i}",
                        title=rule["title"],
                        description=rule["description"],
                        severity=rule["severity"],
                        category=rule["category"],
                        location=CodeLocation(file_path, i, i),
                        code_snippet=line.strip(),
                        fix_suggestion=rule["fix"],
                        confidence=0.85,
                        cwe_id=rule["cwe"],
                    ))
        
        return issues

27.3.4 性能分析器

python
# app/analyzers/performance_analyzer.py
"""性能问题检测分析器"""

import ast
import re
from typing import List
from app.models.issue import ReviewIssue, Severity, IssueCategory, CodeLocation


class PerformanceAnalyzer:
    """性能问题检测器"""
    
    def analyze(self, file_path: str, source_code: str) -> List[ReviewIssue]:
        issues = []
        lines = source_code.split('\n')
        
        issues.extend(self._check_string_concat_in_loop(
            file_path, source_code, lines))
        issues.extend(self._check_list_append_in_loop(
            file_path, source_code, lines))
        issues.extend(self._check_n_plus_one_pattern(
            file_path, source_code))
        issues.extend(self._check_global_lookup(
            file_path, source_code, lines))
        
        return issues
    
    def _check_string_concat_in_loop(
        self, path, source, lines
    ) -> List[ReviewIssue]:
        """检测循环中的字符串拼接"""
        issues = []
        in_loop = False
        loop_start = 0
        
        for i, line in enumerate(lines, 1):
            stripped = line.strip()
            if re.match(r'(?:for|while)\s+', stripped):
                in_loop = True
                loop_start = i
            elif in_loop and stripped and not stripped.startswith('#'):
                if re.search(r'\w+\s*\+=\s*["\']', stripped):
                    issues.append(ReviewIssue(
                        id=f"perf-str-concat-{i}",
                        title="循环中字符串拼接",
                        description=(
                            "循环中使用 += 拼接字符串效率低下。"
                            "每次拼接都会创建新字符串对象。"),
                        severity=Severity.INFO,
                        category=IssueCategory.PERFORMANCE,
                        location=CodeLocation(path, loop_start, i),
                        code_snippet=stripped,
                        fix_suggestion="使用列表收集后 join,或 io.StringIO",
                        fixed_code=(
                            "parts = []\n"
                            "for item in items:\n"
                            "    parts.append(str(item))\n"
                            "result = ''.join(parts)"
                        ),
                        confidence=0.7,
                    ))
            elif in_loop and stripped == "":
                in_loop = False
        
        return issues
    
    def _check_list_append_in_loop(
        self, path, source, lines
    ) -> List[ReviewIssue]:
        """检测可以用列表推导式替代的循环"""
        issues = []
        
        try:
            tree = ast.parse(source)
        except SyntaxError:
            return issues
        
        for node in ast.walk(tree):
            if isinstance(node, ast.For):
                # 检查 for 循环体是否只有简单的 append 操作
                if (len(node.body) == 1 and
                    isinstance(node.body[0], ast.Expr) and
                    isinstance(node.body[0].value, ast.Call)):
                    call = node.body[0].value
                    if (isinstance(call.func, ast.Attribute) and
                        call.func.attr == 'append'):
                        issues.append(ReviewIssue(
                            id=f"perf-list-comp-{node.lineno}",
                            title="可用列表推导式替代",
                            description=(
                                "简单的 for 循环 + append 可以用列表推导式替代,"
                                "更简洁且通常更快。"),
                            severity=Severity.INFO,
                            category=IssueCategory.PERFORMANCE,
                            location=CodeLocation(path, node.lineno,
                                                 node.end_lineno or node.lineno),
                            code_snippet=lines[node.lineno - 1].strip(),
                            fix_suggestion="使用列表推导式",
                            fixed_code=(
                                f"result = [{ast.unparse(call.args[0])} "
                                f"for {ast.unparse(node.target)} in "
                                f"{ast.unparse(node.iter)}]"
                            ),
                            confidence=0.75,
                        ))
        
        return issues
    
    def _check_n_plus_one_pattern(
        self, path, source
    ) -> List[ReviewIssue]:
        """检测 N+1 查询模式"""
        issues = []
        
        # 简单启发式:循环内有 execute/query 操作
        pattern = re.compile(
            r'for\s+\w+\s+in\s+\w+.*:.*'
            r'(?:execute|query|cursor|session\.query|db\.query)',
            re.DOTALL
        )
        
        if pattern.search(source):
            issues.append(ReviewIssue(
                id=f"perf-n-plus-one-1",
                title="潜在的 N+1 查询问题",
                description=(
                    "检测到循环内可能有数据库查询。N+1 查询会导致大量数据库往返。"
                    "建议使用批量查询(IN 子句、JOIN、prefetch_related 等)。"),
                severity=Severity.WARNING,
                category=IssueCategory.PERFORMANCE,
                location=CodeLocation(path, 1, 1),
                fix_suggestion="使用批量查询替代循环查询",
                fixed_code=(
                    "# 不好\n"
                    "for user in users:\n"
                    "    orders = db.query(Order).filter_by(user_id=user.id)\n\n"
                    "# 好\n"
                    "user_ids = [u.id for u in users]\n"
                    "orders = db.query(Order).filter(Order.user_id.in_(user_ids))"
                ),
                confidence=0.5,  # 启发式,置信度较低
            ))
        
        return issues
    
    def _check_global_lookup(self, path, source, lines):
        """检测循环中的全局变量查找"""
        issues = []
        try:
            tree = ast.parse(source)
        except SyntaxError:
            return issues
        
        for node in ast.walk(tree):
            if isinstance(node, (ast.For, ast.While)):
                # 检查循环体内是否有 len() 调用
                for child in ast.walk(node):
                    if (isinstance(child, ast.Call) and
                        isinstance(child.func, ast.Name) and
                        child.func.id == 'len'):
                        issues.append(ReviewIssue(
                            id=f"perf-len-in-loop-{node.lineno}",
                            title="循环中重复调用 len()",
                            description="循环中重复计算 len(),可将结果缓存到局部变量。",
                            severity=Severity.INFO,
                            category=IssueCategory.PERFORMANCE,
                            location=CodeLocation(path, node.lineno, node.lineno),
                            fix_suggestion="将 len() 结果缓存到变量",
                            fixed_code=(
                                "n = len(items)  # 在循环外计算\n"
                                "for i in range(n):\n"
                                "    ..."
                            ),
                            confidence=0.6,
                        ))
                        break  # 每个 loop 只报一次
        
        return issues

27.3.5 代码审查协调 Agent

python
# app/agents/review_agent.py
"""代码审查协调 Agent"""

import os
import ast
from typing import List, Optional
from dataclasses import dataclass, field
from app.models.issue import ReviewIssue, Severity, IssueCategory
from app.analyzers.ast_analyzer import ASTAnalyzer
from app.analyzers.security_analyzer import SecurityAnalyzer
from app.analyzers.performance_analyzer import PerformanceAnalyzer
from app.utils.llm_client import llm_client


@dataclass
class FileReview:
    """单文件审查结果"""
    file_path: str
    language: str
    total_lines: int
    issues: List[ReviewIssue] = field(default_factory=list)
    
    @property
    def critical_count(self) -> int:
        return sum(1 for i in self.issues if i.severity == Severity.CRITICAL)
    
    @property
    def error_count(self) -> int:
        return sum(1 for i in self.issues if i.severity == Severity.ERROR)


@dataclass
class ReviewReport:
    """审查报告"""
    files: List[FileReview] = field(default_factory=list)
    summary: str = ""
    overall_score: float = 0.0  # 0-100
    
    @property
    def total_issues(self) -> int:
        return sum(len(f.issues) for f in self.files)
    
    @property
    def critical_issues(self) -> List[ReviewIssue]:
        return [i for f in self.files for i in f.issues
                if i.severity == Severity.CRITICAL]


class ReviewAgent:
    """代码审查协调 Agent"""
    
    LANGUAGE_MAP = {
        '.py': 'python',
        '.js': 'javascript',
        '.ts': 'typescript',
        '.go': 'go',
        '.java': 'java',
        '.rs': 'rust',
    }
    
    def __init__(self):
        self._ast_analyzer = ASTAnalyzer()
        self._security_analyzer = SecurityAnalyzer()
        self._perf_analyzer = PerformanceAnalyzer()
    
    def review_file(self, file_path: str) -> FileReview:
        """审查单个文件"""
        ext = os.path.splitext(file_path)[1]
        language = self.LANGUAGE_MAP.get(ext, 'unknown')
        
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            source = f.read()
        
        lines = source.count('\n') + 1
        issues = []
        
        if language == 'python':
            issues.extend(self._ast_analyzer.analyze(file_path, source))
        issues.extend(self._security_analyzer.analyze(file_path, source))
        issues.extend(self._perf_analyzer.analyze(file_path, source))
        
        # 按严重程度排序
        severity_order = {
            Severity.CRITICAL: 0, Severity.ERROR: 1,
            Severity.WARNING: 2, Severity.INFO: 3,
        }
        issues.sort(key=lambda x: severity_order.get(x.severity, 99))
        
        return FileReview(
            file_path=file_path,
            language=language,
            total_lines=lines,
            issues=issues,
        )
    
    async def review_with_llm(
        self, file_path: str, source_code: str
    ) -> List[ReviewIssue]:
        """使用 LLM 进行深度语义审查"""
        
        system_prompt = """你是一个资深代码审查专家。审查以下代码,找出潜在问题。

审查维度:
1. 逻辑错误和潜在 Bug
2. 安全漏洞
3. 性能问题
4. 代码可维护性和可读性
5. 最佳实践违反

对每个问题提供:
- 问题描述
- 严重程度(critical/error/warning/info)
- 修复建议
- 修复后的代码

返回 JSON 数组:
[{"title":"...","description":"...","severity":"...","line":N,"fix":"...","fixed_code":"..."}]"""
        
        # 截取代码(避免超出 token 限制)
        max_chars = 8000
        code = source_code[:max_chars]
        if len(source_code) > max_chars:
            code += "\n\n# ... (代码被截断)"
        
        try:
            import json
            result = await llm_client.chat_json(
                messages=[{"role": "user",
                    "content": f"文件: {file_path}\n\n```{code}```"}],
                system_prompt=system_prompt,
            )
            
            issues = []
            for item in result if isinstance(result, list) else []:
                severity_map = {
                    "critical": Severity.CRITICAL,
                    "error": Severity.ERROR,
                    "warning": Severity.WARNING,
                    "info": Severity.INFO,
                }
                issues.append(ReviewIssue(
                    id=f"llm-{file_path}-{item.get('line', 0)}",
                    title=item.get("title", ""),
                    description=item.get("description", ""),
                    severity=severity_map.get(
                        item.get("severity", "info"), Severity.INFO),
                    category=IssueCategory.BEST_PRACTICE,
                    location=type('obj', (), {
                        'file_path': file_path,
                        '__str__': lambda self: f"{file_path}:{item.get('line', 0)}"
                    })(),
                    fix_suggestion=item.get("fix", ""),
                    fixed_code=item.get("fixed_code", ""),
                    confidence=0.7,
                ))
            return issues
        except Exception:
            return []
    
    def generate_summary(self, report: ReviewReport) -> str:
        """生成审查摘要"""
        total = report.total_issues
        critical = len(report.critical_issues)
        
        if critical > 0:
            score = max(20, 100 - critical * 20 - total * 2)
        else:
            score = max(50, 100 - total * 5)
        report.overall_score = score
        
        lines = [
            f"## 代码审查报告",
            f"",
            f"**审查文件**: {len(report.files)} 个",
            f"**总代码行数**: {sum(f.total_lines for f in report.files)}",
            f"**发现问题**: {total} 个",
            f"  - 🔴 严重: {critical}",
            f"  - 🟠 错误: {sum(f.error_count for f in report.files)}",
            f"  - 🟡 警告: {sum(sum(1 for i in f.issues if i.severity == Severity.WARNING) for f in report.files)}",
            f"  - 🔵 建议: {sum(sum(1 for i in f.issues if i.severity == Severity.INFO) for f in report.files)}",
            f"",
            f"**评分**: {score}/100",
        ]
        
        if critical > 0:
            lines.append(f"\n### ⚠️ 严重问题(必须修复)")
            for issue in report.critical_issues[:5]:
                lines.append(f"- **{issue.title}** ({issue.location})")
        
        return '\n'.join(lines)
    
    async def generate_github_comment(
        self, file_review: FileReview
    ) -> str:
        """生成 GitHub PR 评论"""
        if not file_review.issues:
            return f"✅ `{file_review.file_path}` 审查通过,没有发现问题。"
        
        parts = [f"## 🔍 `{file_review.file_path}` 审查结果\n"]
        parts.append(f"发现 **{len(file_review.issues)}** 个问题:\n")
        
        severity_emoji = {
            Severity.CRITICAL: "🔴",
            Severity.ERROR: "🟠",
            Severity.WARNING: "🟡",
            Severity.INFO: "🔵",
        }
        
        for issue in file_review.issues[:10]:  # 最多显示 10 个
            emoji = severity_emoji.get(issue.severity, "⚪")
            parts.append(f"### {emoji} {issue.title}")
            parts.append(f"> {issue.description}")
            if issue.fix_suggestion:
                parts.append(f"\n**建议**: {issue.fix_suggestion}")
            if issue.fixed_code:
                parts.append(f"\n```python\n{issue.fixed_code}\n```")
            parts.append("")
        
        return '\n'.join(parts)

27.3.6 Git 集成

python
# app/git/git_client.py
"""Git 集成客户端"""

import subprocess
import os
from typing import List, Optional
from dataclasses import dataclass


@dataclass
class GitDiff:
    file_path: str
    status: str  # added, modified, deleted
    additions: int
    deletions: int
    patch: str = ""


class GitClient:
    """Git 命令行封装"""
    
    def __init__(self, repo_path: str = "."):
        self.repo_path = os.path.abspath(repo_path)
    
    def _run(self, *args) -> str:
        result = subprocess.run(
            ["git"] + list(args),
            cwd=self.repo_path,
            capture_output=True, text=True, timeout=30,
        )
        if result.returncode != 0:
            raise RuntimeError(f"git {' '.join(args)} failed: {result.stderr}")
        return result.stdout.strip()
    
    def get_changed_files(
        self, base: str = "main", head: str = "HEAD"
    ) -> List[GitDiff]:
        """获取变更文件列表"""
        output = self._run(
            "diff", "--numstat", f"{base}...{head}"
        )
        
        diffs = []
        for line in output.split('\n'):
            if not line.strip():
                continue
            parts = line.split('\t')
            if len(parts) >= 3:
                additions = int(parts[0]) if parts[0] != '-' else 0
                deletions = int(parts[1]) if parts[1] != '-' else 0
                path = parts[2]
                diffs.append(GitDiff(
                    file_path=path,
                    status="modified",
                    additions=additions,
                    deletions=deletions,
                ))
        
        return diffs
    
    def get_file_content(self, file_path: str, ref: str = "HEAD") -> str:
        """获取指定版本的文件内容"""
        return self._run("show", f"{ref}:{file_path}")
    
    def get_current_branch(self) -> str:
        return self._run("branch", "--show-current")

27.3.7 完整的 PR 审查流程

python
# app/main.py
"""代码审查助手 - FastAPI 入口"""

from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, List
from contextlib import asynccontextmanager

from app.agents.review_agent import ReviewAgent, ReviewReport
from app.git.git_client import GitClient
from app.config import settings


review_agent = ReviewAgent()
reports_store: dict = {}


class ReviewRequest(BaseModel):
    repo_path: str = "."
    base_branch: str = "main"
    target_branch: str = "HEAD"
    files: Optional[List[str]] = None  # 指定文件,None 则审查全部变更


@asynccontextmanager
async def lifespan(app: FastAPI):
    print("🔍 Code Review Assistant 启动完成")
    yield


app = FastAPI(title="代码审查助手",
              version="1.0.0", lifespan=lifespan)


@app.post("/api/v1/review")
async def review_code(req: ReviewRequest):
    """执行代码审查"""
    git = GitClient(req.repo_path)
    report = ReviewReport()
    
    if req.files:
        changed = [type('obj', (), {
            'file_path': f, 'status': 'modified',
            'additions': 0, 'deletions': 0})() for f in req.files]
    else:
        changed = git.get_changed_files(req.base_branch, req.target_branch)
    
    for diff in changed:
        try:
            source = git.get_file_content(diff.file_path)
            file_review = review_agent.review_file(
                f"{req.repo_path}/{diff.file_path}"
            )
            report.files.append(file_review)
        except Exception as e:
            report.files.append(type('obj', (), {
                'file_path': diff.file_path,
                'language': 'unknown',
                'total_lines': 0,
                'issues': [],
                'critical_count': 0,
                'error_count': 0,
                '__str__': lambda: diff.file_path,
            })())
    
    summary = review_agent.generate_summary(report)
    report.summary = summary
    
    import uuid
    report_id = str(uuid.uuid4())[:8]
    reports_store[report_id] = report
    
    return {
        "report_id": report_id,
        "total_files": len(report.files),
        "total_issues": report.total_issues,
        "critical": len(report.critical_issues),
        "score": report.overall_score,
        "summary": summary,
    }


@app.get("/api/v1/review/{report_id}")
async def get_review_report(report_id: str):
    """获取审查报告详情"""
    report = reports_store.get(report_id)
    if not report:
        return {"error": "Report not found"}
    
    return {
        "summary": report.summary,
        "score": report.overall_score,
        "files": [
            {
                "path": f.file_path,
                "language": f.language,
                "lines": f.total_lines,
                "issues": [
                    {
                        "title": i.title,
                        "severity": i.severity.value,
                        "category": i.category.value,
                        "location": str(i.location),
                        "description": i.description,
                        "fix": i.fix_suggestion,
                        "confidence": i.confidence,
                    }
                    for i in f.issues
                ],
            }
            for f in report.files
        ],
    }


@app.get("/api/v1/review/{report_id}/comment")
async def get_github_comment(report_id: str):
    """获取 GitHub PR 评论格式的审查结果"""
    report = reports_store.get(report_id)
    if not report:
        return {"error": "Report not found"}
    
    comments = []
    for f in report.files:
        comment = await review_agent.generate_github_comment(f)
        comments.append(comment)
    
    return {"comments": comments}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app.main:app", host="0.0.0.0", port=8001)

27.4 测试

python
# tests/test_analyzers.py
"""分析器测试"""

import pytest
from app.analyzers.ast_analyzer import ASTAnalyzer
from app.analyzers.security_analyzer import SecurityAnalyzer
from app.analyzers.performance_analyzer import PerformanceAnalyzer


class TestASTAnalyzer:
    
    def test_bare_except(self):
        code = """
try:
    do_something()
except:
    pass
"""
        analyzer = ASTAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("bare except" in i.title.lower() for i in issues)
    
    def test_mutable_default(self):
        code = """
def add_item(item, items=[]):
    items.append(item)
    return items
"""
        analyzer = ASTAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("可变" in i.title for i in issues)
    
    def test_shadow_builtin(self):
        code = """
def process(list):
    return [x for x in list]
"""
        analyzer = ASTAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("内置" in i.title for i in issues)


class TestSecurityAnalyzer:
    
    def test_hardcoded_password(self):
        code = 'password = "my_secret_123"'
        analyzer = SecurityAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("硬编码" in i.title for i in issues)
    
    def test_eval_usage(self):
        code = "result = eval(user_input)"
        analyzer = SecurityAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("eval" in i.title for i in issues)
    
    def test_sql_injection(self):
        code = 'query = f"SELECT * FROM users WHERE id = {user_id}"'
        analyzer = SecurityAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("SQL" in i.title for i in issues)
    
    def test_pickle_deserialize(self):
        code = "data = pickle.loads(received_data)"
        analyzer = SecurityAnalyzer()
        issues = analyzer.analyze("test.py", code)
        assert any("反序列化" in i.title for i in issues)

27.5 经验总结

27.5.1 关键设计决策

  1. 静态分析 + LLM 双轨制:静态分析快速且确定性强,适合规则明确的问题(安全漏洞、代码规范);LLM 适合语义理解和上下文相关的建议。两者互补,先跑静态分析获取确定结果,再用 LLM 补充深度分析。

  2. 置信度标注:每个问题都标注置信度。高置信度问题(如硬编码密码、eval())可以直接阻断 CI 流水线;低置信度问题作为建议呈现。

  3. 可操作的修复建议:只指出问题不够,必须给出修复代码。这大幅提高了开发者采纳建议的比例。

27.5.2 踩坑记录

  • AST 分析的局限性:Python 的 AST 无法检测运行时问题(如动态属性访问、__getattr__ 魔法方法)。需要配合 LLM 补充。
  • 正则扫描的误报:简单正则匹配会产生大量误报(如变量名包含 "password" 但并非密码)。通过更精确的正则和上下文分析降低误报。
  • 大型仓库性能:扫描 10 万行代码需要优化。使用多进程并行分析不同文件。

27.5.3 扩展方向

  • 多语言支持:集成 tree-sitter 支持 JavaScript/TypeScript/Go/Rust
  • IDE 插件:VSCode/JetBrains 插件实现实时审查
  • CI/CD 集成:GitHub Actions / GitLab CI 自动审查 PR
  • 学习型规则:从团队的审查历史中学习新的检查规则

本章小结:代码审查助手展示了 Agent 技术在开发工具领域的巨大潜力。关键在于静态分析的确定性LLM 的理解力的完美结合。通过分层架构(静态分析 → 规则匹配 → LLM 深度分析),我们构建了一个既快速又智能的审查系统。

基于 MIT 许可发布