Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/guide/migrate/canned_estimators.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Canned(または既製)Estimator は、TensorFlow 1 でさまざまな典型的なユースケースのモデルをトレーニングするための迅速かつ簡単な方法として従来使用されてきました。 TensorFlow 2 は、Keras モデルを介して、それらの多くの単純な近似代用を提供します。 TensorFlow 2 の代用が組み込まれていない Canned Estimator の場合でも、独自の置換をかなり簡単に構築できます。

このガイドでは、TensorFlow 1 の tf.estimator から派生したモデルを Keras を使用して TensorFlow 2 に移行する方法を示すために、直接相当するものとカスタム置換の例をいくつか紹介します。

すなわち、このガイドには移行の例が含まれています。

  • TensorFlow 1 の tf.estimatorLinearEstimatorClassifier または Regressor から、TensorFlow 2 の tf.compat.v1.keras.models.LinearModel

  • TensorFlow 1 の tf.estimatorDNNEstimatorClassifier または Regressor から、TensorFlow 2 のカスタム Keras DNN ModelKeras へ

  • TensorFlow 1 の tf.estimatorDNNLinearCombinedEstimatorClassifier または Regressor から、TensorFlow 2 の tf.compat.v1.keras.models.WideDeepModel

  • TensorFlow 1 の tf.estimatorBoostedTreesEstimatorClassifier または Regressor から、TensorFlow 2 の tfdf.keras.GradientBoostedTreesModel

モデルのトレーニングの一般的な前処理は、特徴量の前処理です。これは、tf.feature_column を使用して TensorFlow 1 Estimator モデルに対して行われます。TensorFlow 2 での特徴量の前処理の詳細については、特徴量列から Keras 前処理レイヤー API への移行に関するこのガイドをご覧ください。

セットアップ

いくつかの必要な TensorFlow インポートから始めます。

!pip install tensorflow_decision_forests
import keras import pandas as pd import tensorflow as tf import tensorflow.compat.v1 as tf1 import tensorflow_decision_forests as tfdf

標準のタイタニックのデータセットからデモンストレーション用のいくつかの簡単なデータを準備します。

x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv') x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv') x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True) x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True) x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True) x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True) x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True) x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True) x_train.drop(['embark_town', 'deck'], axis=1, inplace=True) x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True) y_train = x_train.pop('survived') y_eval = x_eval.pop('survived')
# Data setup for TensorFlow 1 with `tf.estimator` def _input_fn(): return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32) def _eval_input_fn(): return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32) FEATURE_NAMES = [ 'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone' ] feature_columns = [] for fn in FEATURE_NAMES: feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32) feature_columns.append(feat_col)

そして、さまざまな TensorFlow 1 Estimator および TensorFlow 2 Keras モデルで使用する単純なサンプルオプティマイザをインスタンス化するメソッドを作成します。

def create_sample_optimizer(tf_version): if tf_version == 'tf1': optimizer = lambda: tf.keras.optimizers.legacy.Ftrl( l1_regularization_strength=0.001, learning_rate=tf1.train.exponential_decay( learning_rate=0.1, global_step=tf1.train.get_global_step(), decay_steps=10000, decay_rate=0.9)) elif tf_version == 'tf2': optimizer = tf.keras.optimizers.legacy.Ftrl( l1_regularization_strength=0.001, learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9)) return optimizer

例 1: LinearEstimator からの移行

TensorFlow 1: LinearEstimator の使用

TensorFlow 1 では、tf.estimator.LinearEstimator を使用して、回帰および分類問題のベースライン線形モデルを作成できます。

linear_estimator = tf.estimator.LinearEstimator( head=tf.estimator.BinaryClassHead(), feature_columns=feature_columns, optimizer=create_sample_optimizer('tf1'))
linear_estimator.train(input_fn=_input_fn, steps=100) linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)

TensorFlow 2: Keras LinearModel の使用

TensorFlow 2 では、tf.estimator.LinearEstimator の代替である Keras tf.compat.v1.keras.models.LinearModel のインスタンスを作成できます。tf.compat.v1.keras パスは、互換性のために事前に作成されたモデルが存在することを示すために使用されます。

linear_model = tf.compat.v1.keras.experimental.LinearModel() linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy']) linear_model.fit(x_train, y_train, epochs=10) linear_model.evaluate(x_eval, y_eval, return_dict=True)

例 2: DNNEstimator からの移行

TensorFlow 1: DNNEstimator の使用

TensorFlow 1 では、tf.estimator.DNNEstimator を使用して、回帰および分類問題のベースラインとなるディープニューラルネットワーク(DNN)モデルを作成できます。

dnn_estimator = tf.estimator.DNNEstimator( head=tf.estimator.BinaryClassHead(), feature_columns=feature_columns, hidden_units=[128], activation_fn=tf.nn.relu, optimizer=create_sample_optimizer('tf1'))
dnn_estimator.train(input_fn=_input_fn, steps=100) dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)

TensorFlow 2: Keras を使用してカスタム DNN モデルを作成する

TensorFlow 2 では、カスタム DNN モデルを作成して、tf.estimator.DNNEstimator によって生成されたものを置き換えることができ、同様のレベルのユーザー指定のカスタマイズが可能です(例えば、前の例のように、選択したモデルオプティマイザをカスタマイズする機能)。

同様のワークフローを使用して、tf.estimator.experimental.RNNEstimator を Keras 再帰型ニューラルネットワーク(RNN)モデルに置き換えることができます。Keras は、tf.keras.layers.RNNtf.keras.layers.LSTM、および tf.keras.layers.GRU によって、多数の組み込みのカスタマイズ可能な選択肢を提供します。詳細については、Keras を使用した RNN ガイド組み込み RNN レイヤー: 簡単な例をご覧ください。

dnn_model = tf.keras.models.Sequential( [tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(1)]) dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
dnn_model.fit(x_train, y_train, epochs=10) dnn_model.evaluate(x_eval, y_eval, return_dict=True)

例 3: DNNLinearCombinedEstimator からの移行

TensorFlow 1: DNNLinearCombinedEstimator の使用

TensorFlow 1 では、tf.estimator.DNNLinearCombinedEstimator を使用して、線形コンポーネントと DNN コンポーネントの両方のカスタマイズ機能を備えた回帰および分類問題のベースライン結合モデルを作成できます。

optimizer = create_sample_optimizer('tf1') combined_estimator = tf.estimator.DNNLinearCombinedEstimator( head=tf.estimator.BinaryClassHead(), # Wide settings linear_feature_columns=feature_columns, linear_optimizer=optimizer, # Deep settings dnn_feature_columns=feature_columns, dnn_hidden_units=[128], dnn_optimizer=optimizer)
combined_estimator.train(input_fn=_input_fn, steps=100) combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)

TensorFlow 2: Keras WideDeepModel の使用

TensorFlow 2 では、Keras の tf.compat.v1.keras.models.WideDeepModel インスタンスを作成して、tf.estimator.DNNLinearCombinedEstimator によって生成されたものを置き換えることができ、同様のレベルのユーザー指定のカスタマイズが可能です(例えば、前の例のように、選択したモデルオプティマイザをカスタマイズする機能)。

この WideDeepModel は、構成要素である LinearModel とカスタム DNN モデルに基づいて構築されます。どちらも前の 2 つの例で説明されています。必要に応じて、組み込みの LinearModel の代わりにカスタム線形モデルを使用することもできます。

Canned Estimator の代わりに独自のモデルを構築したい場合は、 Keras Sequential モデルガイドをご覧ください。カスタムトレーニングとオプティマイザの詳細については、カスタムトレーニング: チュートリアルガイドをご覧ください。

# Create LinearModel and DNN Model as in Examples 1 and 2 optimizer = create_sample_optimizer('tf2') linear_model = tf.compat.v1.keras.experimental.LinearModel() linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) linear_model.fit(x_train, y_train, epochs=10, verbose=0) dnn_model = tf.keras.models.Sequential( [tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(1)]) dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model, dnn_model) combined_model.compile( optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy']) combined_model.fit([x_train, x_train], y_train, epochs=10) combined_model.evaluate(x_eval, y_eval, return_dict=True)

例 4: BoostedTreesEstimator からの移行

TensorFlow 1: BoostedTreesEstimator の使用

TensorFlow 1 では、tf.estimator.BoostedTreesEstimator を使用してベースラインを作成し、回帰および分類問題のデシジョンツリーのアンサンブルを使用してベースライン勾配ブースティングモデルを作成できました。この機能は、TensorFlow 2 には含まれなくなりました。

bt_estimator = tf1.estimator.BoostedTreesEstimator( head=tf.estimator.BinaryClassHead(), n_batches_per_layer=1, max_depth=10, n_trees=1000, feature_columns=feature_columns)
bt_estimator.train(input_fn=_input_fn, steps=1000) bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)

TensorFlow 2: TensorFlow Decision Forests の使用

TensorFlow 2 では、tf.estimator.BoostedTreesEstimator
TensorFlow Decision Forests パッケージの tfdf.keras.GradientBoostedTreesModel に置き換えられました。

TensorFlow Decision Forests は、tf.estimator.BoostedTreesEstimator に比べて、特に品質、速度、使いやすさ、および柔軟性に関してさまざまな利点を提供します。TensorFlow Decision Forests について学ぶには、初心者のための colab から始めてください。

次の例は、TensorFlow 2 を使用して勾配ブーストツリーモデルをトレーニングする方法を示しています。

TensorFlow Decision Forests のインストール

!pip install tensorflow_decision_forests

TensorFlow データセットを作成します。Decision Forests は多くの種類の特徴量をネイティブにサポートしており、前処理を必要としないことに注意してください。

train_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv') eval_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv') # Convert the Pandas Dataframes into TensorFlow datasets. train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_dataframe, label="survived") eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_dataframe, label="survived")

train_dataset データセットでモデルをトレーニングします。

# Use the default hyper-parameters of the model. gbt_model = tfdf.keras.GradientBoostedTreesModel() gbt_model.fit(train_dataset)

eval_dataset データセットでモデルの品質を評価します。

gbt_model.compile(metrics=['accuracy']) gbt_evaluation = gbt_model.evaluate(eval_dataset, return_dict=True) print(gbt_evaluation)

勾配ブーストツリーは、TensorFlow Decision Forests で利用できる多くのデシジョンフォレストアルゴリズムの 1 つにすぎません。たとえば、Random Forests(tfdf.keras.GradientBoostedTreesModel として利用可能であり、オーバーフィッティングに対して非常に耐性があります)に対して、CART(tfdf.keras.CartModel として利用可能)はモデルの解釈に最適です。

次の例では、Random Forest モデルをトレーニングしてプロットします。

# Train a Random Forest model rf_model = tfdf.keras.RandomForestModel() rf_model.fit(train_dataset) # Evaluate the Random Forest model rf_model.compile(metrics=['accuracy']) rf_evaluation = rf_model.evaluate(eval_dataset, return_dict=True) print(rf_evaluation)

最後の例では、CART モデルをトレーニングして評価します。

# Train a CART model cart_model = tfdf.keras.CartModel() cart_model.fit(train_dataset) # Plot the CART model tfdf.model_plotter.plot_model_in_colab(cart_model, max_depth=2)