Path: blob/master/site/zh-cn/tfx/guide/trainer.md
25118 views
Trainer TFX 流水线组件
Trainer TFX 流水线组件用于训练 TensorFlow 模型。
Trainer 和 TensorFlow
Trainer 广泛使用 Python TensorFlow API 来训练模型。
注:TFX 支持 TensorFlow 1.15 和 2.x。
组件
Trainer 需要:
用于训练和评估的 tf.Examples。
由用户提供、用于定义 Trainer 逻辑的模块文件。
训练参数和评估参数的 Protobuf 定义。
(可选)由 SchemaGen 流水线组件创建并且可由开发者选择性更改的数据架构。
(可选)由上游 Transform 组件生成的转换计算图。
(可选)用于热启动等场景的预训练模型。
(可选)将传递给用户模块函数的超参数。可以在此处找到与 Tuner 集成的详细信息。
Trainer 发出:至少一个用于推断/应用的模型(通常在 SavedModelFormat 中),以及另一个用于评估的可选模型(通常是一个 EvalSavedModel)。
我们通过模型重写库为 TFLite 等替代模型格式提供支持。有关如何同时转换 Estimator 和 Keras 模型的示例,请点击模型重写库的链接。
通用 Trainer
通用 Trainer 使开发者可以将任何 TensorFlow 模型 API 与 Trainer 组件一起使用。除了 TensorFlow Estimator 外,开发者还可以使用 Keras 模型或自定义训练循环。有关详细信息,请参阅通用 Trainer 的 RFC。
配置 Trainer 组件
通用 Trainer 的典型流水线 DSL 代码如下所示:
Trainer 调用一个训练模块,该模块在 module_file
参数中指定。如果在 custom_executor_spec
中指定了 GenericExecutor
, 则模块文件中需要 run_fn
,而不是 trainer_fn
。trainer_fn
负责创建模型。除此之外,run_fn
还需要处理训练部分并将训练后的模型输出到 FnArgs 给出的指定位置:
这是一个带有 run_fn
的示例模块文件。
请注意,如果流水线中未使用 Transform 组件,则 Trainer 将直接从 ExampleGen 中获取样本:
有关更多详细信息,请参阅 Trainer API 参考。