Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/guide/keras/custom_callback.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Writing your own callbacks

はじめに

コールバックは、トレーニング、評価、推論の間に Keras モデルの動作をカスタマイズするための強力なツールです。例には、TensorBoard でトレーニングの進捗状況や結果を可視化できる tf.keras.callbacks.TensorBoard や、トレーニング中にモデルを定期的に保存できる tf.keras.callbacks.ModelCheckpoint などを含みます。

このガイドでは、Keras コールバックとは何か、それができること、そして独自のコールバックを構築する方法を学ぶことができます。まずは、簡単なコールバックアプリケーションのデモをいくつか紹介します。

Setup

import tensorflow as tf from tensorflow import keras

Keras コールバックの概要

全てのコールバックは keras.callbacks.Callbacks.Callback クラスをサブクラス化し、トレーニング、テスト、予測のさまざまな段階で呼び出される一連のメソッドをオーバーライドします。コールバックは、トレーニング中にモデルの内部状態や統計上のビューを取得するのに有用です。

以下のモデルメソッドには、(キーワード引数 callbacks として)コールバックのリストを渡すことができます。

  • keras.Model.fit()

  • keras.Model.evaluate()

  • keras.Model.predict()

コールバックメソッドの概要

グローバルメソッド

on_(train|test|predict)_begin(self, logs=None)

fit/evaluate/predict の先頭で呼び出されます。

on_(train|test|predict)_end(self, logs=None)

fit/evaluate/predict の最後に呼び出されます。

トレーニング/テスト/予測のためのバッチレベルのメソッド

on_(train|test|predict)_batch_begin(self, batch, logs=None)

トレーニング/テスト/予測中に、バッチを処理する直前に呼び出されます。

on_(train|test|predict)_batch_end(self, batch, logs=None)

バッチのトレーニング/テスト/予測の終了時に呼び出されます。このメソッド内では、logs はメトリクスの結果を含むディクショナリです。

エポックレベルのメソッド(トレーニングのみ)

on_epoch_begin(self, epoch, logs=None)

トレーニング中に、エポックの最初に呼び出されます。

on_epoch_end(self, epoch, logs=None)

トレーニング中、エポックの最後に呼び出されます。

基本的な例

具体的な例を見てみましょう。まず最初に、TensorFlow をインポートして単純な Sequential Keras モデルを定義してみます。

# Define the Keras model to add callbacks to def get_model(): model = keras.Sequential() model.add(keras.layers.Dense(1, input_dim=784)) model.compile( optimizer=keras.optimizers.RMSprop(learning_rate=0.1), loss="mean_squared_error", metrics=["mean_absolute_error"], ) return model

次に、Keras データセット API からトレーニングとテスト用の MNIST データを読み込みます。

# Load example MNIST data and pre-process it (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train.reshape(-1, 784).astype("float32") / 255.0 x_test = x_test.reshape(-1, 784).astype("float32") / 255.0 # Limit the data to 1000 samples x_train = x_train[:1000] y_train = y_train[:1000] x_test = x_test[:1000] y_test = y_test[:1000]

今度は、以下のログを記録する単純なカスタムコールバックを定義します。

  • When fit/evaluate/predict starts & ends

  • When each epoch starts & ends

  • 各トレーニングバッチの開始時と終了時

  • 各評価(テスト)バッチの開始時と終了時

  • 各推論(予測)バッチの開始時と終了時

class CustomCallback(keras.callbacks.Callback): def on_train_begin(self, logs=None): keys = list(logs.keys()) print("Starting training; got log keys: {}".format(keys)) def on_train_end(self, logs=None): keys = list(logs.keys()) print("Stop training; got log keys: {}".format(keys)) def on_epoch_begin(self, epoch, logs=None): keys = list(logs.keys()) print("Start epoch {} of training; got log keys: {}".format(epoch, keys)) def on_epoch_end(self, epoch, logs=None): keys = list(logs.keys()) print("End epoch {} of training; got log keys: {}".format(epoch, keys)) def on_test_begin(self, logs=None): keys = list(logs.keys()) print("Start testing; got log keys: {}".format(keys)) def on_test_end(self, logs=None): keys = list(logs.keys()) print("Stop testing; got log keys: {}".format(keys)) def on_predict_begin(self, logs=None): keys = list(logs.keys()) print("Start predicting; got log keys: {}".format(keys)) def on_predict_end(self, logs=None): keys = list(logs.keys()) print("Stop predicting; got log keys: {}".format(keys)) def on_train_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Training: start of batch {}; got log keys: {}".format(batch, keys)) def on_train_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Training: end of batch {}; got log keys: {}".format(batch, keys)) def on_test_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys)) def on_test_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys)) def on_predict_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys)) def on_predict_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

試してみましょう。

model = get_model() model.fit( x_train, y_train, batch_size=128, epochs=1, verbose=0, validation_split=0.5, callbacks=[CustomCallback()], ) res = model.evaluate( x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()] ) res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])

logs ディクショナリを使用する

logs ディクショナリは、バッチまたはエポックの最後の損失値と全てのメトリクスを含みます。次の例は、損失値と平均絶対誤差を含んでいます。

class LossAndErrorPrintingCallback(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): print( "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"]) ) def on_test_batch_end(self, batch, logs=None): print( "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"]) ) def on_epoch_end(self, epoch, logs=None): print( "The average loss for epoch {} is {:7.2f} " "and mean absolute error is {:7.2f}.".format( epoch, logs["loss"], logs["mean_absolute_error"] ) ) model = get_model() model.fit( x_train, y_train, batch_size=128, epochs=2, verbose=0, callbacks=[LossAndErrorPrintingCallback()], ) res = model.evaluate( x_test, y_test, batch_size=128, verbose=0, callbacks=[LossAndErrorPrintingCallback()], )

self.model 属性を使用する

コールバックは、そのメソッドの 1 つが呼び出された時にログ情報を受け取ることに加え、現在のトレーニング/評価/推論のラウンドに関連付けられたモデルに、self.model でアクセスすることができます。

コールバックで self.model を使用してできることを幾つか次に示します。

  • self.model.stop_training = True を設定して直ちにトレーニングを中断する。

  • self.model.optimizer.learning_rate など、オプティマイザ(self.model.optimizer として使用可能)のハイパーパラメータを変化させる。

  • 一定間隔でモデルを保存する。

  • 各エポックの終了時に幾つかのテストサンプルの model.predict() の出力を記録し、トレーニング中にサ二ティーチェックとして使用する。

  • 各エポックの終了時に中間特徴の可視化を抽出して、モデルが何を学習しているかを経時的に監視する。

  • など

これを確認するために、2 つの例で見てみましょう。

Keras コールバックアプリケーションの例

最小損失で Early stopping する

この最初の例は、属性 self.model.stop_training(ブール)を設定して、損失の最小値に達した時点でトレーニングを停止する Callback を作成しています。オプションで、ローカル最小値に到達した後、実際に停止するまでに幾つのエポックを待つべきか、引数 patience で指定することが可能です。

tf.keras.callbacks.EarlyStopping は、より完全で一般的な実装を提供します。

import numpy as np class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """Stop training when the loss is at its min, i.e. the loss stops decreasing. Arguments: patience: Number of epochs to wait after min has been hit. After this number of no improvement, training stops. """ def __init__(self, patience=0): super(EarlyStoppingAtMinLoss, self).__init__() self.patience = patience # best_weights to store the weights at which the minimum loss occurs. self.best_weights = None def on_train_begin(self, logs=None): # The number of epoch it has waited when loss is no longer minimum. self.wait = 0 # The epoch the training stops at. self.stopped_epoch = 0 # Initialize the best as infinity. self.best = np.Inf def on_epoch_end(self, epoch, logs=None): current = logs.get("loss") if np.less(current, self.best): self.best = current self.wait = 0 # Record the best weights if current results is better (less). self.best_weights = self.model.get_weights() else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print("Restoring model weights from the end of the best epoch.") self.model.set_weights(self.best_weights) def on_train_end(self, logs=None): if self.stopped_epoch > 0: print("Epoch %05d: early stopping" % (self.stopped_epoch + 1)) model = get_model() model.fit( x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=30, verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()], )

学習率をスケジューリングする

この例では、トレーニングの過程でカスタムコールバックを使用して、オプティマイザの学習率を動的に変更する方法を示します。

より一般的な実装については、callbacks.LearningRateScheduler をご覧ください。

class CustomLearningRateScheduler(keras.callbacks.Callback): """Learning rate scheduler which sets the learning rate according to schedule. Arguments: schedule: a function that takes an epoch index (integer, indexed from 0) and current learning rate as inputs and returns a new learning rate as output (float). """ def __init__(self, schedule): super(CustomLearningRateScheduler, self).__init__() self.schedule = schedule def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, "lr"): raise ValueError('Optimizer must have a "lr" attribute.') # Get the current learning rate from model's optimizer. lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # Call schedule function to get the scheduled learning rate. scheduled_lr = self.schedule(epoch, lr) # Set the value back to the optimizer before this epoch starts tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr) print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr)) LR_SCHEDULE = [ # (epoch to start, learning rate) tuples (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001), ] def lr_schedule(epoch, lr): """Helper function to retrieve the scheduled learning rate based on epoch.""" if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]: return lr for i in range(len(LR_SCHEDULE)): if epoch == LR_SCHEDULE[i][0]: return LR_SCHEDULE[i][1] return lr model = get_model() model.fit( x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=15, verbose=0, callbacks=[ LossAndErrorPrintingCallback(), CustomLearningRateScheduler(lr_schedule), ], )

組み込みの Keras コールバック

既存の Keras コールバックについては、API ドキュメントを読んで必ず確認してください。アプリケーションには、CSV へのロギング、モデルの保存、TensorBoard でのメトリクスの可視化、その他多数があります。