Path: blob/master/site/zh-cn/tutorials/estimator/premade.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
预创建的 Estimators
首先要做的事
为了开始,您将首先导入 Tensorflow 和一系列您需要的库。
接下来,使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。
通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。
对于每个数据集都分割出标签,模型将被训练来预测这些标签。
Estimator 编程概述
现在您已经设置了数据,可以使用 TensorFlow Estimator 定义模型。 Estimator 是从 tf.estimator.Estimator
派生的任何类。TensorFlow 提供了一组 tf.estimator
(例如 LinearRegressor
)来实现常见的 ML 算法。除此之外,您可以编写自己的自定义 Estimator。建议在刚开始时使用预制的 Estimator。
为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:
创建一个或多个输入函数
定义模型的特征列
实例化一个 Estimator,指定特征列和各种超参数。
在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。
我们来看看这些任务是如何在鸢尾花分类中实现的。
您的输入函数可以用您喜欢的任何方式生成 features
字典和label
列表。但是,推荐使用 TensorFlow 的 Dataset API,它可以解析各种数据。
Dataset API 可以为您处理很多常见情况。例如,使用 Dataset API,您可以轻松地从大量文件中并行读取记录,并将它们合并为单个数据流。
为了简化此示例,我们将使用 pandas 加载数据,并利用此内存数据构建输入管道。
定义特征列(feature columns)
特征列(feature columns)是一个对象,用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。tf.feature_column
模块提供了许多为模型表示数据的选项。
对于鸢尾花,4 个原始特征是数值,因此您将构建一个特征列列表来告诉 Estimator 模型将四个特征中的每一个表示为 32 位浮点值。因此,创建特征列的代码为:
特征列可能比这里显示的要复杂得多。您可以在此指南中阅读有关特征列的更多信息。
我们已经介绍了如何使模型表示原始特征,现在您可以构建 Estimator 了。
实例化 Estimator
鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:
tf.estimator.DNNClassifier
用于多类别分类的深度模型tf.estimator.DNNLinearCombinedClassifier
用于广度与深度模型tf.estimator.LinearClassifier
用于基于线性模型的分类器
对于鸢尾花问题,tf.estimator.DNNClassifier
似乎是最好的选择。您可以这样实例化该 Estimator:
训练、评估和预测
我们已经有一个 Estimator 对象,现在可以调用方法来执行下列操作:
训练模型。
评估经过训练的模型。
使用经过训练的模型进行预测。
训练模型
通过调用 Estimator 的 Train
方法来训练模型,如下所示:
注意将 input_fn
调用封装在 lambda
中以获取参数,同时提供不带参数的输入函数,如 Estimator 所预期的那样。step
参数告知该方法在训练多少步后停止训练。
评估经过训练的模型
现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:
与对 train
方法的调用不同,我们没有传递 steps
参数来进行评估。用于评估的 input_fn
只生成一个 epoch 的数据。
eval_result
字典亦包含 average_loss
(每个样本的平均误差),loss
(每个 mini-batch 的平均误差)与 Estimator 的 global_step
(经历的训练迭代次数)值。
利用经过训练的模型进行预测(推理)
我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:
predict
方法返回一个 Python 可迭代对象,为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率: