第28章:数据分析平台
从"提需求→等报表→再改"到"问一句→拿结果"——构建企业级 AI 数据分析平台
28.1 需求分析与功能规划
28.1.1 业务背景
数据驱动决策是企业的核心竞争力,但传统数据分析流程存在严重瓶颈:
- 需求排期长:业务人员提出分析需求后,数据团队平均排期 3-7 天
- 沟通成本高:SQL 不懂的 业务 和数据工程师之间反复确认口径
- 响应速度慢:一个临时分析需求可能需要走完完整的数据申请流程
- 分析门槛高:BI 工具(Tableau、Power BI)需要专业培训才能使用
我们需要构建一个 AI 驱动的数据分析平台,让业务人员用自然语言就能完成数据查询、可视化和洞察发现:
- 自然语言转 SQL:用中文描述数据需求,自动生成并执行 SQL
- 智能数据可视化:根据查询结果自动推荐图表类型并生成可视化
- 异常检测与告警:自动发现数据中的异常模式,主动推送告警
- 分析报告生成:基于查询结果自动生成结构化分析报告
28.1.2 功能清单
┌─────────────────────────────────────────────────────────┐
│ AI 数据分析平台功能架构 │
├─────────────────────────────────────────────────────────┤
│ ┌──────────────┐ ┌──────────────┐ ┌───────────────┐ │
│ │ 自然语言层 │ │ 分析引擎层 │ │ 可视化层 │ │
│ │ • NL2SQL │ │ • SQL优化器 │ │ • 图表推荐 │ │
│ │ • 语义纠正 │ │ • 聚合计算 │ │ • 交互式图表 │ │
│ │ • 多轮追问 │ │ • 趋势分析 │ │ • 仪表盘生成 │ │
│ │ • 意图识别 │ │ • 异常检测 │ │ • 导出PDF │ │
│ └──────────────┘ └──────────────┘ └───────────────┘ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ 数据接入层 │ │
│ │ • MySQL/PostgreSQL • ClickHouse • CSV/Excel │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘28.1.3 非功能需求
| 维度 | 指标 |
|---|---|
| SQL 生成准确率 | > 90%(标准查询) |
| 查询响应时间 | P95 < 5 秒(百万级数据) |
| 并发支持 | 100 QPS |
| 支持数据库 | MySQL、PostgreSQL、ClickHouse |
| 安全性 | 禁止 DELETE/DROP/TRUNCATE,只读模式 |
28.2 架构设计
28.2.1 项目结构
ai-data-platform/
├── app/
│ ├── main.py # FastAPI 入口
│ ├── config.py # 配置管理
│ ├── models/ # 数据模型
│ │ ├── query.py # 查询模型
│ │ ├── chart.py # 图表模型
│ │ └── report.py # 报告模型
│ ├── agents/ # Agent 核心
│ │ ├── nl2sql_agent.py # 自然语言转 SQL
│ │ ├── visualization_agent.py # 可视化推荐
│ │ ├── anomaly_agent.py # 异常检测
│ │ └── report_agent.py # 报告生成
│ ├── services/ # 业务服务
│ │ ├── db_service.py # 数据库连接管理
│ │ ├── sql_executor.py # SQL 安全执行
│ │ └── cache_service.py # 查询缓存
│ └── utils/
│ ├── llm_client.py # LLM 客户端
│ └── schema_loader.py # 数据库 Schema 加载
├── tests/
├── demo_data/
│ └── init_db.sql # 演示数据库初始化
└── requirements.txt28.2.2 核心类设计
系统由四个 Agent 组成,Pipeline 模式串联处理:
- NL2SQLAgent:将自然语言转换为安全的 SQL 查询,支持多轮追问和语义纠正
- VisualizationAgent:分析查询结果特征,推荐最佳图表类型
- AnomalyAgent:对时间序列数据进行统计异常检测
- ReportAgent:将分析结果汇总为结构化报告
设计决策:Pipeline 模式而非并行,因为可视化、异常检测和报告生成都依赖 SQL 执行结果。但 NL2SQL 内部的 Schema 检索和意图识别可以并行。
28.3 核心代码实现
28.3.1 项目配置与 LLM 客户端
# app/config.py
"""AI 数据分析平台配置管理"""
from pydantic_settings import BaseSettings
from enum import Enum
class DatabaseType(str, Enum):
MYSQL = "mysql"
POSTGRESQL = "postgresql"
CLICKHOUSE = "clickhouse"
class Settings(BaseSettings):
APP_NAME: str = "AI 数据分析平台"
APP_VERSION: str = "1.0.0"
DEBUG: bool = False
# LLM 配置
LLM_API_KEY: str = ""
LLM_BASE_URL: str = "https://api.openai.com/v1"
LLM_MODEL: str = "gpt-4o"
LLM_TEMPERATURE: float = 0.1
LLM_MAX_TOKENS: int = 4096
# 数据库配置
DB_TYPE: DatabaseType = DatabaseType.MYSQL
DB_HOST: str = "localhost"
DB_PORT: int = 3306
DB_USER: str = "root"
DB_PASSWORD: str = ""
DB_NAME: str = "ecommerce"
DB_READONLY: bool = True
# 查询限制
MAX_ROWS_RETURNED: int = 10000
QUERY_TIMEOUT: int = 30
MAX_SQL_LENGTH: int = 5000
# 缓存配置
CACHE_ENABLED: bool = True
CACHE_TTL: int = 300
class Config:
env_file = ".env"
env_prefix = "DP_"
settings = Settings()# app/utils/llm_client.py
"""LLM 客户端封装"""
import json
from typing import Optional, List, Dict
from openai import OpenAI
from app.config import settings
class LLMClient:
_instance: Optional['LLMClient'] = None
def __new__(cls) -> 'LLMClient':
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._client = OpenAI(
api_key=settings.LLM_API_KEY,
base_url=settings.LLM_BASE_URL,
)
return cls._instance
async def chat(
self,
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
response_format: Optional[dict] = None,
) -> str:
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
kwargs = {
"model": settings.LLM_MODEL,
"messages": full_messages,
"temperature": temperature or settings.LLM_TEMPERATURE,
"max_tokens": max_tokens or settings.LLM_MAX_TOKENS,
}
if response_format:
kwargs["response_format"] = response_format
response = self._client.chat.completions.create(**kwargs)
return response.choices[0].message.content
async def chat_json(
self,
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
) -> dict:
content = await self.chat(
messages=messages, system_prompt=system_prompt,
temperature=0.1, response_format={"type": "json_object"},
)
return json.loads(content)
llm_client = LLMClient()28.3.2 数据库 Schema 加载器
NL2SQL 的核心前提是理解数据库结构:
# app/utils/schema_loader.py
"""数据库 Schema 加载器"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from app.config import settings
@dataclass
class ColumnInfo:
name: str
data_type: str
comment: str = ""
is_primary: bool = False
is_foreign: bool = False
sample_values: List[str] = field(default_factory=list)
@dataclass
class TableInfo:
name: str
comment: str = ""
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
row_count: int = 0
def to_prompt(self) -> str:
"""生成给 LLM 的表描述"""
cols = []
for col in self.columns.values():
flags = []
if col.is_primary:
flags.append("PK")
if col.is_foreign:
flags.append("FK")
flag_str = f" [{', '.join(flags)}]" if flags else ""
sample = (f" (例: {', '.join(col.sample_values[:3])})"
if col.sample_values else "")
cols.append(f" - {col.name}: {col.data_type}{flag_str}{sample}")
header = f"表 `{self.name}`"
if self.comment:
header += f" ({self.comment})"
header += f", 约{self.row_count}行:\n"
return header + "\n".join(cols)
class SchemaLoader:
def __init__(self, db_service):
self._db = db_service
self._tables: Dict[str, TableInfo] = {}
self._ddl_cache: Optional[str] = None
async def load_schema(self) -> str:
if self._ddl_cache:
return self._ddl_cache
tables = await self._load_mysql()
self._tables = tables
self._ddl_cache = "\n\n".join(
t.to_prompt() for t in tables.values())
return self._ddl_cache
async def _load_mysql(self) -> Dict[str, TableInfo]:
tables = {}
table_names = await self._db.fetch_all(
"SELECT TABLE_NAME, TABLE_COMMENT "
"FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s",
(settings.DB_NAME,))
for name, comment in table_names:
cols_rows = await self._db.fetch_all(
"SELECT COLUMN_NAME, DATA_TYPE, COLUMN_COMMENT, "
"COLUMN_KEY FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s "
"ORDER BY ORDINAL_POSITION",
(settings.DB_NAME, name))
table = TableInfo(name=name, comment=comment or "")
for cname, dtype, ccomment, ckey in cols_rows:
table.columns[cname] = ColumnInfo(
name=cname, data_type=dtype, comment=ccomment or "",
is_primary=ckey == "PRI", is_foreign=ckey == "MUL")
# 采样数据
try:
samples = await self._db.fetch_all(
f"SELECT * FROM `{name}` LIMIT 3")
for row in samples:
for col_name, col in table.columns.items():
if len(col.sample_values) < 3 and row.get(col_name):
col.sample_values.append(
str(row[col_name])[:50])
except Exception:
pass
# 行数
try:
count_row = await self._db.fetch_one(
f"SELECT COUNT(*) as cnt FROM `{name}`")
table.row_count = count_row["cnt"] if count_row else 0
except Exception:
pass
tables[name] = table
return tables28.3.3 NL2SQL Agent(核心)
# app/agents/nl2sql_agent.py
"""自然语言转 SQL Agent"""
import re
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from app.utils.llm_client import llm_client
from app.utils.schema_loader import SchemaLoader
@dataclass
class SQLResult:
sql: str
intent: str # query/aggregate/trend/compare/rank
confidence: float
tables_used: List[str] = field(default_factory=list)
explanation: str = ""
follow_up_questions: List[str] = field(default_factory=list)
DANGEROUS_KEYWORDS = re.compile(
r'\b(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE|GRANT|REVOKE)\b',
re.IGNORECASE
)
class NL2SQLAgent:
SYSTEM_PROMPT = """你是一个专业的数据分析师,精通 SQL。将用户的自然语言问题转换为准确的 SQL 查询。
数据库 Schema:
{schema}
规则:
1. 只生成 SELECT 查询,禁止任何修改操作
2. 使用表别名提高可读性
3. 对聚合查询确保 GROUP BY 正确
4. 时间范围使用标准日期函数
5. 添加 LIMIT 限制结果数量(默认1000)
6. 如果问题不明确,列出可能的歧义点
返回 JSON:
{{
"sql": "生成的SQL语句",
"intent": "query|aggregate|trend|compare|rank",
"confidence": 0.0-1.0,
"explanation": "对SQL逻辑的简要说明",
"tables_used": ["使用的表名"],
"follow_up_questions": ["建议追问的问题"]
}}"""
def __init__(self, schema_loader: SchemaLoader):
self._schema_loader = schema_loader
self._schema_text = ""
async def _ensure_schema(self):
if not self._schema_text:
self._schema_text = await self._schema_loader.load_schema()
async def generate_sql(
self, question: str, context: Optional[Dict] = None,
) -> SQLResult:
"""将自然语言转换为 SQL"""
await self._ensure_schema()
prompt = self.SYSTEM_PROMPT.format(schema=self._schema_text)
messages = self._build_messages(question, context)
try:
import json
result = await llm_client.chat_json(
messages=messages, system_prompt=prompt)
sql = self._sanitize_sql(result.get("sql", ""))
return SQLResult(
sql=sql,
intent=result.get("intent", "query"),
confidence=float(result.get("confidence", 0.8)),
tables_used=result.get("tables_used", []),
explanation=result.get("explanation", ""),
follow_up_questions=result.get("follow_up_questions", []),
)
except Exception as e:
return SQLResult(
sql="", intent="query", confidence=0.0,
explanation=f"SQL 生成失败: {str(e)}")
def _build_messages(
self, question: str, context: Optional[Dict] = None,
) -> List[Dict]:
messages = []
if context and context.get("conversation"):
messages.extend(context["conversation"][-4:])
if context and context.get("last_sql"):
messages.append({
"role": "assistant",
"content": f"上一步 SQL: {context['last_sql']}",
})
if context and context.get("time_range"):
messages.append({
"role": "system",
"content": f"时间范围: {context['time_range']}",
})
messages.append({"role": "user", "content": question})
return messages
def _sanitize_sql(self, sql: str) -> str:
"""安全清理 SQL"""
sql = sql.strip()
if sql.startswith("```"):
sql = re.sub(r'^```\w*\n?', '', sql)
sql = re.sub(r'\n?```$', '', sql)
sql = sql.strip().rstrip(";")
if DANGEROUS_KEYWORDS.search(sql):
raise ValueError("检测到不允许的 SQL 操作(DDL/DML)")
if not sql.upper().strip().startswith("SELECT"):
raise ValueError("只允许 SELECT 查询")
if "LIMIT" not in sql.upper():
sql += f"\n LIMIT 1000"
return sql
async def refine_sql(
self, question: str, prev_sql: str, error_msg: str,
) -> SQLResult:
"""SQL 执行失败后自动修正"""
await self._ensure_schema()
prompt = self.SYSTEM_PROMPT.format(schema=self._schema_text)
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": f"SQL:\n```sql\n{prev_sql}\n```"},
{"role": "user", "content": f"执行出错: {error_msg}\n请修正。"},
]
try:
import json
result = await llm_client.chat_json(
messages=messages, system_prompt=prompt)
sql = self._sanitize_sql(result.get("sql", ""))
return SQLResult(
sql=sql, confidence=float(result.get("confidence", 0.6)),
explanation=f"修正后: {result.get('explanation', '')}",
)
except Exception:
return SQLResult(sql="", confidence=0.0, explanation="修正失败")28.3.4 SQL 安全执行器
# app/services/sql_executor.py
"""SQL 安全执行器"""
import re
import time
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from app.config import settings
logger = logging.getLogger(__name__)
@dataclass
class QueryResult:
columns: List[str]
rows: List[Dict[str, Any]]
row_count: int
execution_time: float
sql: str
truncated: bool = False
error: Optional[str] = None
def to_dict_list(self) -> List[Dict]:
return self.rows[:self.row_count]
def summary_stats(self) -> Dict[str, Any]:
"""对数值列计算统计摘要"""
stats = {}
for col in self.columns:
values = [row[col] for row in self.rows
if row.get(col) is not None]
if not values:
continue
try:
nums = [float(v) for v in values]
stats[col] = {
"min": round(min(nums), 2),
"max": round(max(nums), 2),
"avg": round(sum(nums) / len(nums), 2),
"count": len(nums),
}
except (ValueError, TypeError):
unique = set(str(v) for v in values)
stats[col] = {
"unique_count": len(unique),
"top_values": list(unique)[:5],
}
return stats
class SQLExecutor:
def __init__(self, db_service):
self._db = db_service
async def execute(self, sql: str) -> QueryResult:
self._validate_sql(sql)
start = time.time()
try:
rows = await self._db.fetch_all(sql)
elapsed = time.time() - start
if not rows:
return QueryResult(
columns=[], rows=[], row_count=0,
execution_time=elapsed, sql=sql)
columns = list(rows[0].keys())
truncated = len(rows) >= 10000
return QueryResult(
columns=columns, rows=rows,
row_count=len(rows),
execution_time=round(elapsed, 3),
sql=sql, truncated=truncated,
)
except Exception as e:
elapsed = time.time() - start
logger.error(f"SQL 执行失败: {e}")
return QueryResult(
columns=[], rows=[], row_count=0,
execution_time=round(elapsed, 3),
sql=sql, error=str(e),
)
def _validate_sql(self, sql: str):
sql_upper = sql.upper().strip()
if not sql_upper.startswith("SELECT"):
raise ValueError("只允许执行 SELECT 查询")
if re.search(r';\s*(DROP|DELETE|INSERT|UPDATE|ALTER)',
sql, re.IGNORECASE):
raise ValueError("检测到多语句危险操作")
if len(sql) > 5000:
raise ValueError("SQL 过长")28.3.5 可视化推荐 Agent
# app/agents/visualization_agent.py
"""智能可视化推荐 Agent"""
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from app.services.sql_executor import QueryResult
@dataclass
class ChartConfig:
chart_type: str # bar/line/pie/scatter/table/metric
title: str
x_axis: str = ""
y_axis: str = ""
x_label: str = ""
y_label: str = ""
series: List[str] = field(default_factory=list)
@dataclass
class VisualizationResult:
charts: List[ChartConfig]
recommendation_reason: str
class VisualizationAgent:
"""根据查询结果特征推荐最佳可视化方案"""
CHART_RULES = {
"bar": {"triggers": ["类别", "排名", "TOP", "对比", "分类"]},
"line": {"triggers": ["趋势", "时间", "变化", "走势", "增长"]},
"pie": {"triggers": ["占比", "比例", "构成", "分布", "份额"]},
"scatter": {"triggers": ["相关性", "散点", "关系"]},
"metric": {"triggers": ["总计", "总量", "平均值", "总额"]},
}
def recommend(
self, result: QueryResult, intent: str, question: str,
) -> VisualizationResult:
charts = []
analysis = self._analyze_data(result)
rule_chart = self._match_rules(question, intent, analysis)
if rule_chart:
charts.append(rule_chart)
feature_chart = self._infer_from_features(result, analysis)
if feature_chart:
charts.append(feature_chart)
if not charts:
charts.append(ChartConfig(chart_type="table", title="查询结果"))
return VisualizationResult(charts=charts, recommendation_reason="自动推荐")
def _analyze_data(self, result: QueryResult) -> Dict[str, Any]:
if not result.columns or not result.rows:
return {"is_empty": True, "row_count": 0}
numeric_cols, category_cols, date_cols = [], [], []
for col in result.columns:
values = [row.get(col) for row in result.rows[:50]
if row.get(col) is not None]
if not values:
continue
is_numeric = True
for v in values[:10]:
try:
float(v)
except (ValueError, TypeError):
is_numeric = False
break
if is_numeric:
numeric_cols.append(col)
elif self._is_date_col(values):
date_cols.append(col)
else:
category_cols.append(col)
return {
"is_empty": False, "row_count": result.row_count,
"numeric_cols": numeric_cols, "category_cols": category_cols,
"date_cols": date_cols, "has_single_row": result.row_count == 1,
}
def _is_date_col(self, values: List) -> bool:
import re
patterns = [r'\d{4}-\d{2}-\d{2}', r'\d{4}/\d{2}/\d{2}']
matched = sum(1 for v in values[:5]
if any(re.search(p, str(v)) for p in patterns))
return matched >= 3
def _match_rules(self, question, intent, analysis):
q = question.lower()
for chart_type, rules in self.CHART_RULES.items():
if any(t in q for t in rules["triggers"]):
return self._create_chart(chart_type, analysis, question)
return None
def _infer_from_features(self, result, analysis):
if analysis.get("has_single_row") and analysis.get("numeric_cols"):
return ChartConfig(
chart_type="metric",
title=f"{result.columns[0]} = {list(result.rows[0].values())[0]}")
if analysis.get("date_cols") and analysis.get("numeric_cols"):
return ChartConfig(
chart_type="line", title="数据趋势",
x_axis=analysis["date_cols"][0],
y_axis=analysis["numeric_cols"][0])
if (len(analysis.get("category_cols", [])) == 1
and analysis.get("numeric_cols")):
return ChartConfig(
chart_type="bar", title="分类对比",
x_axis=analysis["category_cols"][0],
y_axis=analysis["numeric_cols"][0])
return None
def _create_chart(self, chart_type, analysis, question):
date_cols = analysis.get("date_cols", [])
category_cols = analysis.get("category_cols", [])
numeric_cols = analysis.get("numeric_cols", [])
config = {"chart_type": chart_type, "title": question}
if chart_type == "line" and date_cols and numeric_cols:
config.update({"x_axis": date_cols[0], "y_axis": numeric_cols[0]})
elif chart_type in ("bar", "pie") and category_cols and numeric_cols:
config.update({"x_axis": category_cols[0], "y_axis": numeric_cols[0]})
elif chart_type == "metric" and numeric_cols:
config.update({"series": numeric_cols})
return ChartConfig(**config)28.3.6 异常检测 Agent
# app/agents/anomaly_agent.py
"""数据异常检测 Agent"""
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from statistics import mean, stdev
from app.services.sql_executor import QueryResult
@dataclass
class AnomalyPoint:
timestamp: str
value: float
expected_range: tuple
deviation_percent: float
severity: str # info/warning/critical
description: str
@dataclass
class AnomalyReport:
has_anomaly: bool
anomaly_points: List[AnomalyPoint]
overall_trend: str
summary: str
recommendations: List[str]
class AnomalyAgent:
"""基于统计方法的时序异常检测"""
def detect(
self, result: QueryResult, date_col: str, value_col: str,
sensitivity: float = 2.0,
) -> AnomalyReport:
if not result.rows or date_col not in result.columns:
return AnomalyReport(
has_anomaly=False, anomaly_points=[],
overall_trend="unknown", summary="数据不足")
series = []
for row in result.rows:
try:
val = float(row.get(value_col, 0))
ts = str(row.get(date_col, ""))
series.append({"timestamp": ts, "value": val})
except (ValueError, TypeError):
continue
if len(series) < 3:
return AnomalyReport(
has_anomaly=False, anomaly_points=[],
overall_trend="unknown", summary="数据点过少")
values = [p["value"] for p in series]
avg = mean(values)
std = stdev(values) if len(values) > 1 else 0
anomalies = []
for point in series:
z_score = abs(point["value"] - avg) / std if std > 0 else 0
if z_score >= sensitivity:
lower = round(avg - sensitivity * std, 2)
upper = round(avg + sensitivity * std, 2)
deviation = (abs(point["value"] - avg) / avg * 100
if avg != 0 else 0)
anomalies.append(AnomalyPoint(
timestamp=point["timestamp"],
value=point["value"],
expected_range=(lower, upper),
deviation_percent=round(deviation, 1),
severity="critical" if z_score >= 3 else "warning",
description=(f"{point['timestamp']} 数值 "
f"{point['value']:.2f} 偏离均值 "
f"{deviation:.1f}%"),
))
trend = self._detect_trend(values)
summary = (f"共分析 {len(series)} 个数据点;均值: {avg:.2f}, "
f"标准差: {std:.2f};趋势: {trend};"
f"异常点: {len(anomalies)}个")
recs = []
if anomalies:
recs.append(f"有 {len(anomalies)} 个异常点需要关注")
if trend == "decreasing":
recs.append("整体呈下降趋势,建议分析原因")
return AnomalyReport(
has_anomaly=len(anomalies) > 0,
anomaly_points=anomalies,
overall_trend=trend,
summary=summary,
recommendations=recs,
)
def _detect_trend(self, values: List[float]) -> str:
if len(values) < 2:
return "stable"
mid = len(values) // 2
first = mean(values[:mid])
second = mean(values[mid:])
if first == 0:
return "stable"
change = (second - first) / abs(first) * 100
if change > 10:
return "increasing"
elif change < -10:
return "decreasing"
return "stable"28.3.7 报告生成 Agent
# app/agents/report_agent.py
"""分析报告生成 Agent"""
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from datetime import datetime
from app.utils.llm_client import llm_client
from app.services.sql_executor import QueryResult
from app.agents.anomaly_agent import AnomalyReport
@dataclass
class ReportSection:
title: str
content: str
data_summary: Dict[str, Any] = field(default_factory=dict)
@dataclass
class AnalysisReport:
title: str
sections: List[ReportSection]
key_findings: List[str]
recommendations: List[str]
created_at: str = ""
class ReportAgent:
SYSTEM_PROMPT = """你是一个资深数据分析师,擅长将数据查询结果转化为清晰的分析报告。
要求:用数据说话、突出关键发现、给出可操作建议、使用 Markdown 格式。"""
async def generate(
self, question: str, result: QueryResult,
anomaly_report: Optional[AnomalyReport] = None,
) -> AnalysisReport:
data_context = self._build_data_context(result)
messages = [
{"role": "user", "content": f"用户问题: {question}"},
{"role": "assistant", "content": f"执行 SQL: {result.sql}"},
{"role": "user", "content": f"查询结果:\n{data_context}"},
]
if anomaly_report and anomaly_report.has_anomaly:
anomaly_text = "\n".join(
f"- {a.description}" for a in anomaly_report.anomaly_points)
messages.append({
"role": "user",
"content": (f"异常检测:\n{anomaly_text}\n"
f"趋势: {anomaly_report.overall_trend}"),
})
messages.append({"role": "user", "content": "请生成数据分析报告。"})
try:
report_text = await llm_client.chat(
messages=messages, system_prompt=self.SYSTEM_PROMPT,
temperature=0.4, max_tokens=4096)
return AnalysisReport(
title=f"分析报告: {question}",
sections=[ReportSection(title="分析报告", content=report_text)],
key_findings=[], recommendations=[],
created_at=datetime.now().isoformat(),
)
except Exception as e:
return AnalysisReport(
title=f"分析报告: {question}",
sections=[ReportSection(title="错误",
content=f"报告生成失败: {str(e)}")],
key_findings=[], recommendations=[])
def _build_data_context(self, result: QueryResult) -> str:
lines = [f"共 {result.row_count} 行, {len(result.columns)} 列"]
lines.append(f"列: {', '.join(result.columns)}")
stats = result.summary_stats()
for col, s in stats.items():
if "avg" in s:
lines.append(f" {col}: 均值={s['avg']}, "
f"最小={s['min']}, 最大={s['max']}")
lines.append("数据预览:")
for row in result.rows[:10]:
lines.append(" " + str(dict(row)))
return "\n".join(lines)28.3.8 数据库服务与 FastAPI 入口
# app/services/db_service.py
"""数据库连接服务"""
import aiomysql
from typing import List, Dict, Any, Optional
from app.config import settings
class DatabaseService:
def __init__(self):
self._pool = None
async def connect(self):
self._pool = await aiomysql.create_pool(
host=settings.DB_HOST, port=settings.DB_PORT,
user=settings.DB_USER, password=settings.DB_PASSWORD,
db=settings.DB_NAME, maxsize=10,
autocommit=True, charset='utf8mb4')
async def close(self):
if self._pool:
self._pool.close()
await self._pool.wait_closed()
async def fetch_all(self, sql: str, args: tuple = ()) -> List[Dict]:
if not self._pool:
await self.connect()
async with self._pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql, args)
return await cur.fetchall()
async def fetch_one(self, sql: str, args: tuple = ()):
if not self._pool:
await self.connect()
async with self._pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql, args)
return await cur.fetchone()# app/main.py
"""AI 数据分析平台 - FastAPI 入口"""
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
from app.config import settings
from app.services.db_service import DatabaseService
from app.utils.schema_loader import SchemaLoader
from app.agents.nl2sql_agent import NL2SQLAgent
from app.agents.visualization_agent import VisualizationAgent
from app.agents.anomaly_agent import AnomalyAgent
from app.agents.report_agent import ReportAgent
from app.services.sql_executor import SQLExecutor
db_service = DatabaseService()
schema_loader = SchemaLoader(db_service)
nl2sql_agent = NL2SQLAgent(schema_loader)
viz_agent = VisualizationAgent()
anomaly_agent = AnomalyAgent()
report_agent = ReportAgent()
sql_executor = SQLExecutor(db_service)
sessions: dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
await db_service.connect()
print(f"🚀 {settings.APP_NAME} v{settings.APP_VERSION} 启动完成")
yield
await db_service.close()
app = FastAPI(
title=settings.APP_NAME, version=settings.APP_VERSION,
lifespan=lifespan,
description="用自然语言查询数据,AI 自动生成 SQL、图表和报告")
app.add_middleware(CORSMiddleware, allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
class QueryRequest(BaseModel):
question: str
session_id: Optional[str] = None
time_range: Optional[str] = None
class RefineRequest(BaseModel):
question: str
prev_sql: str
error_msg: str
@app.get("/health")
async def health():
return {"status": "ok", "version": settings.APP_VERSION}
@app.get("/api/v1/schema")
async def get_schema():
schema_text = await schema_loader.load_schema()
return {"schema": schema_text}
@app.post("/api/v1/query")
async def query_data(req: QueryRequest):
"""完整 Pipeline:NL2SQL → 执行 → 可视化 → 异常检测"""
context = sessions.get(req.session_id, {}) if req.session_id else {}
if req.time_range:
context["time_range"] = req.time_range
# 1. NL2SQL
sql_result = await nl2sql_agent.generate_sql(req.question, context)
if not sql_result.sql:
raise HTTPException(400, detail=f"无法生成 SQL: {sql_result.explanation}")
# 2. 执行 SQL(失败自动修正)
exec_result = await sql_executor.execute(sql_result.sql)
if exec_result.error:
refined = await nl2sql_agent.refine_sql(
req.question, sql_result.sql, exec_result.error)
if refined.sql:
exec_result = await sql_executor.execute(refined.sql)
sql_result = refined
if exec_result.error:
raise HTTPException(400, detail=f"SQL 执行失败: {exec_result.error}")
# 3. 可视化推荐
viz_result = viz_agent.recommend(exec_result, sql_result.intent, req.question)
# 4. 异常检测(时间序列数据)
anomaly_report = None
date_cols = [c for c in exec_result.columns
if any(k in c.lower() for k in ["date", "time", "month"])]
num_cols = [c for c in exec_result.columns
if exec_result.rows
and isinstance(exec_result.rows[0].get(c), (int, float))]
if date_cols and num_cols and exec_result.row_count >= 5:
anomaly_report = anomaly_agent.detect(
exec_result, date_cols[0], num_cols[0])
# 5. 保存会话
if req.session_id:
sessions[req.session_id] = {
"last_sql": sql_result.sql,
"conversation": [
{"role": "user", "content": req.question},
{"role": "assistant", "content": sql_result.explanation},
],
}
return {
"session_id": req.session_id,
"sql": sql_result.sql,
"explanation": sql_result.explanation,
"confidence": sql_result.confidence,
"data": exec_result.to_dict_list(),
"row_count": exec_result.row_count,
"execution_time": exec_result.execution_time,
"visualization": [
{"chart_type": c.chart_type, "title": c.title,
"x_axis": c.x_axis, "y_axis": c.y_axis}
for c in viz_result.charts
],
"anomaly": {
"has_anomaly": anomaly_report.has_anomaly,
"summary": anomaly_report.summary,
} if anomaly_report else None,
"follow_up_questions": sql_result.follow_up_questions,
}
@app.post("/api/v1/report")
async def generate_report(req: QueryRequest):
"""生成分析报告"""
context = sessions.get(req.session_id, {}) if req.session_id else {}
sql_result = await nl2sql_agent.generate_sql(req.question, context)
exec_result = await sql_executor.execute(sql_result.sql)
if exec_result.error:
raise HTTPException(400, detail=exec_result.error)
report = await report_agent.generate(req.question, exec_result)
return {
"title": report.title,
"sections": [{"title": s.title, "content": s.content}
for s in report.sections],
"key_findings": report.key_findings,
"recommendations": report.recommendations,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=8000,
reload=settings.DEBUG)28.4 演示与测试
28.4.1 演示数据库初始化
-- demo_data/init_db.sql
CREATE TABLE IF NOT EXISTS users (
id INT PRIMARY KEY AUTO_INCREMENT,
username VARCHAR(50) COMMENT '用户名',
city VARCHAR(50) COMMENT '城市',
register_date DATE COMMENT '注册日期',
vip_level INT DEFAULT 0 COMMENT 'VIP等级'
) COMMENT '用户表';
CREATE TABLE IF NOT EXISTS products (
id INT PRIMARY KEY AUTO_INCREMENT,
name VARCHAR(200) COMMENT '商品名称',
category VARCHAR(50) COMMENT '类目',
price DECIMAL(10,2) COMMENT '价格',
stock INT COMMENT '库存'
) COMMENT '商品表';
CREATE TABLE IF NOT EXISTS orders (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT COMMENT '用户ID',
product_id INT COMMENT '商品ID',
amount DECIMAL(10,2) COMMENT '金额',
quantity INT COMMENT '数量',
status VARCHAR(20) COMMENT '状态',
order_date DATETIME COMMENT '下单时间'
) COMMENT '订单表';
-- 示例数据(省略完整 INSERT,结构同上)
INSERT INTO users VALUES
(1,'张三','北京','2024-01-15',3),
(2,'李四','上海','2024-02-20',1),
(3,'王五','广州','2024-03-10',2),
(4,'赵六','深圳','2024-04-05',0);
INSERT INTO products VALUES
(1,'无线蓝牙耳机Pro','数码',299.00,500),
(2,'智能手表S3','数码',1599.00,200),
(3,'纯棉T恤','服饰',89.00,1000),
(4,'运动跑鞋X1','运动',459.00,300);
INSERT INTO orders VALUES
(1,1,1,299.00,1,'completed','2024-10-01 10:30:00'),
(2,2,3,178.00,2,'completed','2024-10-05 09:15:00'),
(3,3,4,459.00,1,'pending','2024-10-07 16:45:00');28.4.2 测试用例
# tests/test_nl2sql.py
"""NL2SQL 精度测试"""
import pytest
class MockSchemaLoader:
SCHEMA = """
表 `orders` (订单表):
- id: int [PK]
- user_id: int [FK]
- amount: decimal
- status: varchar
- order_date: datetime
表 `users` (用户表):
- id: int [PK]
- username: varchar
- city: varchar
- vip_level: int"""
async def load_schema(self):
return self.SCHEMA
@pytest.mark.asyncio
async def test_nl2sql_basic():
"""测试 NL2SQL 基本生成能力"""
from app.agents.nl2sql_agent import NL2SQLAgent
agent = NL2SQLAgent(MockSchemaLoader())
# 需要配置 LLM API_KEY 才能实际运行
# result = await agent.generate_sql("各城市的用户数量")
# assert "users" in result.sql.lower()
# assert "city" in result.sql.lower()
pass
@pytest.mark.asyncio
async def test_sql_sanitize():
"""测试 SQL 安全过滤"""
from app.agents.nl2sql_agent import NL2SQLAgent
agent = NL2SQLAgent(MockSchemaLoader())
# 测试危险 SQL 被拦截
try:
agent._sanitize_sql("DROP TABLE users")
assert False, "应该抛出异常"
except ValueError as e:
assert "不允许" in str(e)28.5 部署
28.5.1 Docker 部署
FROM python:3.11-slim
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc default-libmysqlclient-dev && rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY app/ ./app/
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]# docker-compose.yml
version: '3.8'
services:
data-platform:
build: .
ports: ["8000:8000"]
environment:
- DP_LLM_API_KEY=${LLM_API_KEY}
- DP_DB_HOST=mysql
- DP_DB_USER=data_analyst
- DP_DB_PASSWORD=readonly_pass
- DP_DB_NAME=ecommerce
depends_on: [mysql]
mysql:
image: mysql:8
environment:
MYSQL_ROOT_PASSWORD: root123
MYSQL_DATABASE: ecommerce
MYSQL_USER: data_analyst
MYSQL_PASSWORD: readonly_pass
ports: ["3306:3306"]
volumes:
- ./demo_data/init_db.sql:/docker-entrypoint-initdb.d/init.sql28.6 经验总结
28.6.1 踩坑记录
坑1:NL2SQL 的 Schema 信息过载
早期我们将完整 DDL(包括索引、约束)全部传给 LLM,导致 Token 消耗大且准确率反而下降。解决方案是只保留表名、列名、类型和注释,去掉索引和外键约束。采样 3 行数据作为示例值,极大提升了生成准确率。
坑2:多表 JOIN 的歧义问题
"销售额"可能指 orders.amount 的 SUM,也可能涉及 products 的关联。"各品类销售额"需要 JOIN orders 和 products。我们通过 Few-Shot 示例 + Schema 关系描述 来消歧。在 Schema 中标注 [FK] 关系帮助 LLM 理解表间连接。
坑3:SQL 注入防护
用户输入的问题可能包含恶意 SQL 片段。我们通过白名单验证(只允许 SELECT)和正则黑名单(禁止 DDL/DML)双重防护。生产环境中建议使用只读数据库账号作为额外保障。
坑4:异常检测的基线选择
简单的 Z-Score 在数据有明显趋势时会误报(增长趋势中的正常高点被标记为异常)。我们增加了趋势检测步骤:如果整体趋势是增长的,则使用滑动窗口 Z-Score而非全局均值作为基线。
28.6.2 性能优化经验
- Schema 缓存:数据库结构不常变化,启动时加载一次并缓存,避免每次查询都查 INFORMATION_SCHEMA
- SQL 执行缓存:相同的 SQL 查询结果缓存 5 分钟,相同问题不重复生成 SQL
- 低温度生成:SQL 生成使用 temperature=0.1,确保输出的确定性
- 自动修正机制:SQL 执行失败后自动修正,平均 1.2 次修正即可成功
28.6.3 关键设计模式总结
| 模式 | 应用场景 | 效果 |
|---|---|---|
| Schema-as-Prompt | NL2SQL | 准确率提升 30% |
| 自动修正循环 | SQL 执行容错 | 失败恢复率 85% |
| 规则+特征混合 | 可视化推荐 | 推荐准确率 92% |
| Pipeline 串联 | 完整分析流程 | 一次调用完成全流程 |
28.6.4 未来演进方向
- Text-to-Chart:直接生成 ECharts/Plotly 代码,支持交互式图表
- 自然语言数据建模:根据业务需求自动设计数据库表结构
- 预测分析:集成时间序列预测(Prophet、ARIMA),自动生成预测报告
- 多数据源联邦查询:同时查询 MySQL、ClickHouse、Elasticsearch
本章小结:AI 数据分析平台的核心是 NL2SQL + 自动化 Pipeline。关键在于 Schema 信息的高质量表示、SQL 安全执行机制和失败自修正能力。通过将自然语言查询、SQL 生成、安全执行、可视化推荐和异常检测串联为完整的 Pipeline,业务人员可以用一句话完成从"提问题"到"拿结论"的全流程,将数据分析的响应时间从天级缩短到秒级。