Path: blob/master/site/ja/guide/keras/custom_callback.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
Writing your own callbacks
はじめに
コールバックは、トレーニング、評価、推論の間に Keras モデルの動作をカスタマイズするための強力なツールです。例には、TensorBoard でトレーニングの進捗状況や結果を可視化できる tf.keras.callbacks.TensorBoard
や、トレーニング中にモデルを定期的に保存できる tf.keras.callbacks.ModelCheckpoint
などを含みます。
このガイドでは、Keras コールバックとは何か、それができること、そして独自のコールバックを構築する方法を学ぶことができます。まずは、簡単なコールバックアプリケーションのデモをいくつか紹介します。
Setup
Keras コールバックの概要
全てのコールバックは keras.callbacks.Callbacks.Callback
クラスをサブクラス化し、トレーニング、テスト、予測のさまざまな段階で呼び出される一連のメソッドをオーバーライドします。コールバックは、トレーニング中にモデルの内部状態や統計上のビューを取得するのに有用です。
以下のモデルメソッドには、(キーワード引数 callbacks
として)コールバックのリストを渡すことができます。
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
コールバックメソッドの概要
グローバルメソッド
on_(train|test|predict)_begin(self, logs=None)
fit
/evaluate
/predict
の先頭で呼び出されます。
on_(train|test|predict)_end(self, logs=None)
fit
/evaluate
/predict
の最後に呼び出されます。
トレーニング/テスト/予測のためのバッチレベルのメソッド
on_(train|test|predict)_batch_begin(self, batch, logs=None)
トレーニング/テスト/予測中に、バッチを処理する直前に呼び出されます。
on_(train|test|predict)_batch_end(self, batch, logs=None)
バッチのトレーニング/テスト/予測の終了時に呼び出されます。このメソッド内では、logs
はメトリクスの結果を含むディクショナリです。
エポックレベルのメソッド(トレーニングのみ)
on_epoch_begin(self, epoch, logs=None)
トレーニング中に、エポックの最初に呼び出されます。
on_epoch_end(self, epoch, logs=None)
トレーニング中、エポックの最後に呼び出されます。
基本的な例
具体的な例を見てみましょう。まず最初に、TensorFlow をインポートして単純な Sequential Keras モデルを定義してみます。
次に、Keras データセット API からトレーニングとテスト用の MNIST データを読み込みます。
今度は、以下のログを記録する単純なカスタムコールバックを定義します。
When
fit
/evaluate
/predict
starts & endsWhen each epoch starts & ends
各トレーニングバッチの開始時と終了時
各評価(テスト)バッチの開始時と終了時
各推論(予測)バッチの開始時と終了時
試してみましょう。
logs
ディクショナリを使用する
logs
ディクショナリは、バッチまたはエポックの最後の損失値と全てのメトリクスを含みます。次の例は、損失値と平均絶対誤差を含んでいます。
self.model
属性を使用する
コールバックは、そのメソッドの 1 つが呼び出された時にログ情報を受け取ることに加え、現在のトレーニング/評価/推論のラウンドに関連付けられたモデルに、self.model
でアクセスすることができます。
コールバックで self.model
を使用してできることを幾つか次に示します。
self.model.stop_training = True
を設定して直ちにトレーニングを中断する。self.model.optimizer.learning_rate
など、オプティマイザ(self.model.optimizer
として使用可能)のハイパーパラメータを変化させる。一定間隔でモデルを保存する。
各エポックの終了時に幾つかのテストサンプルの
model.predict()
の出力を記録し、トレーニング中にサ二ティーチェックとして使用する。各エポックの終了時に中間特徴の可視化を抽出して、モデルが何を学習しているかを経時的に監視する。
など
これを確認するために、2 つの例で見てみましょう。
Keras コールバックアプリケーションの例
最小損失で Early stopping する
この最初の例は、属性 self.model.stop_training
(ブール)を設定して、損失の最小値に達した時点でトレーニングを停止する Callback
を作成しています。オプションで、ローカル最小値に到達した後、実際に停止するまでに幾つのエポックを待つべきか、引数 patience
で指定することが可能です。
tf.keras.callbacks.EarlyStopping
は、より完全で一般的な実装を提供します。
学習率をスケジューリングする
この例では、トレーニングの過程でカスタムコールバックを使用して、オプティマイザの学習率を動的に変更する方法を示します。
より一般的な実装については、callbacks.LearningRateScheduler
をご覧ください。
組み込みの Keras コールバック
既存の Keras コールバックについては、API ドキュメントを読んで必ず確認してください。アプリケーションには、CSV へのロギング、モデルの保存、TensorBoard でのメトリクスの可視化、その他多数があります。