Trainer TFX パイプラインコンポーネント
Trainer TFX パイプラインコンポーネントは TensorFlow モデルをトレーニングします。
Trainer と TensorFlow
Trainer は、モデルのトレーニングに Python TensorFlow API を多大に使用しています。
注意: TFX は TensorFlow 1.15 と 2.x をサポートします。
コンポーネント
Trainer は次を取り込みます。
training と eval に使用される tf.Examples
Trainer のロジックを定義するユーザー指定のモジュールファイル
train args と eval args の Protobuf 定義
(オプション)SchemaGen パイプラインコンポーネントが作成し、開発者がオプションとして変更できるデータスキーマ
(オプション)上流の Transform コンポーネントが生成する transform グラフ
(オプション)warmstart などのシナリオに使用される事前トレーニング済みのモデル
(オプション)ユーザーモジュール関数に渡されるハイパーパラメータ。Tuner との統合に関する詳細は、こちらをご覧ください。
Trainer の出力: 少なくとも 1 つの推論/サービング用モデル(通常 SavedModel 形式)とオプションとして eval 用のモデル(通常 EvalSavedModel)
TFLite などの代替のモデル形式のサポートは Model Rewriting ライブラリを通じて提供しています。Estimator と Keras モデルの両方の変換方法の例については、Model Rewriting ライブラリへのリンクをご覧ください。
汎用 Trainer
汎用の Trainer を使用すると、開発者はあらゆる TensorFlow でもる API を Trainer コンポーネントと使用できるようになります。TensorFlow Estimator のほか、Keras モデルやカスタムトレーニングループを使用できます。詳細については、汎用 Trainer 用の RFCをご覧ください。
Trainer コンポーネントを構成する
以下は、汎用 Trainer の一般的なパイプライン DSL コードの例です。
Trainer は module_file
パラメーターに指定されているトレーニングモジュールを呼び出します。custom_executor_spec
に GenericExecutor
が指定されている場合、モジュールファイルには trainer_fn
の代わりに run_fn
が必要です。trainer_fn
はモデルの作成を行います。そのほか、run_fn
はトレーニングの部分を処理し、トレーニング済みのモデルを FnArgs で指定された目的の場所に出力する必要もあります。
上記の Example モジュールファイルでは run_fn
を使用しています。
Transform コンポーネントがパイプラインで使用されていない場合、Trainer は直接 ExampleGen の Example を取るところに注意してください。
詳細については、Trainer API リファレンスをご覧ください。