Skip to content

第28章:数据分析平台

从"提需求→等报表→再改"到"问一句→拿结果"——构建企业级 AI 数据分析平台


28.1 需求分析与功能规划

28.1.1 业务背景

数据驱动决策是企业的核心竞争力,但传统数据分析流程存在严重瓶颈:

  1. 需求排期长:业务人员提出分析需求后,数据团队平均排期 3-7 天
  2. 沟通成本高:SQL 不懂的 业务 和数据工程师之间反复确认口径
  3. 响应速度慢:一个临时分析需求可能需要走完完整的数据申请流程
  4. 分析门槛高:BI 工具(Tableau、Power BI)需要专业培训才能使用

我们需要构建一个 AI 驱动的数据分析平台,让业务人员用自然语言就能完成数据查询、可视化和洞察发现:

  • 自然语言转 SQL:用中文描述数据需求,自动生成并执行 SQL
  • 智能数据可视化:根据查询结果自动推荐图表类型并生成可视化
  • 异常检测与告警:自动发现数据中的异常模式,主动推送告警
  • 分析报告生成:基于查询结果自动生成结构化分析报告

28.1.2 功能清单

┌─────────────────────────────────────────────────────────┐
│               AI 数据分析平台功能架构                     │
├─────────────────────────────────────────────────────────┤
│  ┌──────────────┐  ┌──────────────┐  ┌───────────────┐ │
│  │  自然语言层   │  │  分析引擎层   │  │  可视化层     │ │
│  │ • NL2SQL     │  │ • SQL优化器  │  │ • 图表推荐   │ │
│  │ • 语义纠正   │  │ • 聚合计算   │  │ • 交互式图表 │ │
│  │ • 多轮追问   │  │ • 趋势分析   │  │ • 仪表盘生成 │ │
│  │ • 意图识别   │  │ • 异常检测   │  │ • 导出PDF    │ │
│  └──────────────┘  └──────────────┘  └───────────────┘ │
│  ┌───────────────────────────────────────────────────┐  │
│  │                   数据接入层                        │  │
│  │  • MySQL/PostgreSQL  • ClickHouse  • CSV/Excel    │  │
│  └───────────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────────┘

28.1.3 非功能需求

维度指标
SQL 生成准确率> 90%(标准查询)
查询响应时间P95 < 5 秒(百万级数据)
并发支持100 QPS
支持数据库MySQL、PostgreSQL、ClickHouse
安全性禁止 DELETE/DROP/TRUNCATE,只读模式

28.2 架构设计

28.2.1 项目结构

ai-data-platform/
├── app/
│   ├── main.py                    # FastAPI 入口
│   ├── config.py                  # 配置管理
│   ├── models/                    # 数据模型
│   │   ├── query.py               # 查询模型
│   │   ├── chart.py               # 图表模型
│   │   └── report.py              # 报告模型
│   ├── agents/                    # Agent 核心
│   │   ├── nl2sql_agent.py        # 自然语言转 SQL
│   │   ├── visualization_agent.py # 可视化推荐
│   │   ├── anomaly_agent.py       # 异常检测
│   │   └── report_agent.py        # 报告生成
│   ├── services/                  # 业务服务
│   │   ├── db_service.py          # 数据库连接管理
│   │   ├── sql_executor.py        # SQL 安全执行
│   │   └── cache_service.py       # 查询缓存
│   └── utils/
│       ├── llm_client.py          # LLM 客户端
│       └── schema_loader.py       # 数据库 Schema 加载
├── tests/
├── demo_data/
│   └── init_db.sql                # 演示数据库初始化
└── requirements.txt

28.2.2 核心类设计

系统由四个 Agent 组成,Pipeline 模式串联处理:

  • NL2SQLAgent:将自然语言转换为安全的 SQL 查询,支持多轮追问和语义纠正
  • VisualizationAgent:分析查询结果特征,推荐最佳图表类型
  • AnomalyAgent:对时间序列数据进行统计异常检测
  • ReportAgent:将分析结果汇总为结构化报告

设计决策:Pipeline 模式而非并行,因为可视化、异常检测和报告生成都依赖 SQL 执行结果。但 NL2SQL 内部的 Schema 检索和意图识别可以并行。


28.3 核心代码实现

28.3.1 项目配置与 LLM 客户端

python
# app/config.py
"""AI 数据分析平台配置管理"""

from pydantic_settings import BaseSettings
from enum import Enum


class DatabaseType(str, Enum):
    MYSQL = "mysql"
    POSTGRESQL = "postgresql"
    CLICKHOUSE = "clickhouse"


class Settings(BaseSettings):
    APP_NAME: str = "AI 数据分析平台"
    APP_VERSION: str = "1.0.0"
    DEBUG: bool = False

    # LLM 配置
    LLM_API_KEY: str = ""
    LLM_BASE_URL: str = "https://api.openai.com/v1"
    LLM_MODEL: str = "gpt-4o"
    LLM_TEMPERATURE: float = 0.1
    LLM_MAX_TOKENS: int = 4096

    # 数据库配置
    DB_TYPE: DatabaseType = DatabaseType.MYSQL
    DB_HOST: str = "localhost"
    DB_PORT: int = 3306
    DB_USER: str = "root"
    DB_PASSWORD: str = ""
    DB_NAME: str = "ecommerce"
    DB_READONLY: bool = True

    # 查询限制
    MAX_ROWS_RETURNED: int = 10000
    QUERY_TIMEOUT: int = 30
    MAX_SQL_LENGTH: int = 5000

    # 缓存配置
    CACHE_ENABLED: bool = True
    CACHE_TTL: int = 300

    class Config:
        env_file = ".env"
        env_prefix = "DP_"


settings = Settings()
python
# app/utils/llm_client.py
"""LLM 客户端封装"""

import json
from typing import Optional, List, Dict
from openai import OpenAI
from app.config import settings


class LLMClient:
    _instance: Optional['LLMClient'] = None

    def __new__(cls) -> 'LLMClient':
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._client = OpenAI(
                api_key=settings.LLM_API_KEY,
                base_url=settings.LLM_BASE_URL,
            )
        return cls._instance

    async def chat(
        self,
        messages: List[Dict[str, str]],
        system_prompt: Optional[str] = None,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        response_format: Optional[dict] = None,
    ) -> str:
        full_messages = []
        if system_prompt:
            full_messages.append({"role": "system", "content": system_prompt})
        full_messages.extend(messages)
        kwargs = {
            "model": settings.LLM_MODEL,
            "messages": full_messages,
            "temperature": temperature or settings.LLM_TEMPERATURE,
            "max_tokens": max_tokens or settings.LLM_MAX_TOKENS,
        }
        if response_format:
            kwargs["response_format"] = response_format
        response = self._client.chat.completions.create(**kwargs)
        return response.choices[0].message.content

    async def chat_json(
        self,
        messages: List[Dict[str, str]],
        system_prompt: Optional[str] = None,
    ) -> dict:
        content = await self.chat(
            messages=messages, system_prompt=system_prompt,
            temperature=0.1, response_format={"type": "json_object"},
        )
        return json.loads(content)


llm_client = LLMClient()

28.3.2 数据库 Schema 加载器

NL2SQL 的核心前提是理解数据库结构:

python
# app/utils/schema_loader.py
"""数据库 Schema 加载器"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional
from app.config import settings


@dataclass
class ColumnInfo:
    name: str
    data_type: str
    comment: str = ""
    is_primary: bool = False
    is_foreign: bool = False
    sample_values: List[str] = field(default_factory=list)


@dataclass
class TableInfo:
    name: str
    comment: str = ""
    columns: Dict[str, ColumnInfo] = field(default_factory=dict)
    row_count: int = 0

    def to_prompt(self) -> str:
        """生成给 LLM 的表描述"""
        cols = []
        for col in self.columns.values():
            flags = []
            if col.is_primary:
                flags.append("PK")
            if col.is_foreign:
                flags.append("FK")
            flag_str = f" [{', '.join(flags)}]" if flags else ""
            sample = (f" (例: {', '.join(col.sample_values[:3])})"
                      if col.sample_values else "")
            cols.append(f"  - {col.name}: {col.data_type}{flag_str}{sample}")
        header = f"表 `{self.name}`"
        if self.comment:
            header += f" ({self.comment})"
        header += f", 约{self.row_count}行:\n"
        return header + "\n".join(cols)


class SchemaLoader:
    def __init__(self, db_service):
        self._db = db_service
        self._tables: Dict[str, TableInfo] = {}
        self._ddl_cache: Optional[str] = None

    async def load_schema(self) -> str:
        if self._ddl_cache:
            return self._ddl_cache

        tables = await self._load_mysql()
        self._tables = tables
        self._ddl_cache = "\n\n".join(
            t.to_prompt() for t in tables.values())
        return self._ddl_cache

    async def _load_mysql(self) -> Dict[str, TableInfo]:
        tables = {}
        table_names = await self._db.fetch_all(
            "SELECT TABLE_NAME, TABLE_COMMENT "
            "FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s",
            (settings.DB_NAME,))
        for name, comment in table_names:
            cols_rows = await self._db.fetch_all(
                "SELECT COLUMN_NAME, DATA_TYPE, COLUMN_COMMENT, "
                "COLUMN_KEY FROM information_schema.COLUMNS "
                "WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s "
                "ORDER BY ORDINAL_POSITION",
                (settings.DB_NAME, name))
            table = TableInfo(name=name, comment=comment or "")
            for cname, dtype, ccomment, ckey in cols_rows:
                table.columns[cname] = ColumnInfo(
                    name=cname, data_type=dtype, comment=ccomment or "",
                    is_primary=ckey == "PRI", is_foreign=ckey == "MUL")
            # 采样数据
            try:
                samples = await self._db.fetch_all(
                    f"SELECT * FROM `{name}` LIMIT 3")
                for row in samples:
                    for col_name, col in table.columns.items():
                        if len(col.sample_values) < 3 and row.get(col_name):
                            col.sample_values.append(
                                str(row[col_name])[:50])
            except Exception:
                pass
            # 行数
            try:
                count_row = await self._db.fetch_one(
                    f"SELECT COUNT(*) as cnt FROM `{name}`")
                table.row_count = count_row["cnt"] if count_row else 0
            except Exception:
                pass
            tables[name] = table
        return tables

28.3.3 NL2SQL Agent(核心)

python
# app/agents/nl2sql_agent.py
"""自然语言转 SQL Agent"""

import re
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from app.utils.llm_client import llm_client
from app.utils.schema_loader import SchemaLoader


@dataclass
class SQLResult:
    sql: str
    intent: str  # query/aggregate/trend/compare/rank
    confidence: float
    tables_used: List[str] = field(default_factory=list)
    explanation: str = ""
    follow_up_questions: List[str] = field(default_factory=list)


DANGEROUS_KEYWORDS = re.compile(
    r'\b(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE|GRANT|REVOKE)\b',
    re.IGNORECASE
)


class NL2SQLAgent:
    SYSTEM_PROMPT = """你是一个专业的数据分析师,精通 SQL。将用户的自然语言问题转换为准确的 SQL 查询。

数据库 Schema:
{schema}

规则:
1. 只生成 SELECT 查询,禁止任何修改操作
2. 使用表别名提高可读性
3. 对聚合查询确保 GROUP BY 正确
4. 时间范围使用标准日期函数
5. 添加 LIMIT 限制结果数量(默认1000)
6. 如果问题不明确,列出可能的歧义点

返回 JSON:
{{
  "sql": "生成的SQL语句",
  "intent": "query|aggregate|trend|compare|rank",
  "confidence": 0.0-1.0,
  "explanation": "对SQL逻辑的简要说明",
  "tables_used": ["使用的表名"],
  "follow_up_questions": ["建议追问的问题"]
}}"""

    def __init__(self, schema_loader: SchemaLoader):
        self._schema_loader = schema_loader
        self._schema_text = ""

    async def _ensure_schema(self):
        if not self._schema_text:
            self._schema_text = await self._schema_loader.load_schema()

    async def generate_sql(
        self, question: str, context: Optional[Dict] = None,
    ) -> SQLResult:
        """将自然语言转换为 SQL"""
        await self._ensure_schema()
        prompt = self.SYSTEM_PROMPT.format(schema=self._schema_text)
        messages = self._build_messages(question, context)

        try:
            import json
            result = await llm_client.chat_json(
                messages=messages, system_prompt=prompt)
            sql = self._sanitize_sql(result.get("sql", ""))
            return SQLResult(
                sql=sql,
                intent=result.get("intent", "query"),
                confidence=float(result.get("confidence", 0.8)),
                tables_used=result.get("tables_used", []),
                explanation=result.get("explanation", ""),
                follow_up_questions=result.get("follow_up_questions", []),
            )
        except Exception as e:
            return SQLResult(
                sql="", intent="query", confidence=0.0,
                explanation=f"SQL 生成失败: {str(e)}")

    def _build_messages(
        self, question: str, context: Optional[Dict] = None,
    ) -> List[Dict]:
        messages = []
        if context and context.get("conversation"):
            messages.extend(context["conversation"][-4:])
        if context and context.get("last_sql"):
            messages.append({
                "role": "assistant",
                "content": f"上一步 SQL: {context['last_sql']}",
            })
        if context and context.get("time_range"):
            messages.append({
                "role": "system",
                "content": f"时间范围: {context['time_range']}",
            })
        messages.append({"role": "user", "content": question})
        return messages

    def _sanitize_sql(self, sql: str) -> str:
        """安全清理 SQL"""
        sql = sql.strip()
        if sql.startswith("```"):
            sql = re.sub(r'^```\w*\n?', '', sql)
            sql = re.sub(r'\n?```$', '', sql)
        sql = sql.strip().rstrip(";")
        if DANGEROUS_KEYWORDS.search(sql):
            raise ValueError("检测到不允许的 SQL 操作(DDL/DML)")
        if not sql.upper().strip().startswith("SELECT"):
            raise ValueError("只允许 SELECT 查询")
        if "LIMIT" not in sql.upper():
            sql += f"\n LIMIT 1000"
        return sql

    async def refine_sql(
        self, question: str, prev_sql: str, error_msg: str,
    ) -> SQLResult:
        """SQL 执行失败后自动修正"""
        await self._ensure_schema()
        prompt = self.SYSTEM_PROMPT.format(schema=self._schema_text)
        messages = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": f"SQL:\n```sql\n{prev_sql}\n```"},
            {"role": "user", "content": f"执行出错: {error_msg}\n请修正。"},
        ]
        try:
            import json
            result = await llm_client.chat_json(
                messages=messages, system_prompt=prompt)
            sql = self._sanitize_sql(result.get("sql", ""))
            return SQLResult(
                sql=sql, confidence=float(result.get("confidence", 0.6)),
                explanation=f"修正后: {result.get('explanation', '')}",
            )
        except Exception:
            return SQLResult(sql="", confidence=0.0, explanation="修正失败")

28.3.4 SQL 安全执行器

python
# app/services/sql_executor.py
"""SQL 安全执行器"""

import re
import time
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from app.config import settings

logger = logging.getLogger(__name__)


@dataclass
class QueryResult:
    columns: List[str]
    rows: List[Dict[str, Any]]
    row_count: int
    execution_time: float
    sql: str
    truncated: bool = False
    error: Optional[str] = None

    def to_dict_list(self) -> List[Dict]:
        return self.rows[:self.row_count]

    def summary_stats(self) -> Dict[str, Any]:
        """对数值列计算统计摘要"""
        stats = {}
        for col in self.columns:
            values = [row[col] for row in self.rows
                      if row.get(col) is not None]
            if not values:
                continue
            try:
                nums = [float(v) for v in values]
                stats[col] = {
                    "min": round(min(nums), 2),
                    "max": round(max(nums), 2),
                    "avg": round(sum(nums) / len(nums), 2),
                    "count": len(nums),
                }
            except (ValueError, TypeError):
                unique = set(str(v) for v in values)
                stats[col] = {
                    "unique_count": len(unique),
                    "top_values": list(unique)[:5],
                }
        return stats


class SQLExecutor:
    def __init__(self, db_service):
        self._db = db_service

    async def execute(self, sql: str) -> QueryResult:
        self._validate_sql(sql)
        start = time.time()
        try:
            rows = await self._db.fetch_all(sql)
            elapsed = time.time() - start
            if not rows:
                return QueryResult(
                    columns=[], rows=[], row_count=0,
                    execution_time=elapsed, sql=sql)
            columns = list(rows[0].keys())
            truncated = len(rows) >= 10000
            return QueryResult(
                columns=columns, rows=rows,
                row_count=len(rows),
                execution_time=round(elapsed, 3),
                sql=sql, truncated=truncated,
            )
        except Exception as e:
            elapsed = time.time() - start
            logger.error(f"SQL 执行失败: {e}")
            return QueryResult(
                columns=[], rows=[], row_count=0,
                execution_time=round(elapsed, 3),
                sql=sql, error=str(e),
            )

    def _validate_sql(self, sql: str):
        sql_upper = sql.upper().strip()
        if not sql_upper.startswith("SELECT"):
            raise ValueError("只允许执行 SELECT 查询")
        if re.search(r';\s*(DROP|DELETE|INSERT|UPDATE|ALTER)',
                      sql, re.IGNORECASE):
            raise ValueError("检测到多语句危险操作")
        if len(sql) > 5000:
            raise ValueError("SQL 过长")

28.3.5 可视化推荐 Agent

python
# app/agents/visualization_agent.py
"""智能可视化推荐 Agent"""

from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from app.services.sql_executor import QueryResult


@dataclass
class ChartConfig:
    chart_type: str  # bar/line/pie/scatter/table/metric
    title: str
    x_axis: str = ""
    y_axis: str = ""
    x_label: str = ""
    y_label: str = ""
    series: List[str] = field(default_factory=list)


@dataclass
class VisualizationResult:
    charts: List[ChartConfig]
    recommendation_reason: str


class VisualizationAgent:
    """根据查询结果特征推荐最佳可视化方案"""

    CHART_RULES = {
        "bar": {"triggers": ["类别", "排名", "TOP", "对比", "分类"]},
        "line": {"triggers": ["趋势", "时间", "变化", "走势", "增长"]},
        "pie": {"triggers": ["占比", "比例", "构成", "分布", "份额"]},
        "scatter": {"triggers": ["相关性", "散点", "关系"]},
        "metric": {"triggers": ["总计", "总量", "平均值", "总额"]},
    }

    def recommend(
        self, result: QueryResult, intent: str, question: str,
    ) -> VisualizationResult:
        charts = []
        analysis = self._analyze_data(result)
        rule_chart = self._match_rules(question, intent, analysis)
        if rule_chart:
            charts.append(rule_chart)
        feature_chart = self._infer_from_features(result, analysis)
        if feature_chart:
            charts.append(feature_chart)
        if not charts:
            charts.append(ChartConfig(chart_type="table", title="查询结果"))
        return VisualizationResult(charts=charts, recommendation_reason="自动推荐")

    def _analyze_data(self, result: QueryResult) -> Dict[str, Any]:
        if not result.columns or not result.rows:
            return {"is_empty": True, "row_count": 0}
        numeric_cols, category_cols, date_cols = [], [], []
        for col in result.columns:
            values = [row.get(col) for row in result.rows[:50]
                      if row.get(col) is not None]
            if not values:
                continue
            is_numeric = True
            for v in values[:10]:
                try:
                    float(v)
                except (ValueError, TypeError):
                    is_numeric = False
                    break
            if is_numeric:
                numeric_cols.append(col)
            elif self._is_date_col(values):
                date_cols.append(col)
            else:
                category_cols.append(col)
        return {
            "is_empty": False, "row_count": result.row_count,
            "numeric_cols": numeric_cols, "category_cols": category_cols,
            "date_cols": date_cols, "has_single_row": result.row_count == 1,
        }

    def _is_date_col(self, values: List) -> bool:
        import re
        patterns = [r'\d{4}-\d{2}-\d{2}', r'\d{4}/\d{2}/\d{2}']
        matched = sum(1 for v in values[:5]
                      if any(re.search(p, str(v)) for p in patterns))
        return matched >= 3

    def _match_rules(self, question, intent, analysis):
        q = question.lower()
        for chart_type, rules in self.CHART_RULES.items():
            if any(t in q for t in rules["triggers"]):
                return self._create_chart(chart_type, analysis, question)
        return None

    def _infer_from_features(self, result, analysis):
        if analysis.get("has_single_row") and analysis.get("numeric_cols"):
            return ChartConfig(
                chart_type="metric",
                title=f"{result.columns[0]} = {list(result.rows[0].values())[0]}")
        if analysis.get("date_cols") and analysis.get("numeric_cols"):
            return ChartConfig(
                chart_type="line", title="数据趋势",
                x_axis=analysis["date_cols"][0],
                y_axis=analysis["numeric_cols"][0])
        if (len(analysis.get("category_cols", [])) == 1
                and analysis.get("numeric_cols")):
            return ChartConfig(
                chart_type="bar", title="分类对比",
                x_axis=analysis["category_cols"][0],
                y_axis=analysis["numeric_cols"][0])
        return None

    def _create_chart(self, chart_type, analysis, question):
        date_cols = analysis.get("date_cols", [])
        category_cols = analysis.get("category_cols", [])
        numeric_cols = analysis.get("numeric_cols", [])
        config = {"chart_type": chart_type, "title": question}
        if chart_type == "line" and date_cols and numeric_cols:
            config.update({"x_axis": date_cols[0], "y_axis": numeric_cols[0]})
        elif chart_type in ("bar", "pie") and category_cols and numeric_cols:
            config.update({"x_axis": category_cols[0], "y_axis": numeric_cols[0]})
        elif chart_type == "metric" and numeric_cols:
            config.update({"series": numeric_cols})
        return ChartConfig(**config)

28.3.6 异常检测 Agent

python
# app/agents/anomaly_agent.py
"""数据异常检测 Agent"""

from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from statistics import mean, stdev
from app.services.sql_executor import QueryResult


@dataclass
class AnomalyPoint:
    timestamp: str
    value: float
    expected_range: tuple
    deviation_percent: float
    severity: str  # info/warning/critical
    description: str


@dataclass
class AnomalyReport:
    has_anomaly: bool
    anomaly_points: List[AnomalyPoint]
    overall_trend: str
    summary: str
    recommendations: List[str]


class AnomalyAgent:
    """基于统计方法的时序异常检测"""

    def detect(
        self, result: QueryResult, date_col: str, value_col: str,
        sensitivity: float = 2.0,
    ) -> AnomalyReport:
        if not result.rows or date_col not in result.columns:
            return AnomalyReport(
                has_anomaly=False, anomaly_points=[],
                overall_trend="unknown", summary="数据不足")

        series = []
        for row in result.rows:
            try:
                val = float(row.get(value_col, 0))
                ts = str(row.get(date_col, ""))
                series.append({"timestamp": ts, "value": val})
            except (ValueError, TypeError):
                continue

        if len(series) < 3:
            return AnomalyReport(
                has_anomaly=False, anomaly_points=[],
                overall_trend="unknown", summary="数据点过少")

        values = [p["value"] for p in series]
        avg = mean(values)
        std = stdev(values) if len(values) > 1 else 0

        anomalies = []
        for point in series:
            z_score = abs(point["value"] - avg) / std if std > 0 else 0
            if z_score >= sensitivity:
                lower = round(avg - sensitivity * std, 2)
                upper = round(avg + sensitivity * std, 2)
                deviation = (abs(point["value"] - avg) / avg * 100
                             if avg != 0 else 0)
                anomalies.append(AnomalyPoint(
                    timestamp=point["timestamp"],
                    value=point["value"],
                    expected_range=(lower, upper),
                    deviation_percent=round(deviation, 1),
                    severity="critical" if z_score >= 3 else "warning",
                    description=(f"{point['timestamp']} 数值 "
                                 f"{point['value']:.2f} 偏离均值 "
                                 f"{deviation:.1f}%"),
                ))

        trend = self._detect_trend(values)
        summary = (f"共分析 {len(series)} 个数据点;均值: {avg:.2f}, "
                   f"标准差: {std:.2f};趋势: {trend};"
                   f"异常点: {len(anomalies)}个")
        recs = []
        if anomalies:
            recs.append(f"有 {len(anomalies)} 个异常点需要关注")
        if trend == "decreasing":
            recs.append("整体呈下降趋势,建议分析原因")

        return AnomalyReport(
            has_anomaly=len(anomalies) > 0,
            anomaly_points=anomalies,
            overall_trend=trend,
            summary=summary,
            recommendations=recs,
        )

    def _detect_trend(self, values: List[float]) -> str:
        if len(values) < 2:
            return "stable"
        mid = len(values) // 2
        first = mean(values[:mid])
        second = mean(values[mid:])
        if first == 0:
            return "stable"
        change = (second - first) / abs(first) * 100
        if change > 10:
            return "increasing"
        elif change < -10:
            return "decreasing"
        return "stable"

28.3.7 报告生成 Agent

python
# app/agents/report_agent.py
"""分析报告生成 Agent"""

from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from datetime import datetime
from app.utils.llm_client import llm_client
from app.services.sql_executor import QueryResult
from app.agents.anomaly_agent import AnomalyReport


@dataclass
class ReportSection:
    title: str
    content: str
    data_summary: Dict[str, Any] = field(default_factory=dict)


@dataclass
class AnalysisReport:
    title: str
    sections: List[ReportSection]
    key_findings: List[str]
    recommendations: List[str]
    created_at: str = ""


class ReportAgent:
    SYSTEM_PROMPT = """你是一个资深数据分析师,擅长将数据查询结果转化为清晰的分析报告。
要求:用数据说话、突出关键发现、给出可操作建议、使用 Markdown 格式。"""

    async def generate(
        self, question: str, result: QueryResult,
        anomaly_report: Optional[AnomalyReport] = None,
    ) -> AnalysisReport:
        data_context = self._build_data_context(result)
        messages = [
            {"role": "user", "content": f"用户问题: {question}"},
            {"role": "assistant", "content": f"执行 SQL: {result.sql}"},
            {"role": "user", "content": f"查询结果:\n{data_context}"},
        ]
        if anomaly_report and anomaly_report.has_anomaly:
            anomaly_text = "\n".join(
                f"- {a.description}" for a in anomaly_report.anomaly_points)
            messages.append({
                "role": "user",
                "content": (f"异常检测:\n{anomaly_text}\n"
                            f"趋势: {anomaly_report.overall_trend}"),
            })
        messages.append({"role": "user", "content": "请生成数据分析报告。"})

        try:
            report_text = await llm_client.chat(
                messages=messages, system_prompt=self.SYSTEM_PROMPT,
                temperature=0.4, max_tokens=4096)
            return AnalysisReport(
                title=f"分析报告: {question}",
                sections=[ReportSection(title="分析报告", content=report_text)],
                key_findings=[], recommendations=[],
                created_at=datetime.now().isoformat(),
            )
        except Exception as e:
            return AnalysisReport(
                title=f"分析报告: {question}",
                sections=[ReportSection(title="错误",
                                        content=f"报告生成失败: {str(e)}")],
                key_findings=[], recommendations=[])

    def _build_data_context(self, result: QueryResult) -> str:
        lines = [f"共 {result.row_count} 行, {len(result.columns)} 列"]
        lines.append(f"列: {', '.join(result.columns)}")
        stats = result.summary_stats()
        for col, s in stats.items():
            if "avg" in s:
                lines.append(f"  {col}: 均值={s['avg']}, "
                             f"最小={s['min']}, 最大={s['max']}")
        lines.append("数据预览:")
        for row in result.rows[:10]:
            lines.append("  " + str(dict(row)))
        return "\n".join(lines)

28.3.8 数据库服务与 FastAPI 入口

python
# app/services/db_service.py
"""数据库连接服务"""

import aiomysql
from typing import List, Dict, Any, Optional
from app.config import settings


class DatabaseService:
    def __init__(self):
        self._pool = None

    async def connect(self):
        self._pool = await aiomysql.create_pool(
            host=settings.DB_HOST, port=settings.DB_PORT,
            user=settings.DB_USER, password=settings.DB_PASSWORD,
            db=settings.DB_NAME, maxsize=10,
            autocommit=True, charset='utf8mb4')

    async def close(self):
        if self._pool:
            self._pool.close()
            await self._pool.wait_closed()

    async def fetch_all(self, sql: str, args: tuple = ()) -> List[Dict]:
        if not self._pool:
            await self.connect()
        async with self._pool.acquire() as conn:
            async with conn.cursor(aiomysql.DictCursor) as cur:
                await cur.execute(sql, args)
                return await cur.fetchall()

    async def fetch_one(self, sql: str, args: tuple = ()):
        if not self._pool:
            await self.connect()
        async with self._pool.acquire() as conn:
            async with conn.cursor(aiomysql.DictCursor) as cur:
                await cur.execute(sql, args)
                return await cur.fetchone()
python
# app/main.py
"""AI 数据分析平台 - FastAPI 入口"""

from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional

from app.config import settings
from app.services.db_service import DatabaseService
from app.utils.schema_loader import SchemaLoader
from app.agents.nl2sql_agent import NL2SQLAgent
from app.agents.visualization_agent import VisualizationAgent
from app.agents.anomaly_agent import AnomalyAgent
from app.agents.report_agent import ReportAgent
from app.services.sql_executor import SQLExecutor

db_service = DatabaseService()
schema_loader = SchemaLoader(db_service)
nl2sql_agent = NL2SQLAgent(schema_loader)
viz_agent = VisualizationAgent()
anomaly_agent = AnomalyAgent()
report_agent = ReportAgent()
sql_executor = SQLExecutor(db_service)
sessions: dict = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
    await db_service.connect()
    print(f"🚀 {settings.APP_NAME} v{settings.APP_VERSION} 启动完成")
    yield
    await db_service.close()


app = FastAPI(
    title=settings.APP_NAME, version=settings.APP_VERSION,
    lifespan=lifespan,
    description="用自然语言查询数据,AI 自动生成 SQL、图表和报告")
app.add_middleware(CORSMiddleware, allow_origins=["*"],
                   allow_credentials=True,
                   allow_methods=["*"], allow_headers=["*"])


class QueryRequest(BaseModel):
    question: str
    session_id: Optional[str] = None
    time_range: Optional[str] = None


class RefineRequest(BaseModel):
    question: str
    prev_sql: str
    error_msg: str


@app.get("/health")
async def health():
    return {"status": "ok", "version": settings.APP_VERSION}


@app.get("/api/v1/schema")
async def get_schema():
    schema_text = await schema_loader.load_schema()
    return {"schema": schema_text}


@app.post("/api/v1/query")
async def query_data(req: QueryRequest):
    """完整 Pipeline:NL2SQL → 执行 → 可视化 → 异常检测"""
    context = sessions.get(req.session_id, {}) if req.session_id else {}
    if req.time_range:
        context["time_range"] = req.time_range

    # 1. NL2SQL
    sql_result = await nl2sql_agent.generate_sql(req.question, context)
    if not sql_result.sql:
        raise HTTPException(400, detail=f"无法生成 SQL: {sql_result.explanation}")

    # 2. 执行 SQL(失败自动修正)
    exec_result = await sql_executor.execute(sql_result.sql)
    if exec_result.error:
        refined = await nl2sql_agent.refine_sql(
            req.question, sql_result.sql, exec_result.error)
        if refined.sql:
            exec_result = await sql_executor.execute(refined.sql)
            sql_result = refined
        if exec_result.error:
            raise HTTPException(400, detail=f"SQL 执行失败: {exec_result.error}")

    # 3. 可视化推荐
    viz_result = viz_agent.recommend(exec_result, sql_result.intent, req.question)

    # 4. 异常检测(时间序列数据)
    anomaly_report = None
    date_cols = [c for c in exec_result.columns
                 if any(k in c.lower() for k in ["date", "time", "month"])]
    num_cols = [c for c in exec_result.columns
                if exec_result.rows
                and isinstance(exec_result.rows[0].get(c), (int, float))]
    if date_cols and num_cols and exec_result.row_count >= 5:
        anomaly_report = anomaly_agent.detect(
            exec_result, date_cols[0], num_cols[0])

    # 5. 保存会话
    if req.session_id:
        sessions[req.session_id] = {
            "last_sql": sql_result.sql,
            "conversation": [
                {"role": "user", "content": req.question},
                {"role": "assistant", "content": sql_result.explanation},
            ],
        }

    return {
        "session_id": req.session_id,
        "sql": sql_result.sql,
        "explanation": sql_result.explanation,
        "confidence": sql_result.confidence,
        "data": exec_result.to_dict_list(),
        "row_count": exec_result.row_count,
        "execution_time": exec_result.execution_time,
        "visualization": [
            {"chart_type": c.chart_type, "title": c.title,
             "x_axis": c.x_axis, "y_axis": c.y_axis}
            for c in viz_result.charts
        ],
        "anomaly": {
            "has_anomaly": anomaly_report.has_anomaly,
            "summary": anomaly_report.summary,
        } if anomaly_report else None,
        "follow_up_questions": sql_result.follow_up_questions,
    }


@app.post("/api/v1/report")
async def generate_report(req: QueryRequest):
    """生成分析报告"""
    context = sessions.get(req.session_id, {}) if req.session_id else {}
    sql_result = await nl2sql_agent.generate_sql(req.question, context)
    exec_result = await sql_executor.execute(sql_result.sql)
    if exec_result.error:
        raise HTTPException(400, detail=exec_result.error)
    report = await report_agent.generate(req.question, exec_result)
    return {
        "title": report.title,
        "sections": [{"title": s.title, "content": s.content}
                     for s in report.sections],
        "key_findings": report.key_findings,
        "recommendations": report.recommendations,
    }


if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app.main:app", host="0.0.0.0", port=8000,
                reload=settings.DEBUG)

28.4 演示与测试

28.4.1 演示数据库初始化

sql
-- demo_data/init_db.sql
CREATE TABLE IF NOT EXISTS users (
    id INT PRIMARY KEY AUTO_INCREMENT,
    username VARCHAR(50) COMMENT '用户名',
    city VARCHAR(50) COMMENT '城市',
    register_date DATE COMMENT '注册日期',
    vip_level INT DEFAULT 0 COMMENT 'VIP等级'
) COMMENT '用户表';

CREATE TABLE IF NOT EXISTS products (
    id INT PRIMARY KEY AUTO_INCREMENT,
    name VARCHAR(200) COMMENT '商品名称',
    category VARCHAR(50) COMMENT '类目',
    price DECIMAL(10,2) COMMENT '价格',
    stock INT COMMENT '库存'
) COMMENT '商品表';

CREATE TABLE IF NOT EXISTS orders (
    id INT PRIMARY KEY AUTO_INCREMENT,
    user_id INT COMMENT '用户ID',
    product_id INT COMMENT '商品ID',
    amount DECIMAL(10,2) COMMENT '金额',
    quantity INT COMMENT '数量',
    status VARCHAR(20) COMMENT '状态',
    order_date DATETIME COMMENT '下单时间'
) COMMENT '订单表';

-- 示例数据(省略完整 INSERT,结构同上)
INSERT INTO users VALUES
(1,'张三','北京','2024-01-15',3),
(2,'李四','上海','2024-02-20',1),
(3,'王五','广州','2024-03-10',2),
(4,'赵六','深圳','2024-04-05',0);
INSERT INTO products VALUES
(1,'无线蓝牙耳机Pro','数码',299.00,500),
(2,'智能手表S3','数码',1599.00,200),
(3,'纯棉T恤','服饰',89.00,1000),
(4,'运动跑鞋X1','运动',459.00,300);
INSERT INTO orders VALUES
(1,1,1,299.00,1,'completed','2024-10-01 10:30:00'),
(2,2,3,178.00,2,'completed','2024-10-05 09:15:00'),
(3,3,4,459.00,1,'pending','2024-10-07 16:45:00');

28.4.2 测试用例

python
# tests/test_nl2sql.py
"""NL2SQL 精度测试"""

import pytest


class MockSchemaLoader:
    SCHEMA = """
表 `orders` (订单表):
  - id: int [PK]
  - user_id: int [FK]
  - amount: decimal
  - status: varchar
  - order_date: datetime
表 `users` (用户表):
  - id: int [PK]
  - username: varchar
  - city: varchar
  - vip_level: int"""

    async def load_schema(self):
        return self.SCHEMA


@pytest.mark.asyncio
async def test_nl2sql_basic():
    """测试 NL2SQL 基本生成能力"""
    from app.agents.nl2sql_agent import NL2SQLAgent
    agent = NL2SQLAgent(MockSchemaLoader())
    # 需要配置 LLM API_KEY 才能实际运行
    # result = await agent.generate_sql("各城市的用户数量")
    # assert "users" in result.sql.lower()
    # assert "city" in result.sql.lower()
    pass


@pytest.mark.asyncio
async def test_sql_sanitize():
    """测试 SQL 安全过滤"""
    from app.agents.nl2sql_agent import NL2SQLAgent
    agent = NL2SQLAgent(MockSchemaLoader())
    # 测试危险 SQL 被拦截
    try:
        agent._sanitize_sql("DROP TABLE users")
        assert False, "应该抛出异常"
    except ValueError as e:
        assert "不允许" in str(e)

28.5 部署

28.5.1 Docker 部署

dockerfile
FROM python:3.11-slim
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
    gcc default-libmysqlclient-dev && rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY app/ ./app/
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
yaml
# docker-compose.yml
version: '3.8'
services:
  data-platform:
    build: .
    ports: ["8000:8000"]
    environment:
      - DP_LLM_API_KEY=${LLM_API_KEY}
      - DP_DB_HOST=mysql
      - DP_DB_USER=data_analyst
      - DP_DB_PASSWORD=readonly_pass
      - DP_DB_NAME=ecommerce
    depends_on: [mysql]
  mysql:
    image: mysql:8
    environment:
      MYSQL_ROOT_PASSWORD: root123
      MYSQL_DATABASE: ecommerce
      MYSQL_USER: data_analyst
      MYSQL_PASSWORD: readonly_pass
    ports: ["3306:3306"]
    volumes:
      - ./demo_data/init_db.sql:/docker-entrypoint-initdb.d/init.sql

28.6 经验总结

28.6.1 踩坑记录

坑1:NL2SQL 的 Schema 信息过载

早期我们将完整 DDL(包括索引、约束)全部传给 LLM,导致 Token 消耗大且准确率反而下降。解决方案是只保留表名、列名、类型和注释,去掉索引和外键约束。采样 3 行数据作为示例值,极大提升了生成准确率。

坑2:多表 JOIN 的歧义问题

"销售额"可能指 orders.amount 的 SUM,也可能涉及 products 的关联。"各品类销售额"需要 JOIN orders 和 products。我们通过 Few-Shot 示例 + Schema 关系描述 来消歧。在 Schema 中标注 [FK] 关系帮助 LLM 理解表间连接。

坑3:SQL 注入防护

用户输入的问题可能包含恶意 SQL 片段。我们通过白名单验证(只允许 SELECT)和正则黑名单(禁止 DDL/DML)双重防护。生产环境中建议使用只读数据库账号作为额外保障。

坑4:异常检测的基线选择

简单的 Z-Score 在数据有明显趋势时会误报(增长趋势中的正常高点被标记为异常)。我们增加了趋势检测步骤:如果整体趋势是增长的,则使用滑动窗口 Z-Score而非全局均值作为基线。

28.6.2 性能优化经验

  1. Schema 缓存:数据库结构不常变化,启动时加载一次并缓存,避免每次查询都查 INFORMATION_SCHEMA
  2. SQL 执行缓存:相同的 SQL 查询结果缓存 5 分钟,相同问题不重复生成 SQL
  3. 低温度生成:SQL 生成使用 temperature=0.1,确保输出的确定性
  4. 自动修正机制:SQL 执行失败后自动修正,平均 1.2 次修正即可成功

28.6.3 关键设计模式总结

模式应用场景效果
Schema-as-PromptNL2SQL准确率提升 30%
自动修正循环SQL 执行容错失败恢复率 85%
规则+特征混合可视化推荐推荐准确率 92%
Pipeline 串联完整分析流程一次调用完成全流程

28.6.4 未来演进方向

  1. Text-to-Chart:直接生成 ECharts/Plotly 代码,支持交互式图表
  2. 自然语言数据建模:根据业务需求自动设计数据库表结构
  3. 预测分析:集成时间序列预测(Prophet、ARIMA),自动生成预测报告
  4. 多数据源联邦查询:同时查询 MySQL、ClickHouse、Elasticsearch

本章小结:AI 数据分析平台的核心是 NL2SQL + 自动化 Pipeline。关键在于 Schema 信息的高质量表示、SQL 安全执行机制和失败自修正能力。通过将自然语言查询、SQL 生成、安全执行、可视化推荐和异常检测串联为完整的 Pipeline,业务人员可以用一句话完成从"提问题"到"拿结论"的全流程,将数据分析的响应时间从天级缩短到秒级。

基于 MIT 许可发布