Copyright 2020 The TensorFlow Authors.
Keras の再帰型ニューラルネットワーク(RNN)
はじめに
再帰型ニューラルネットワーク(RNN)は、時系列や自然言語などのシーケンスデータのモデリングを強力に行うニューラルネットワークのクラスです。
概略的には、RNN レイヤーは for
ループを使用して、それまでに確認した時間ステップに関する情報をエンコードする内部状態を維持しながらシーケンスの時間ステップをイテレートします。
Keras RNN API は、次に焦点を当てて設計されています。
使いやすさ:
keras.layers.RNN
、keras.layers.LSTM
、keras.layers.GRU
レイヤーがビルトインされているため、難しい構成選択を行わずに、再帰型モデルを素早く構築できます。カスタマイズしやすさ: カスタムビヘイビアを使って独自の RNN セルレイヤーを構築し(
for
ループの内部)、一般的なkeras.layers.RNN
レイヤー(for
ループ自体)で使用することもできます。このため、異なるリサーチアイデアを最小限のコードで柔軟に素早くプロトタイプすることができます。
セットアップ
ビルトイン RNN レイヤー: 単純な例
Keras には、次の 3 つのビルトイン RNN レイヤーがあります。
keras.layers.SimpleRNN
: 前の時間ステップの出力が次の時間ステップにフィードされる、完全に連結された RNN です。keras.layers.GRU
: Cho et al., 2014 で初めて提案されたレイヤー。keras.layers.LSTM
: Hochreiter & Schmidhuber, 1997 で初めて提案されたレイヤー。
2015 年始めに、Keras に、LSTM および GRU の再利用可能なオープンソース Python 実装が導入されました。
整数のシーケンスを処理し、そのような整数を 64 次元ベクトルに埋め込み、LSTM
レイヤーを使用してベクトルのシーケンスを処理する Sequential
モデルの単純な例を次に示しています。
ビルトイン RNN は、多数の有益な特徴をサポートしています。
dropout
およびrecurrent_dropout
引数を介した再帰ドロップアウトgo_backwards
引数を介して、入力シーケンスを逆順に処理する能力unroll
引数を介したループ展開(CPU で短いシーケンスを処理する際に大幅な高速化が得られる)など。
詳細については、「RNN API ドキュメント」を参照してください。
出力と状態
デフォルトでは、RNN レイヤーの出力には、サンプル当たり 1 つのベクトルが含まれます。このベクトルは、最後の時間ステップに対応する RNN セル出力で、入力シーケンス全体の情報が含まれます。この出力の形状は (batch_size, units)
で、units
はレイヤーのコンストラクタに渡される units
引数に対応します。
RNN レイヤーは、return_sequences=True
に設定した場合、各サンプルに対する出力のシーケンス全体(各サンプルの時間ステップごとに 1 ベクトル)を返すこともできます。この出力の形状は (batch_size, timesteps, units)
です。
さらに、RNN レイヤーはその最終内部状態を返すことができます。返された状態は、後で RNN 実行を再開する際に使用するか、別の RNN を初期化するために使用できます。この設定は通常、エンコーダ・デコーダ方式の Sequence-to-Sequence モデルで使用され、エンコーダの最終状態がデコーダの初期状態として使用されます。
内部状態を返すように RNN レイヤーを構成するには、レイヤーを作成する際に、return_state
パラメータを True
に設定します。LSTM
には状態テンソルが 2 つあるのに対し、GRU
には 1 つしかないことに注意してください。
レイヤーの初期状態を構成するには、追加のキーワード引数 initial_state
を使ってレイヤーを呼び出します。次の例に示すように、状態の形状は、レイヤーのユニットサイズに一致する必要があることに注意してください。
RNN レイヤーと RNN セル
ビルトイン RNN レイヤーのほかに、RNN API は、セルレベルの API も提供しています。入力シーケンスの全バッチを処理する RNN レイヤーとは異なり、RNN セルは単一の時間ステップのみを処理します。
セルは、RNN レイヤーの for
ループ内にあります。keras.layers.RNN
レイヤー内のセルをラップすることで、シーケンスのバッチを処理できるレイヤー(RNN(LSTMCell(10))
など)を得られます。
数学的には、RNN(LSTMCell(10))
は LSTM(10)
と同じ結果を出します。実際、TF v1.x でのこのレイヤーの実装は、対応する RNN セルを作成し、それを RNN レイヤーにラップするだけでした。ただし、ビルトインの GRU
と LSTM
レイヤーを使用すれば、CuDNN が使用できるようになり、パフォーマンスの改善を確認できることがあります。
ビルトイン RNN セルには 3 つあり、それぞれ、それに一致する RNN レイヤーに対応しています。
keras.layers.SimpleRNNCell
はSimpleRNN
レイヤーに対応します。keras.layers.GRUCell
はGRU
レイヤーに対応します。keras.layers.LSTMCell
はLSTM
レイヤーに対応します。
セルの抽象化とジェネリックな keras.layers.RNN
クラスを合わせることで、リサーチ用のカスタム RNN アーキテクチャの実装を簡単に行えるようになります。
バッチ間のステートフルネス
非常に長い(無限の可能性のある)シーケンスを処理する場合は、バッチ間ステートフルネスのパターンを使用するとよいでしょう。
通常、RNN レイヤーの内部状態は、新しいバッチが確認されるたびにリセットされます(レイヤーが確認する各サンプルは、過去のサンプルとは無関係だと考えられます)。レイヤーは、あるサンプルを処理する間のみ状態を維持します。
ただし、非常に長いシーケンスがある場合、より短いシーケンスに分割し、レイヤーの状態をリセットせずにそれらの短いシーケンスを順次、RNN レイヤーにフィードすることができます。こうすると、レイヤーはサブシーケンスごとに確認していても、シーケンス全体の情報を維持することができます。
これは、コンストラクタに stateful=True
を設定して行います。
シーケンス s = [t0, t1, ... t1546, t1547]
があるとした場合、これを次のように分割します。
そして、次のようにして処理します。
状態をクリアする場合は、layer.reset_states()
を使用できます。
注意: このセットアップでは、あるバッチのサンプル
i
は前のバッチのサンプルi
の続きであることを前提としています。つまり、すべてのバッチには同じ数のサンプル(バッチサイズ)が含まれることになります。たとえば、バッチに[sequence_A_from_t0_to_t100, sequence_B_from_t0_to_t100]
が含まれるとした場合、次のバッチには、[sequence_A_from_t101_to_t200, sequence_B_from_t101_to_t200]
が含まれます。
完全な例を次に示します。
RNN の記録済みの状態は、layer.weights()
には含まれません。RNN レイヤーの状態を再利用する場合は、layer.states
によって状態の値を取得し、new_layer(inputs, initial_state=layer.states)
などの Keras Functional API またはモデルのサブクラス化を通じて新しいレイヤーの初期状態として使用することができます。
この場合には、単一の入力と出力を持つレイヤーのみをサポートする Sequential モデルを使用できない可能性があることにも注意してください。このモデルでは追加入力としての初期状態を使用することができません。
双方向性 RNN
時系列以外のシーケンスについては(テキストなど)、開始から終了までのシーケンスを処理だけでなく、逆順に処理する場合、RNN モデルの方がパフォーマンスに優れていることがほとんどです。たとえば、ある文で次に出現する単語を予測するには、その単語の前に出現した複数の単語だけでなく、その単語に関する文脈があると役立ちます。
Keras は、そのような双方向性のある RNN を構築するために、keras.layers.Bidirectional
ラッパーという簡単な API を提供しています。
内部的には、Bidirectional
は渡された RNN レイヤーをコピーし、新たにコピーされたレイヤーの go_backwards
フィールドを転換して、入力が逆順に処理されるようにします。
Bidirectional
RNN の出力は、デフォルトで、フォワードレイヤー出力とバックワードレイヤー出力の総和となります。これとは異なるマージ動作が必要な場合は(連結など)、Bidirectional
ラッパーコンストラクタの merge_mode
パラメータを変更します。Bidirectional
の詳細については、API ドキュメントをご覧ください。
パフォーマンス最適化と CuDNN カーネル
TensorFlow 2.0 では、ビルトインの LSTM と GRU レイヤーは、GPU が利用できる場合にデフォルトで CuDNN カーネルを活用するように更新されています。この変更により、以前の keras.layers.CuDNNLSTM/CuDNNGRU
レイヤーは使用廃止となったため、実行するハードウェアを気にせずにモデルを構築することができます。
CuDNN カーネルは、特定の前提を以って構築されており、レイヤーはビルトイン LSTM または GRU レイヤーのデフォルト値を変更しない場合は CuDNN カーネルを使用できません。これらには次のような例があります。
activation
関数をtanh
からほかのものに変更する。recurrent_activation
関数をsigmoid
からほかのものに変更する。recurrent_dropout
> 0 を使用する。unroll
を True に設定する。LSTM/GRU によって内部tf.while_loop
は展開済みfor
ループに分解されます。use_bias
を False に設定する。入力データが厳密に右詰でない場合にマスキングを使用する(マスクが厳密に右詰データに対応している場合でも、CuDNN は使用されます。これは最も一般的な事例です)。
利用できる場合に CuDNN カーネルを使用する
パフォーマンスの違いを確認するために、単純な LSTM モデルを構築してみましょう。
入力シーケンスとして、MNIST 番号の行のシーケンスを使用し(ピクセルの各行を時間ステップとして扱います)、番号のラベルを予測します。
MNIST データセットを読み込みましょう。
モデルのインスタンスを作成してトレーニングしましょう。
sparse_categorical_crossentropy
をモデルの損失関数として選択します。モデルの出力形状は [batch_size, 10]
です。モデルのターゲットは整数ベクトルで、各整数は 0 から 9 の範囲内にあります。
では、CuDNN カーネルを使用しないモデルと比較してみましょう。
NVIDIA GPU と CuDNN がインストールされたマシンで実行すると、CuDNN で構築されたモデルの方が、通常の TensorFlow カーネルを使用するモデルに比べて非常に高速に実行されます。
CPU のみの環境で推論を実行する場合でも、同じ CuDNN 対応モデルを使用できます。次の tf.device
注釈は単にデバイスの交換を強制しています。GPU が利用できないな場合は、デフォルトで CPU で実行されます。
実行するハードウェアを気にする必要がなくなったのです。素晴らしいと思いませんか?
リスト/ディクショナリ入力、またはネストされた入力を使う RNN
ネスト構造の場合、インプルメンターは単一の時間ステップにより多くの情報を含めることができます。たとえば、動画のフレームに、音声と動画の入力を同時に含めることができます。この場合のデータ形状は、次のようになります。
[batch, timestep, {"video": [height, width, channel], "audio": [frequency]}]
別の例では、手書きのデータに、現在のペンの位置を示す座標 x と y のほか、筆圧情報も含めることができます。データは次のように表現できます。
[batch, timestep, {"location": [x, y], "pressure": [force]}]
次のコードは、このような構造化された入力を受け入れるカスタム RNN セルの構築方法を例に示しています。
ネストされた入力/出力をサポートするカスタムセルを定義する
独自レイヤーの記述に関する詳細は、「サブクラス化による新規レイヤーとモデルの作成」を参照してください。
ネストされた入力/出力で RNN モデルを構築する
上記で定義した keras.layers.RNN
レイヤーとカスタムセルを使用する Keras モデルを構築しましょう。
ランダムに生成されたデータでモデルをトレーニングする
このモデルに適した候補データセットを持ち合わせていないため、ランダムな Numpy データを使って実演することにします。
Keras keras.layers.RNN
レイヤーでは、シーケンス内の個別のステップの数学ロジックを定義することだけが期待されています。シーケンスのイテレーションは、keras.layers.RNN
レイヤーによって処理されます。新しいタイプの RNN(LSTM など) を素早くプロトタイプ化する上で、非常に強力な方法です。
詳細については、API ドキュメントを参照してください。