Path: blob/master/site/zh-cn/tutorials/keras/text_classification.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
电影评论文本分类
本教程演示了从存储在磁盘上的纯文本文件开始的文本分类。您将训练一个二元分类器对 IMDB 数据集执行情感分析。在笔记本的最后,有一个练习供您尝试,您将在其中训练一个多类分类器来预测 Stack Overflow 上编程问题的标签。
情感分析
此笔记本训练了一个情感分析模型,利用评论文本将电影评论分类为正面或负面评价。这是一个二元(或二类)分类示例,也是一个重要且应用广泛的机器学习问题。
您将使用 Large Movie Review Dataset,其中包含 Internet Movie Database 中的 50,000 条电影评论文本 。我们将这些评论分为两组,其中 25,000 条用于训练,另外 25,000 条用于测试。训练集和测试集是均衡的,也就是说其中包含相等数量的正面评价和负面评价。
下载并探索 IMDB 数据集
我们下载并提取数据集,然后浏览一下目录结构。
aclImdb/train/pos
和 aclImdb/train/neg
目录包含许多文本文件,每个文件都是一条电影评论。我们来看看其中的一条评论。
要准备用于二元分类的数据集,磁盘上需要有两个文件夹,分别对应于 class_a
和 class_b
。这些将是正面和负面的电影评论,可以在 aclImdb/train/pos
和 aclImdb/train/neg
中找到。由于 IMDB 数据集包含其他文件夹,因此您需要在使用此实用工具之前将其移除。
如上所示,训练文件夹中有 25,000 个样本,您将使用其中的 80%(或 20,000 个)进行训练。稍后您将看到,您可以通过将数据集直接传递给 model.fit
来训练模型。如果您不熟悉 tf.data
,还可以遍历数据集并打印出一些样本,如下所示。
请注意,评论包含原始文本(带有标点符号和偶尔出现的 HTML 代码,如 <br/>
)。我们将在以下部分展示如何处理这些问题。
标签为 0 或 1。要查看它们与正面和负面电影评论的对应关系,可以查看数据集上的 class_names
属性。
接下来,您将创建验证数据集和测试数据集。您将使用训练集中剩余的 5,000 条评论进行验证。
注:使用 validation_split
和 subset
参数时,请确保要么指定随机种子,要么传递 shuffle=False
,这样验证拆分和训练拆分就不会重叠。
准备用于训练的数据集
接下来,您将使用有用的 tf.keras.layers.TextVectorization
层对数据进行标准化、词例化和向量化。
标准化是指对文本进行预处理,通常是移除标点符号或 HTML 元素以简化数据集。词例化是指将字符串分割成词例(例如,通过空格将句子分割成单个单词)。向量化是指将词例转换为数字,以便将它们输入神经网络。所有这些任务都可以通过这个层完成。
正如您在上面看到的,评论包含各种 HTML 代码,例如 <br />
。TextVectorization
层(默认情况下会将文本转换为小写并去除标点符号,但不会去除 HTML)中的默认标准化程序不会移除这些代码。您将编写一个自定义标准化函数来移除 HTML。
注:为了防止训练-测试偏差(也称为训练-应用偏差),在训练和测试时间对数据进行相同的预处理非常重要。为此,可以将 TextVectorization
层直接包含在模型中,如本教程后面所示。
接下来,您将创建一个 TextVectorization
层。您将使用该层对我们的数据进行标准化、词例化和向量化。您将 output_mode
设置为 int
以便为每个词例创建唯一的整数索引。
请注意,您使用的是默认拆分函数,以及您在上面定义的自定义标准化函数。您还将为模型定义一些常量,例如显式的最大 sequence_length
,这会使层将序列填充或截断为精确的 sequence_length
值。
接下来,您将调用 adapt
以使预处理层的状态适合数据集。这会使模型构建字符串到整数的索引。
注:在调用时请务必仅使用您的训练数据(使用测试集会泄漏信息)。
我们来创建一个函数来查看使用该层预处理一些数据的结果。
正如您在上面看到的,每个词例都被一个整数替换了。您可以通过在该层上调用 .get_vocabulary()
来查找每个整数对应的词例(字符串)。
你几乎已经准备好训练你的模型了。作为最后的预处理步骤,你将在训练、验证和测试数据集上应用之前创建的TextVectorization层。
配置数据集以提高性能
以下是加载数据时应该使用的两种重要方法,以确保 I/O 不会阻塞。
从磁盘加载后,.cache()
会将数据保存在内存中。这将确保数据集在训练模型时不会成为瓶颈。如果您的数据集太大而无法放入内存,也可以使用此方法创建高性能的磁盘缓存,这比许多小文件的读取效率更高。
prefetch()
会在训练时将数据预处理和模型执行重叠。
您可以在数据性能指南中深入了解这两种方法,以及如何将数据缓存到磁盘。
创建模型
是时候创建您的神经网络了:
层按顺序堆叠以构建分类器:
第一个层是
Embedding
层。此层采用整数编码的评论,并查找每个单词索引的嵌入向量。这些向量是通过模型训练学习到的。向量向输出数组增加了一个维度。得到的维度为:(batch, sequence, embedding)
。要详细了解嵌入向量,请参阅单词嵌入向量教程。接下来,
GlobalAveragePooling1D
将通过对序列维度求平均值来为每个样本返回一个定长输出向量。这允许模型以尽可能最简单的方式处理变长输入。最后一层与单个输出结点密集连接。
损失函数与优化器
模型训练需要一个损失函数和一个优化器。由于这是一个二元分类问题,并且模型输出概率(具有 Sigmoid 激活的单一单元层),我们将使用 losses.BinaryCrossentropy
损失函数。
现在,配置模型以使用优化器和损失函数:
训练模型
将 dataset
对象传递给 fit 方法,对模型进行训练。
评估模型
我们来看一下模型的性能如何。将返回两个值。损失值(loss)(一个表示误差的数字,值越低越好)与准确率(accuracy)。
这种十分简单的方式实现了约 86% 的准确率。
创建准确率和损失随时间变化的图表
model.fit()
会返回包含一个字典的 History
对象。该字典包含训练过程中产生的所有信息:
其中有四个条目:每个条目代表训练和验证过程中的一项监测指标。您可以使用这些指标来绘制用于比较的训练损失和验证损失图表,以及训练准确率和验证准确率图表:
在该图表中,虚线代表训练损失和准确率,实线代表验证损失和准确率。
请注意,训练损失会逐周期下降,而训练准确率则逐周期上升。使用梯度下降优化时,这是预期结果,它应该在每次迭代中最大限度减少所需的数量。
但是,对于验证损失和准确率来说则不然——它们似乎会在训练转确率之前达到顶点。这是过拟合的一个例子:模型在训练数据上的表现要好于在之前从未见过的数据上的表现。经过这一点之后,模型会过度优化和学习特定于训练数据的表示,但无法泛化到测试数据。
对于这种特殊情况,您可以通过在验证准确率不再增加时直接停止训练来防止过度拟合。一种方式是使用 tf.keras.callbacks.EarlyStopping
回调。
导出模型
在上面的代码中,您在向模型馈送文本之前对数据集应用了 TextVectorization
。 如果您想让模型能够处理原始字符串(例如,为了简化部署),您可以在模型中包含 TextVectorization
层。为此,您可以使用刚刚训练的权重创建一个新模型。
使用新数据进行推断
要获得对新样本的预测,只需调用 model.predict()
即可。
练习:对 Stack Overflow 问题进行多类分类
本教程展示了如何在 IMDB 数据集上从头开始训练二元分类器。作为练习,您可以修改此笔记本以训练多类分类器来预测 Stack Overflow 上的编程问题的标签。
我们已经准备好了一个数据集供您使用,其中包含了几千个发布在 Stack Overflow 上的编程问题(例如,"How can sort a dictionary by value in Python?")。每一个问题都只有一个标签(Python、CSharp、JavaScript 或 Java)。您的任务是将问题作为输入,并预测适当的标签,在本例中为 Python。
您将使用的数据集包含从 BigQuery 上更大的公共 Stack Overflow 数据集提取的数千个问题,其中包含超过 1700 万个帖子。
下载数据集后,您会发现它与您之前使用的 IMDB 数据集具有相似的目录结构:
注:为了增加分类问题的难度,编程问题中出现的 Python、CSharp、JavaScript 或 Java 等词已被替换为 blank(因为许多问题都包含它们所涉及的语言)。
要完成此练习,您应该对此笔记本进行以下修改以使用 Stack Overflow 数据集:
在笔记本顶部,将下载 IMDB 数据集的代码更新为下载前面准备好的 Stack Overflow 数据集的代码。由于 Stack Overflow 数据集具有类似的目录结构,因此您不需要进行太多修改。
将模型的最后一层修改为
Dense(4)
,因为现在有四个输出类。编译模型时,将损失更改为
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
。当每个类的标签是整数(在本例中,它们可以是 0、1、2 或 3)时,这是用于多类分类问题的正确损失函数。 此外,将指标更改为metrics=['accuracy']
,因为这是一个多类分类问题(tf.metrics.BinaryAccuracy
仅用于二元分类器 )。在绘制随时间变化的准确率时,请将
binary_accuracy
和val_binary_accuracy
分别更改为accuracy
和val_accuracy
。完成这些更改后,就可以训练多类分类器了。
了解更多信息
本教程从头开始介绍了文本分类。要详细了解一般的文本分类工作流程,请查看 Google Developers 提供的文本分类指南。