第11章:数据集(Datasets)#
11.1 数据集概述#
在构建故事讲述AI大语言模型的过程中,数据集的质量和规模直接决定了模型的性能和能力。大语言模型(LLM)需要海量的文本数据来学习语言的结构、语法、语义以及各种知识。对于专门用于讲故事的AI模型,我们需要特别关注那些包含丰富叙事结构、情节发展和人物塑造的文本数据。
语言模型的训练数据通常来源广泛,包括但不限于书籍、文章、网页内容、对话记录等。这些数据经过精心筛选和处理后,才能用于模型的训练。在本章中,我们将深入探讨数据集的收集、处理、加载以及合成数据生成的方法,为构建一个高质量的故事讲述AI模型奠定基础。
数据集的构建过程可以分为几个关键步骤:数据收集、数据清洗、数据标注、数据分割以及数据加载。每一步都至关重要,任何一个环节出现问题都可能导致最终模型性能的下降。因此,我们需要投入足够的时间和精力来确保数据集的质量。
11.2 数据收集#
11.2.1 公开数据集#
在开始构建自己的数据集之前,了解一些已有的公开数据集是非常有价值的。这些数据集不仅可以直接用于模型训练,还可以作为我们构建自己数据集的参考。以下是一些适用于故事讲述AI模型的公开数据集:
BookCorpus:包含超过11,000本未出版的书籍,涵盖了各种类型的小说,是训练语言模型的优质资源。
Project Gutenberg:提供了超过60,000本公版书籍的电子版本,包括大量经典文学作品。
Wikitext:从维基百科文章中提取的大规模语言建模数据集,包含了丰富的知识和叙事内容。
Children’s Book Test (CBT):专门用于评估模型理解儿童故事能力的数据集,对于构建故事讲述AI特别有价值。
ROCStories:包含50,000个日常生活的短故事,每个故事由五个句子组成,具有连贯的情节发展。
WritingPrompts:来自Reddit的写作提示和相应的故事,包含了各种创意写作内容。
LAMBADA:专门设计用于测试模型长距离依赖理解能力的数据集,对于故事生成尤为重要。
这些公开数据集可以通过各种渠道获取,如Hugging Face的Datasets库、TensorFlow Datasets、Kaggle等平台。在实际应用中,我们通常会结合多个数据集,以确保模型能够学习到多样化的语言表达和叙事结构。
11.2.2 自定义数据收集#
除了使用公开数据集外,针对特定的故事类型或风格,我们可能需要收集自定义数据。以下是一些自定义数据收集的方法:
网络爬虫:通过编写爬虫程序,从特定网站(如故事分享平台、文学网站等)收集故事内容。这种方法可以获取大量的数据,但需要注意版权问题和数据质量控制。
API访问:许多平台提供API接口,允许开发者以结构化的方式获取内容。例如,使用Reddit API获取WritingPrompts子版块的内容,或使用News API获取新闻故事。
合作收集:与作家、出版社或教育机构合作,获取专业创作的故事内容。这种方法可以获取高质量的数据,但可能需要支付费用或签订协议。
众包平台:通过众包平台(如Amazon Mechanical Turk、Prolific等)招募人员创作或收集故事。这种方法可以快速获取大量数据,但需要设计良好的任务指南和质量控制机制。
无论采用哪种方法,都需要确保收集的数据符合法律和伦理要求,并尊重原创作者的权益。同时,建立一个明确的数据收集计划,包括目标数据量、数据类型、质量标准等,可以帮助我们更有效地进行数据收集工作。
11.3 数据清洗与预处理#
收集到的原始数据通常包含各种噪声和不规则内容,需要进行清洗和预处理才能用于模型训练。数据清洗是提高模型训练效果的关键步骤,包括以下几个方面:
11.3.1 基础文本清洗#
基础文本清洗主要处理文本中的格式问题和明显的噪声:
去除HTML标签:如果数据来源于网页,可能包含HTML标签,需要使用正则表达式或专门的库(如BeautifulSoup)去除。
统一编码:确保所有文本使用相同的编码(通常是UTF-8),避免因编码不一致导致的乱码问题。
处理特殊字符:根据需要保留或去除特殊字符、表情符号等。
规范化空白字符:处理多余的空格、制表符、换行符等,使文本格式一致。
大小写处理:根据需要统一文本的大小写,或保持原有格式。
以下是一个简单的Python代码示例,展示了基础文本清洗的过程:
import re
from bs4 import BeautifulSoup
def basic_clean(text):
# 去除HTML标签
text = BeautifulSoup(text, "html.parser").get_text()
# 处理特殊字符和多余空白
text = re.sub(r'\s+', ' ', text) # 将多个空白字符替换为单个空格
text = text.strip() # 去除首尾空白
# 其他清洗步骤...
return text
# 应用到数据集
cleaned_texts = [basic_clean(text) for text in raw_texts]
11.3.2 语言学预处理#
语言学预处理涉及更深层次的文本分析和处理:
分词:将文本分割成单词或子词单元,这是后续处理的基础。
词干提取和词形还原:将单词转换为其基本形式,减少词汇的变化形式。
去除停用词:根据需要去除常见但信息量较少的词(如”the”、”and”等)。
句子分割:将文本分割成句子单元,便于后续的处理和分析。
以下是使用NLTK库进行语言学预处理的示例:
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.corpus import stopwords
# 下载必要的资源
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
def linguistic_preprocess(text):
# 句子分割
sentences = sent_tokenize(text)
# 分词、词形还原和去除停用词
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
processed_sentences = []
for sentence in sentences:
# 分词
words = word_tokenize(sentence)
# 词形还原和去除停用词
words = [lemmatizer.lemmatize(word.lower()) for word in words if word.lower() not in stop_words]
processed_sentences.append(' '.join(words))
return processed_sentences
# 应用到数据集
processed_texts = [linguistic_preprocess(text) for text in cleaned_texts]
11.3.3 故事特定的预处理#
对于故事文本,我们可能需要进行一些特定的预处理:
章节和段落识别:识别故事的章节和段落结构,保留这些信息以便模型学习叙事结构。
对话提取:识别和标记故事中的对话内容,这对于模型学习人物对话风格很重要。
情节标记:识别故事中的关键情节点,如开端、发展、高潮、结局等。
人物识别:识别故事中的人物及其关系,这有助于模型理解人物互动和发展。
这些特定的预处理通常需要结合规则和机器学习方法来实现,可能需要一定的人工标注工作。
11.3.4 数据质量控制#
数据清洗的最后一步是进行质量控制,确保处理后的数据符合预期:
数据完整性检查:确保没有缺失或损坏的数据。
长度过滤:过滤掉过短或过长的文本,保持数据的一致性。
重复检测:识别并处理重复的内容,避免模型过度学习某些模式。
语言检测:确保数据是目标语言(如中文),过滤掉其他语言的内容。
内容审核:过滤不适当的内容,确保数据符合伦理和法律要求。
以下是一个简单的数据质量控制示例:
from langdetect import detect
def quality_control(texts, min_length=50, max_length=10000):
filtered_texts = []
for text in texts:
# 长度过滤
if len(text) < min_length or len(text) > max_length:
continue
# 语言检测(确保是英文)
try:
if detect(text) != 'en':
continue
except:
continue
# 其他质量控制步骤...
filtered_texts.append(text)
return filtered_texts
# 应用到数据集
quality_texts = quality_control(processed_texts)
通过这些数据清洗和预处理步骤,我们可以将原始的、杂乱的文本数据转换为结构化、高质量的训练数据,为后续的模型训练奠定基础。
11.4 数据加载与处理#
在准备好高质量的数据集后,下一步是设计高效的数据加载和处理流程,以便将数据输入到模型中进行训练。现代深度学习框架提供了多种工具和库来简化这一过程。
11.4.1 数据格式化#
首先,我们需要将清洗后的数据转换为适合模型训练的格式。对于语言模型,常见的格式包括:
文本文件:每行一个样本或文档,简单直观但处理大规模数据时效率较低。
JSON/JSONL:每个样本包含多个字段,如文本内容、元数据等,灵活性高。
TFRecord/WebDataset:二进制格式,专为高效的数据加载和处理设计。
HDF5:适合存储大规模、层次化的数据,支持随机访问。
Parquet:列式存储格式,适合处理结构化数据。
对于故事文本,我们可能会选择JSONL格式,每个样本包含故事文本、标题、作者、类型等信息。以下是一个示例:
{"title": "小红帽", "author": "格林兄弟", "type": "童话", "text": "从前有一个可爱的小女孩..."}
{"title": "灰姑娘", "author": "格林兄弟", "type": "童话", "text": "从前,有一个善良的女孩..."}
11.4.2 数据分割#
在模型训练前,我们通常需要将数据集分割为训练集、验证集和测试集:
训练集:用于模型的参数学习,通常占总数据的70-80%。
验证集:用于调整超参数和早停,通常占10-15%。
测试集:用于评估最终模型性能,通常占10-15%。
分割数据时,需要确保各个子集的分布相似,避免引入偏差。以下是使用scikit-learn进行数据分割的示例:
from sklearn.model_selection import train_test_split
# 假设stories是我们的故事数据列表
train_data, temp_data = train_test_split(stories, test_size=0.3, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
print(f"训练集大小: {len(train_data)}")
print(f"验证集大小: {len(val_data)}")
print(f"测试集大小: {len(test_data)}")
11.4.3 数据加载库#
现代深度学习框架提供了专门的数据加载库,简化了数据处理流程:
PyTorch DataLoader:PyTorch的数据加载工具,支持批处理、多进程加载和自定义数据转换。
TensorFlow tf.data:TensorFlow的数据加载API,提供高效的数据处理和转换功能。
Hugging Face Datasets:专为NLP任务设计的库,支持多种数据格式和处理操作。
WebDataset:专为大规模数据设计的库,基于tar文件格式,支持流式处理。
以下是使用Hugging Face Datasets加载和处理故事数据集的示例:
from datasets import Dataset, DatasetDict
# 创建数据集字典
dataset_dict = DatasetDict({
'train': Dataset.from_list(train_data),
'validation': Dataset.from_list(val_data),
'test': Dataset.from_list(test_data)
})
# 查看数据集信息
print(dataset_dict)
# 应用数据转换
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
11.4.4 数据批处理与动态加载#
对于大规模数据集,我们通常采用批处理和动态加载的方式,避免将整个数据集加载到内存中:
批处理:将数据分成小批次进行处理,每次只加载一个批次的数据。
动态加载:在训练过程中动态加载数据,而不是预先加载所有数据。
预取:在处理当前批次的同时,预先加载下一批次的数据,减少等待时间。
缓存:缓存频繁使用的数据,减少重复加载的开销。
以下是使用PyTorch DataLoader进行批处理和动态加载的示例:
import torch
from torch.utils.data import Dataset, DataLoader
class StoryDataset(Dataset):
def __init__(self, stories, tokenizer, max_length=512):
self.stories = stories
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.stories)
def __getitem__(self, idx):
story = self.stories[idx]
encoding = self.tokenizer(
story["text"],
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
# 将字典中的所有张量转换为一维
return {k: v.squeeze(0) for k, v in encoding.items()}
# 创建数据集和数据加载器
train_dataset = StoryDataset(train_data, tokenizer)
train_dataloader = DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
num_workers=4,
pin_memory=True
)
# 使用数据加载器进行训练
for batch in train_dataloader:
# 将数据移动到GPU
batch = {k: v.to(device) for k, v in batch.items()}
# 模型前向传播、损失计算和反向传播
# ...
通过这种方式,我们可以高效地处理大规模数据集,同时保持内存使用在合理范围内。
11.5 合成数据生成#
除了收集和处理现有的数据外,合成数据生成是另一种增加训练数据的重要方法,特别是在特定领域的数据稀缺时。对于故事讲述AI模型,合成数据可以帮助增强模型在特定类型故事或特定叙事风格上的能力。
11.5.1 基于规则的数据生成#
基于规则的方法使用预定义的模板和规则来生成新的故事样本:
故事模板:创建基本的故事结构模板,如”主角遇到问题→尝试解决→遇到障碍→最终解决”。
元素替换:在模板中替换不同的角色、场景、问题等元素,生成多样化的故事。
语法规则:使用形式语法或其他语言规则来生成符合特定结构的文本。
以下是一个简单的基于模板的故事生成示例:
import random
# 故事元素
characters = ["小明", "小红", "小华", "小丽"]
settings = ["森林", "学校", "城堡", "海边"]
problems = ["迷路了", "遇到了一只神秘动物", "发现了一个神秘洞穴", "遭遇了暴风雨"]
solutions = ["得到了朋友的帮助", "凭借智慧解决了问题", "找到了回家的路", "学会了一项新技能"]
# 故事模板
template = "从前,{character}在{setting}里{problem}。经过一番努力,最终{solution}。"
# 生成故事
def generate_story():
return template.format(
character=random.choice(characters),
setting=random.choice(settings),
problem=random.choice(problems),
solution=random.choice(solutions)
)
# 生成100个故事样本
synthetic_stories = [generate_story() for _ in range(100)]
11.5.2 基于模型的数据生成#
使用现有的语言模型来生成新的故事样本,这种方法可以产生更自然、多样化的文本:
预训练模型生成:使用GPT、BART等预训练模型生成新的故事文本。
条件生成:基于特定的提示、风格或主题生成故事。
数据增强:对现有故事进行改写、扩展或变换,生成新的变体。
以下是使用Hugging Face Transformers库进行基于模型的故事生成示例:
from transformers import pipeline
# 初始化文本生成管道
generator = pipeline('text-generation', model='gpt2')
# 故事提示
prompts = [
"从前有一个小女孩,她住在森林边缘,",
"在一个遥远的王国,有一位年轻的王子,",
"太空站的警报突然响起,宇航员们发现,"
]
# 生成故事
synthetic_stories = []
for prompt in prompts:
# 为每个提示生成10个不同的故事
for _ in range(10):
result = generator(prompt, max_length=200, num_return_sequences=1)[0]['generated_text']
synthetic_stories.append(result)
11.5.3 数据增强技术#
数据增强是一种特殊的合成数据生成方法,通过对现有数据进行变换来创建新的样本:
同义词替换:用同义词替换文本中的某些词,保持语义不变但创造新的表达。
回译:将文本翻译成另一种语言,然后再翻译回来,产生表达方式的变化。
句子重排:改变句子的顺序,创造不同的叙事流程。
插入和删除:随机插入或删除某些词或短语,增加文本的多样性。
EDA (Easy Data Augmentation):结合上述多种方法的简单数据增强技术。
以下是使用nlpaug库进行文本数据增强的示例:
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
# 同义词替换增强器
synonym_aug = naw.SynonymAug()
# 回译增强器
back_translation_aug = naw.BackTranslationAug(
from_model_name='facebook/wmt19-en-de',
to_model_name='facebook/wmt19-de-en'
)
# 句子重排增强器
sentence_shuffle_aug = nas.RandomSentAug()
# 应用数据增强
augmented_stories = []
for story in stories[:10]: # 只对前10个故事进行增强示例
# 同义词替换
aug_syn = synonym_aug.augment(story)
augmented_stories.append(aug_syn)
# 回译
aug_back = back_translation_aug.augment(story)
augmented_stories.append(aug_back)
# 句子重排
aug_shuffle = sentence_shuffle_aug.augment(story)
augmented_stories.append(aug_shuffle)
11.5.4 合成数据的质量控制#
合成数据生成后,同样需要进行质量控制,确保数据的质量和多样性:
人工审核:对生成的样本进行抽样审核,确保质量和适当性。
自动评估:使用自动化指标(如困惑度、BLEU分数等)评估生成文本的质量。
多样性检查:确保生成的样本具有足够的多样性,避免重复和单一模式。
与原始数据比较:确保合成数据与原始数据在分布上相似,但不是简单的复制。
以下是一个简单的合成数据质量评估示例:
from nltk.translate.bleu_score import sentence_bleu
from collections import Counter
import numpy as np
def evaluate_synthetic_data(original_stories, synthetic_stories, sample_size=100):
# 抽样原始故事和合成故事
orig_sample = random.sample(original_stories, min(sample_size, len(original_stories)))
synth_sample = random.sample(synthetic_stories, min(sample_size, len(synthetic_stories)))
# 计算平均长度
orig_lengths = [len(story.split()) for story in orig_sample]
synth_lengths = [len(story.split()) for story in synth_sample]
print(f"原始故事平均长度: {np.mean(orig_lengths):.2f} ± {np.std(orig_lengths):.2f}")
print(f"合成故事平均长度: {np.mean(synth_lengths):.2f} ± {np.std(synth_lengths):.2f}")
# 计算词汇多样性
orig_vocab = Counter(word for story in orig_sample for word in story.split())
synth_vocab = Counter(word for story in synth_sample for word in story.split())
print(f"原始故事词汇量: {len(orig_vocab)}")
print(f"合成故事词汇量: {len(synth_vocab)}")
# 计算BLEU分数(评估合成故事与原始故事的相似度)
bleu_scores = []
for synth_story in synth_sample[:10]: # 只对前10个样本计算BLEU分数
synth_tokens = synth_story.split()
scores = [sentence_bleu([orig_story.split()], synth_tokens) for orig_story in orig_sample[:10]]
bleu_scores.append(max(scores)) # 取最相似的原始故事的BLEU分数
print(f"平均最大BLEU分数: {np.mean(bleu_scores):.4f}")
# 其他评估指标...
# 评估合成数据
evaluate_synthetic_data(original_stories, synthetic_stories)
通过合成数据生成,我们可以显著扩充训练数据集,特别是在特定领域或风格的数据稀缺时。但需要注意的是,合成数据应该作为原始数据的补充,而不是替代,两者结合使用通常能获得最佳效果。
11.6 数据集管理与版本控制#
随着项目的发展,数据集可能会不断更新和扩充,因此建立良好的数据集管理和版本控制机制非常重要。
11.6.1 数据集元数据#
为数据集创建详细的元数据,记录数据集的来源、处理方法、统计信息等:
基本信息:数据集名称、版本、创建日期、作者等。
来源信息:数据的来源、收集方法、原始格式等。
处理信息:清洗和预处理的步骤、使用的工具和参数等。
统计信息:样本数量、词汇量、平均长度、类别分布等。
使用说明:数据集的适用场景、限制条件、许可证信息等。
以下是一个数据集元数据的JSON示例:
{
"name": "StoryTeller Dataset",
"version": "1.0.0",
"created_at": "2023-05-15",
"authors": ["张三", "李四"],
"description": "用于训练故事讲述AI模型的中文故事数据集",
"sources": [
{"name": "Project Gutenberg", "url": "https://www.gutenberg.org/"},
{"name": "中国民间故事集", "type": "book", "publisher": "人民文学出版社"}
],
"processing": [
{"step": "HTML清洗", "tool": "BeautifulSoup", "version": "4.9.3"},
{"step": "分词", "tool": "jieba", "version": "0.42.1"},
{"step": "质量过滤", "min_length": 100, "max_length": 5000}
],
"statistics": {
"total_samples": 10000,
"train_samples": 8000,
"val_samples": 1000,
"test_samples": 1000,
"avg_length": 1250,
"vocabulary_size": 35000,
"genres": {
"童话": 3000,
"民间故事": 2500,
"寓言": 1500,
"科幻": 1500,
"其他": 1500
}
},
"license": "CC BY-NC-SA 4.0",
"usage_notes": "本数据集仅用于研究和非商业用途。使用时请注明出处。"
}
11.6.2 数据集版本控制#
随着项目的发展,数据集可能会经历多次更新和迭代,建立版本控制机制可以帮助追踪这些变化:
语义化版本号:使用类似软件的版本号系统(如1.0.0、1.1.0等)来标识不同版本的数据集。
变更日志:记录每个版本的变更内容,包括新增、修改和删除的数据。
数据集快照:为每个版本创建不可变的快照,确保实验的可重复性。
Git LFS:使用Git Large File Storage等工具进行数据集的版本控制。
数据集注册表:使用专门的数据集注册工具(如DVC、Weights & Biases等)管理数据集版本。
以下是使用DVC (Data Version Control) 进行数据集版本控制的示例:
# 安装DVC
pip install dvc
# 初始化DVC仓库
dvc init
# 添加数据集到DVC跟踪
dvc add data/storyteller_dataset_v1.0.0
# 提交更改到Git
git add data/storyteller_dataset_v1.0.0.dvc .gitignore
git commit -m "Add StoryTeller Dataset v1.0.0"
# 修改数据集后,更新DVC跟踪
dvc add data/storyteller_dataset_v1.1.0
# 提交新版本到Git
git add data/storyteller_dataset_v1.1.0.dvc
git commit -m "Update to StoryTeller Dataset v1.1.0"
# 切换到特定版本的数据集
git checkout <commit_hash>
dvc checkout
11.6.3 数据集共享与发布#
对于有价值的数据集,可以考虑共享和发布,使其能够被更广泛的研究社区使用:
数据托管平台:使用Hugging Face Datasets、Kaggle、Zenodo等平台托管和分享数据集。
文档和示例:提供详细的文档和使用示例,帮助其他研究者理解和使用数据集。
引用信息:提供正确的引用格式,使其他研究者可以在论文中引用你的数据集。
许可证:选择适当的许可证(如CC BY-NC-SA、MIT等),明确数据集的使用条件。
以下是在Hugging Face Datasets上发布数据集的示例:
from datasets import Dataset, DatasetDict
from huggingface_hub import HfApi
# 准备数据集
dataset_dict = DatasetDict({
'train': Dataset.from_list(train_data),
'validation': Dataset.from_list(val_data),
'test': Dataset.from_list(test_data)
})
# 保存数据集到本地
dataset_dict.save_to_disk("storyteller_dataset")
# 登录Hugging Face
api = HfApi()
api.login()
# 上传数据集到Hugging Face Hub
dataset_dict.push_to_hub("username/storyteller_dataset", private=False)
# 更新数据集卡片(README.md)
with open("README.md", "w") as f:
f.write("""
# StoryTeller Dataset
用于训练故事讲述AI模型的中文故事数据集。
## 数据集描述
本数据集包含10,000个中文故事样本,涵盖童话、民间故事、寓言、科幻等多种类型。
## 使用方法
```python
from datasets import load_dataset
dataset = load_dataset("username/storyteller_dataset")
引用#
如果您在研究中使用了本数据集,请引用:
@dataset{storyteller_dataset,
author = {张三, 李四},
title = {StoryTeller Dataset},
year = {2023},
url = {https://huggingface.co/datasets/username/storyteller_dataset}
}
上传README.md#
api.upload_file(
path_or_fileobj="README.md",
path_in_repo="README.md",
repo_id="username/storyteller_dataset",
repo_type="dataset"
)
通过良好的数据集管理和版本控制,我们可以确保数据集的质量、可追溯性和可重用性,为模型训练和研究提供坚实的基础。