第16章:部署-16.1 API开发基础#
16.1 API开发基础#
在前面的章节中,我们深入探讨了如何训练和优化一个故事讲述AI模型,从监督式微调到强化学习技术。现在,我们已经拥有了一个高质量的模型,能够生成符合人类偏好的引人入胜的故事。然而,一个优秀的模型如果无法被用户便捷地使用,其价值就会大打折扣。本章将介绍如何将我们的故事讲述AI模型部署为实用的应用程序,使其能够服务于最终用户。
首先,我们将探讨API(应用程序编程接口)的开发,这是将AI模型转化为可用服务的第一步。API提供了一种标准化的方式,使其他应用程序能够与我们的模型进行交互,而无需了解其内部工作原理。
API设计原则#
设计一个好的API需要遵循一些基本原则,特别是对于AI模型这样的复杂系统:
简单性:
API应该易于理解和使用
隐藏模型的复杂性,只暴露必要的功能
提供合理的默认值,减少用户决策负担
一致性:
保持命名、参数和返回值的一致性
遵循RESTful或GraphQL等标准设计模式
错误处理和状态码应该遵循行业标准
可扩展性:
设计应考虑未来功能扩展
版本控制策略应该从一开始就考虑
允许灵活配置模型参数和行为
安全性:
实现适当的认证和授权机制
防止滥用和过度使用(如速率限制)
保护用户数据和隐私
可观测性:
提供详细的日志和监控能力
包含性能指标和使用统计
便于调试和问题排查
API功能规划#
对于故事讲述AI,我们的API应该提供哪些功能?以下是一个基本功能集:
故事生成:
根据提示或主题生成完整故事
支持不同的故事类型和风格
允许控制故事长度和复杂度
故事续写:
基于已有故事片段继续创作
保持风格和情节的一致性
支持多轮交互式创作
角色创建:
生成详细的角色描述
基于简单描述扩展角色背景
创建符合特定故事需求的角色
故事修改:
调整故事的风格或语调
简化或丰富故事内容
改变故事的结局或关键情节
元数据生成:
为故事创建标题
生成摘要或简介
提取关键主题和教育价值
API架构设计#
设计API架构时,我们需要考虑多个因素,包括性能、可扩展性和易用性。以下是一个适合故事讲述AI的架构设计:
整体架构#
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ │ │ │ │ │
│ 客户端应用 │────▶│ API网关 │────▶│ 负载均衡器 │
│ │ │ │ │ │
└─────────────┘ └─────────────┘ └─────────────┘
│
▼
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ │ │ │ │ │
│ 数据库 │◀───▶│ 应用服务器 │◀───▶│ 模型服务器 │
│ │ │ │ │ │
└─────────────┘ └─────────────┘ └─────────────┘
│ │
▼ ▼
┌─────────────┐ ┌─────────────┐
│ │ │ │
│ 缓存服务 │ │ 监控系统 │
│ │ │ │
└─────────────┘ └─────────────┘
组件说明#
API网关:
处理认证和授权
实现速率限制和配额管理
路由请求到适当的服务
负载均衡器:
分配请求到多个模型服务器
实现健康检查和故障转移
优化资源利用
应用服务器:
实现业务逻辑和API端点
处理请求验证和响应格式化
管理用户会话和状态
模型服务器:
运行AI模型推理
管理模型版本和配置
优化推理性能
数据库:
存储用户数据和生成的故事
管理模型配置和元数据
支持分析和报告
缓存服务:
缓存常用请求和响应
减少模型推理负载
提高响应速度
监控系统:
跟踪API使用情况和性能
检测异常和错误
生成报告和警报
API规范设计#
现在,让我们设计一个具体的API规范,使用RESTful风格:
基本端点#
生成故事
POST /api/v1/stories/generate
请求体:
{
"prompt": "一个关于勇敢小兔子的故事",
"style": "童话",
"target_age": 8,
"length": "medium",
"educational_theme": "勇气",
"parameters": {
"temperature": 0.7,
"top_p": 0.9,
"max_tokens": 1000
}
}
响应:
{
"story_id": "story_12345",
"title": "跳跳兔的冒险",
"content": "从前,有一只名叫跳跳的小兔子...",
"metadata": {
"word_count": 450,
"reading_time": "3分钟",
"themes": ["勇气", "友谊", "冒险"],
"educational_value": "教导孩子面对恐惧和挑战"
}
}
续写故事
POST /api/v1/stories/{story_id}/continue
请求体:
{
"current_content": "从前,有一只名叫跳跳的小兔子...",
"continuation_prompt": "跳跳遇到了一只狐狸",
"length": "medium",
"parameters": {
"temperature": 0.8,
"top_p": 0.9,
"max_tokens": 500
}
}
响应:
{
"continuation": "当跳跳正在森林里采集胡萝卜时,他突然遇到了一只狡猾的狐狸...",
"metadata": {
"word_count": 200,
"themes": ["危险", "智慧"]
}
}
创建角色
POST /api/v1/characters/create
请求体:
{
"name": "跳跳",
"brief_description": "一只勇敢但有点胆小的小兔子",
"story_context": "森林冒险故事",
"target_age": 8,
"parameters": {
"detail_level": "high",
"creativity": 0.8
}
}
响应:
{
"character_id": "char_6789",
"name": "跳跳",
"full_description": "跳跳是一只白色的小兔子,有着粉红色的鼻子和长长的耳朵...",
"personality": {
"traits": ["勇敢", "善良", "好奇", "有时胆小"],
"motivations": ["保护家人", "探索森林", "克服恐惧"],
"strengths": ["跑得快", "听力敏锐", "善于交朋友"],
"weaknesses": ["害怕黑暗", "有时优柔寡断"]
},
"background": "跳跳出生在森林边缘的一个兔子洞里,是家中最小的兔子..."
}
修改故事
POST /api/v1/stories/{story_id}/modify
请求体:
{
"content": "从前,有一只名叫跳跳的小兔子...",
"modification_type": "change_style",
"target_style": "冒险",
"instructions": "使故事更加刺激,增加一些冒险元素",
"parameters": {
"creativity": 0.7,
"preservation_rate": 0.5
}
}
响应:
{
"modified_content": "在茂密的魔法森林深处,生活着一只名叫跳跳的勇敢小兔子...",
"changes": {
"style_shift": "从平静的童话风格转变为更具冒险性的叙述",
"added_elements": ["神秘森林", "潜在危险", "冒险使命"],
"preservation_rate": 0.6
}
}
生成元数据
POST /api/v1/stories/{story_id}/metadata
请求体:
{
"content": "从前,有一只名叫跳跳的小兔子...",
"metadata_types": ["title", "summary", "themes", "educational_value"]
}
响应:
{
"title": "跳跳兔的森林冒险",
"summary": "这是一个关于小兔子跳跳克服恐惧,在森林中冒险的故事...",
"themes": ["勇气", "成长", "友谊", "冒险"],
"educational_value": {
"primary_lesson": "勇敢面对恐惧",
"secondary_lessons": ["友谊的价值", "不轻易放弃"],
"target_age_range": "5-9岁"
},
"keywords": ["兔子", "森林", "冒险", "勇气", "友谊"]
}
实现API服务器#
现在,让我们使用Python和Flask实现一个简单的API服务器,将我们的故事讲述AI模型暴露为Web服务。
基本项目结构#
storyteller-api/
├── app/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── routes.py
│ │ └── validators.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── storyteller.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── auth.py
│ │ └── helpers.py
│ └── config.py
├── instance/
│ └── config.py
├── logs/
├── tests/
│ ├── __init__.py
│ ├── test_api.py
│ └── test_models.py
├── .env
├── .gitignore
├── requirements.txt
└── run.py
核心代码实现#
app/__init__.py - 应用初始化
from flask import Flask
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
import logging
from logging.handlers import RotatingFileHandler
import os
from app.config import Config
limiter = Limiter(key_func=get_remote_address)
def create_app(config_class=Config):
app = Flask(__name__, instance_relative_config=True)
app.config.from_object(config_class)
app.config.from_pyfile('config.py', silent=True)
# 初始化扩展
CORS(app)
limiter.init_app(app)
# 注册蓝图
from app.api import bp as api_bp
app.register_blueprint(api_bp, url_prefix='/api/v1')
# 设置日志
if not os.path.exists('logs'):
os.mkdir('logs')
file_handler = RotatingFileHandler('logs/storyteller.log', maxBytes=10240, backupCount=10)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
))
file_handler.setLevel(logging.INFO)
app.logger.addHandler(file_handler)
app.logger.setLevel(logging.INFO)
app.logger.info('Storyteller API startup')
return app
app/config.py - 配置文件
import os
from dotenv import load_dotenv
load_dotenv()
class Config:
SECRET_KEY = os.environ.get('SECRET_KEY') or 'you-will-never-guess'
MODEL_PATH = os.environ.get('MODEL_PATH') or './models/storyteller-model'
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
RATELIMIT_DEFAULT = "100 per day, 10 per hour"
RATELIMIT_STORAGE_URL = "memory://"
LOG_LEVEL = os.environ.get('LOG_LEVEL') or 'INFO'
app/api/routes.py - API路由
from flask import request, jsonify, current_app
from app.api import bp
from app.models.storyteller import StorytellerModel
from app.utils.auth import token_required
from app.utils.helpers import validate_json
from app import limiter
import uuid
import time
# 初始化模型
model = StorytellerModel()
@bp.route('/stories/generate', methods=['POST'])
@token_required
@limiter.limit("10 per minute")
@validate_json(['prompt'])
def generate_story():
"""生成新故事的端点"""
data = request.get_json()
# 记录请求
request_id = str(uuid.uuid4())
current_app.logger.info(f"Story generation request {request_id}: {data['prompt'][:50]}...")
start_time = time.time()
try:
# 提取参数
prompt = data['prompt']
style = data.get('style', 'general')
target_age = data.get('target_age', 8)
length = data.get('length', 'medium')
educational_theme = data.get('educational_theme', None)
# 模型参数
params = data.get('parameters', {})
temperature = params.get('temperature', 0.7)
top_p = params.get('top_p', 0.9)
max_tokens = params.get('max_tokens', 1000)
# 生成故事
result = model.generate_story(
prompt=prompt,
style=style,
target_age=target_age,
length=length,
educational_theme=educational_theme,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
# 构建响应
response = {
'story_id': f"story_{uuid.uuid4().hex[:8]}",
'title': result['title'],
'content': result['content'],
'metadata': result['metadata']
}
# 记录性能
elapsed_time = time.time() - start_time
current_app.logger.info(f"Request {request_id} completed in {elapsed_time:.2f}s")
return jsonify(response), 200
except Exception as e:
current_app.logger.error(f"Error in request {request_id}: {str(e)}")
return jsonify({'error': str(e)}), 500
@bp.route('/stories/<story_id>/continue', methods=['POST'])
@token_required
@validate_json(['current_content'])
def continue_story(story_id):
"""续写故事的端点"""
data = request.get_json()
try:
# 提取参数
current_content = data['current_content']
continuation_prompt = data.get('continuation_prompt', '')
length = data.get('length', 'medium')
# 模型参数
params = data.get('parameters', {})
temperature = params.get('temperature', 0.8)
top_p = params.get('top_p', 0.9)
max_tokens = params.get('max_tokens', 500)
# 生成续写
result = model.continue_story(
current_content=current_content,
continuation_prompt=continuation_prompt,
length=length,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
# 构建响应
response = {
'continuation': result['continuation'],
'metadata': result['metadata']
}
return jsonify(response), 200
except Exception as e:
current_app.logger.error(f"Error continuing story {story_id}: {str(e)}")
return jsonify({'error': str(e)}), 500
# 其他端点实现类似...
app/models/storyteller.py - 模型封装
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import nltk
from nltk.tokenize import sent_tokenize
import re
import os
class StorytellerModel:
def __init__(self, model_path=None):
"""初始化故事讲述模型"""
if model_path is None:
from app.config import Config
model_path = Config.MODEL_PATH
# 下载必要的NLTK资源
nltk.download('punkt', quiet=True)
# 加载模型和分词器
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
# 长度映射
self.length_map = {
"short": 300,
"medium": 600,
"long": 1200
}
def generate_story(self, prompt, style="general", target_age=8, length="medium",
educational_theme=None, temperature=0.7, top_p=0.9, max_tokens=None):
"""生成一个完整的故事"""
# 准备提示
if max_tokens is None:
max_tokens = self.length_map.get(length, 600)
# 构建系统提示
system_prompt = f"你是一个专业的儿童故事作家。请创作一个适合{target_age}岁儿童的{style}风格故事。"
if educational_theme:
system_prompt += f" 故事应该包含关于'{educational_theme}'的教育主题。"
# 组合提示
full_prompt = f"{system_prompt}\n\n故事提示: {prompt}\n\n故事:"
# 生成故事
inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# 解码故事
story_text = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# 生成标题
title = self._generate_title(story_text)
# 分析元数据
metadata = self._analyze_story(story_text, target_age)
return {
"title": title,
"content": story_text,
"metadata": metadata
}
def continue_story(self, current_content, continuation_prompt="", length="medium",
temperature=0.8, top_p=0.9, max_tokens=None):
"""续写一个已有的故事"""
if max_tokens is None:
max_tokens = self.length_map.get(length, 500)
# 准备提示
prompt = f"以下是一个故事的开始:\n\n{current_content}\n\n"
if continuation_prompt:
prompt += f"继续这个故事,其中: {continuation_prompt}\n\n"
else:
prompt += "请继续这个故事:\n\n"
# 生成续写
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# 解码续写
continuation = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# 分析元数据
metadata = {
"word_count": len(continuation.split()),
"themes": self._extract_themes(continuation)
}
return {
"continuation": continuation,
"metadata": metadata
}
def _generate_title(self, story_text):
"""为故事生成标题"""
prompt = f"为以下故事生成一个吸引人的标题:\n\n{story_text[:500]}...\n\n标题:"
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + 20,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
title = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# 清理标题
title = title.strip()
if ":" in title:
title = title.split(":")[0].strip()
return title
def _analyze_story(self, story_text, target_age):
"""分析故事并提取元数据"""
# 计算字数和阅读时间
word_count = len(story_text.split())
reading_time = f"{max(1, round(word_count / 200))}分钟"
# 提取主题
themes = self._extract_themes(story_text)
# 分析教育价值
educational_value = self._analyze_educational_value(story_text)
# 分析适龄性
age_appropriate = self._analyze_age_appropriateness(story_text, target_age)
return {
"word_count": word_count,
"reading_time": reading_time,
"themes": themes,
"educational_value": educational_value,
"age_appropriate": age_appropriate
}
def _extract_themes(self, text):
"""从文本中提取主题"""
# 这里使用简化的实现,实际应用中可能需要更复杂的主题提取算法
theme_keywords = {
"勇气": ["勇敢", "勇气", "克服", "挑战", "害怕", "恐惧", "面对"],
"友谊": ["朋友", "友谊", "帮助", "支持", "一起", "分享"],
"冒险": ["冒险", "探索", "发现", "旅程", "旅行", "未知"],
"家庭": ["家庭", "父母", "妈妈", "爸爸", "兄弟", "姐妹", "爱"],
"成长": ["成长", "学习", "变化", "进步", "经验", "教训"],
"诚实": ["诚实", "真相", "谎言", "欺骗", "真实", "坦白"],
"坚持": ["坚持", "努力", "不放弃", "坚定", "毅力", "继续"],
"想象力": ["想象", "创造", "梦想", "幻想", "魔法", "神奇"]
}
found_themes = []
for theme, keywords in theme_keywords.items():
for keyword in keywords:
if keyword in text:
found_themes.append(theme)
break
return found_themes[:5] # 最多返回5个主题
def _analyze_educational_value(self, text):
"""分析故事的教育价值"""
# 简化实现
educational_aspects = []
if any(word in text.lower() for word in ["学习", "知识", "教育", "学校", "老师"]):
educational_aspects.append("学习价值")
if any(word in text.lower() for word in ["分享", "给予", "帮助", "关心", "同情"]):
educational_aspects.append("社交情感学习")
if any(word in text.lower() for word in ["对不起", "道歉", "原谅", "理解", "接受"]):
educational_aspects.append("情感管理")
if any(word in text.lower() for word in ["动物", "植物", "自然", "环境", "地球"]):
educational_aspects.append("自然知识")
if any(word in text.lower() for word in ["数字", "计数", "形状", "大小", "比较"]):
educational_aspects.append("数学概念")
if not educational_aspects:
educational_aspects.append("一般性教育")
return educational_aspects
def _analyze_age_appropriateness(self, text, target_age):
"""分析故事的适龄性"""
# 简化实现
sentences = sent_tokenize(text)
avg_sentence_length = sum(len(s.split()) for s in sentences) / len(sentences)
# 词汇复杂度的简单估计
complex_words = len([w for w in text.split() if len(w) > 6])
complex_word_ratio = complex_words / len(text.split())
# 根据目标年龄调整期望值
if target_age <= 5:
expected_sentence_length = 5
expected_complex_ratio = 0.05
elif target_age <= 8:
expected_sentence_length = 8
expected_complex_ratio = 0.1
else:
expected_sentence_length = 12
expected_complex_ratio = 0.15
# 计算偏差
sentence_length_diff = abs(avg_sentence_length - expected_sentence_length)
complex_ratio_diff = abs(complex_word_ratio - expected_complex_ratio)
# 评估适龄性
if sentence_length_diff <= 2 and complex_ratio_diff <= 0.05:
appropriateness = "非常适合"
elif sentence_length_diff <= 4 and complex_ratio_diff <= 0.1:
appropriateness = "适合"
else:
appropriateness = "可能需要调整"
return {
"rating": appropriateness,
"avg_sentence_length": round(avg_sentence_length, 1),
"complex_word_ratio": round(complex_word_ratio, 2)
}
app/utils/auth.py - 认证工具
from functools import wraps
from flask import request, jsonify, current_app
import os
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = None
# 检查是否存在Authorization头
if 'Authorization' in request.headers:
auth_header = request.headers['Authorization']
try:
token = auth_header.split(" ")[1]
except IndexError:
return jsonify({'message': 'Token is missing or invalid'}), 401
if not token:
return jsonify({'message': 'Token is missing!'}), 401
# 在生产环境中,应该使用更安全的方法验证令牌
# 这里使用简单的API密钥比较作为示例
api_key = os.environ.get('API_KEY')
if not api_key or token != api_key:
current_app.logger.warning(f"Invalid API key attempt: {token[:10]}...")
return jsonify({'message': 'Invalid token!'}), 401
return f(*args, **kwargs)
return decorated
app/utils/helpers.py - 辅助函数
from functools import wraps
from flask import request, jsonify
def validate_json(required_fields=[]):
"""验证JSON请求体并检查必需字段"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# 检查Content-Type
if not request.is_json:
return jsonify({'error': 'Missing JSON in request'}), 400
# 获取JSON数据
data = request.get_json()
if not data:
return jsonify({'error': 'Invalid JSON'}), 400
# 检查必需字段
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return jsonify({
'error': f"Missing required fields: {', '.join(missing_fields)}"
}), 400
return f(*args, **kwargs)
return decorated_function
return decorator
run.py - 应用入口点
from app import create_app
app = create_app()
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
API测试与文档#
开发API后,测试和文档是确保其可用性和可维护性的关键步骤。
单元测试#
以下是一个简单的测试示例,使用pytest测试我们的API端点:
# tests/test_api.py
import json
import pytest
from app import create_app
from unittest.mock import patch
@pytest.fixture
def client():
app = create_app()
app.config['TESTING'] = True
app.config['API_KEY'] = 'test_key'
with app.test_client() as client:
yield client
def test_generate_story(client):
"""测试故事生成端点"""
# 模拟StorytellerModel.generate_story的返回值
mock_result = {
"title": "测试标题",
"content": "这是一个测试故事内容。",
"metadata": {
"word_count": 6,
"reading_time": "1分钟",
"themes": ["测试"],
"educational_value": ["一般性教育"],
"age_appropriate": {"rating": "适合", "avg_sentence_length": 6.0, "complex_word_ratio": 0.0}
}
}
with patch('app.models.storyteller.StorytellerModel.generate_story', return_value=mock_result):
response = client.post(
'/api/v1/stories/generate',
headers={'Authorization': 'Bearer test_key'},
json={'prompt': '测试提示'}
)
assert response.status_code == 200
data = json.loads(response.data)
assert 'story_id' in data
assert data['title'] == "测试标题"
assert data['content'] == "这是一个测试故事内容。"
assert 'metadata' in data
def test_missing_token(client):
"""测试缺少认证令牌的情况"""
response = client.post(
'/api/v1/stories/generate',
json={'prompt': '测试提示'}
)
assert response.status_code == 401
data = json.loads(response.data)
assert 'message' in data
assert 'missing' in data['message'].lower()
def test_invalid_token(client):
"""测试无效认证令牌的情况"""
response = client.post(
'/api/v1/stories/generate',
headers={'Authorization': 'Bearer invalid_key'},
json={'prompt': '测试提示'}
)
assert response.status_code == 401
data = json.loads(response.data)
assert 'message' in data
assert 'invalid' in data['message'].lower()
def test_missing_required_field(client):
"""测试缺少必需字段的情况"""
response = client.post(
'/api/v1/stories/generate',
headers={'Authorization': 'Bearer test_key'},
json={} # 缺少prompt字段
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'prompt' in data['error'].lower()
API文档#
为API创建清晰的文档对于开发者使用至关重要。我们可以使用Swagger/OpenAPI规范来自动生成API文档。
首先,安装必要的包:
pip install flask-swagger-ui apispec marshmallow
然后,添加Swagger支持:
# app/__init__.py 中添加
from flask_swagger_ui import get_swaggerui_blueprint
def create_app(config_class=Config):
# ... 现有代码 ...
# 设置Swagger
SWAGGER_URL = '/api/docs'
API_URL = '/static/swagger.json'
swaggerui_blueprint = get_swaggerui_blueprint(
SWAGGER_URL,
API_URL,
config={
'app_name': "Storyteller API"
}
)
app.register_blueprint(swaggerui_blueprint, url_prefix=SWAGGER_URL)
# ... 现有代码 ...
创建Swagger规范文件:
# app/api/swagger.py
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from marshmallow import Schema, fields
from flask import jsonify, current_app
import json
import os
# 定义请求和响应模式
class GenerateStoryRequestSchema(Schema):
prompt = fields.Str(required=True, description="故事提示")
style = fields.Str(description="故事风格", default="童话")
target_age = fields.Int(description="目标年龄", default=8)
length = fields.Str(description="故事长度", default="medium")
educational_theme = fields.Str(description="教育主题")
parameters = fields.Dict(description="模型参数")
class StoryMetadataSchema(Schema):
word_count = fields.Int(description="字数")
reading_time = fields.Str(description="阅读时间")
themes = fields.List(fields.Str(), description="主题")
educational_value = fields.List(fields.Str(), description="教育价值")
class GenerateStoryResponseSchema(Schema):
story_id = fields.Str(description="故事ID")
title = fields.Str(description="故事标题")
content = fields.Str(description="故事内容")
metadata = fields.Nested(StoryMetadataSchema, description="元数据")
# 创建规范
spec = APISpec(
title="Storyteller API",
version="1.0.0",
openapi_version="3.0.2",
plugins=[MarshmallowPlugin()],
)
# 注册模式
spec.components.schema("GenerateStoryRequest", schema=GenerateStoryRequestSchema)
spec.components.schema("GenerateStoryResponse", schema=GenerateStoryResponseSchema)
# 添加路径
spec.path(
path="/api/v1/stories/generate",
operations={
"post": {
"summary": "生成新故事",
"description": "根据提示生成一个完整的故事",
"requestBody": {
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/GenerateStoryRequest"}
}
}
},
"responses": {
"200": {
"description": "成功生成故事",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/GenerateStoryResponse"}
}
}
},
"400": {
"description": "无效请求"
},
"401": {
"description": "未授权"
},
"500": {
"description": "服务器错误"
}
},
"security": [{"ApiKeyAuth": []}]
}
}
)
# 添加安全定义
spec.components.security_scheme(
"ApiKeyAuth",
{
"type": "apiKey",
"in": "header",
"name": "Authorization",
"description": "API密钥认证。格式: Bearer {token}"
}
)
# 导出规范
def get_apispec():
return spec
def create_swagger_json():
"""创建swagger.json文件"""
with open(os.path.join(current_app.static_folder, 'swagger.json'), 'w') as f:
json.dump(spec.to_dict(), f)
最后,在应用启动时创建swagger.json文件:
# app/__init__.py 中添加
@app.before_first_request
def before_first_request():
# 创建静态目录(如果不存在)
if not os.path.exists(app.static_folder):
os.makedirs(app.static_folder)
# 创建swagger.json
from app.api.swagger import create_swagger_json
create_swagger_json()
部署考虑事项#
在将API部署到生产环境之前,需要考虑以下几个方面:
性能优化:
使用模型量化减少内存需求
实现请求批处理以提高吞吐量
考虑使用模型服务框架如TorchServe或Triton
可扩展性:
使用容器化(Docker)便于部署和扩展
实现水平扩展以处理高负载
使用负载均衡器分配请求
安全性:
实现适当的认证和授权机制
使用HTTPS加密所有通信
定期更新依赖项以修复安全漏洞
监控和日志:
实现详细的日志记录
设置性能监控和警报
跟踪API使用情况和错误率
成本管理:
优化资源使用以减少云服务成本
考虑按需扩展以处理流量峰值
实现缓存以减少模型推理次数
Docker化API服务#
使用Docker可以简化部署过程并确保环境一致性。以下是一个基本的Dockerfile:
# 使用官方Python镜像作为基础
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 复制依赖文件
COPY requirements.txt .
# 安装依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 设置环境变量
ENV FLASK_APP=run.py
ENV FLASK_ENV=production
ENV MODEL_PATH=/app/models/storyteller-model
# 暴露端口
EXPOSE 5000
# 运行应用
CMD ["gunicorn", "--bind", "0.0.0.0:5000", "run:app"]
创建docker-compose.yml文件以简化部署:
version: '3'
services:
api:
build: .
ports:
- "5000:5000"
volumes:
- ./models:/app/models
- ./logs:/app/logs
environment:
- SECRET_KEY=your-secret-key
- API_KEY=your-api-key
restart: always
总结#
在本节中,我们探讨了如何设计和实现一个API,将我们的故事讲述AI模型转化为可用的服务。我们讨论了API设计原则、功能规划和架构设计,并提供了一个使用Flask实现的完整示例。我们还介绍了测试、文档和部署考虑事项,以确保API的可靠性和可用性。
在下一节中,我们将探讨如何构建一个Web应用程序,为最终用户提供一个友好的界面来与我们的故事讲述AI交互。