Path: blob/master/site/ja/tutorials/text/text_classification_rnn.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
RNN を使ったテキスト分類
Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳はベストエフォートであるため、この翻訳が正確であることや英語の公式ドキュメントの 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリtensorflow/docsにプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [email protected] メーリングリストにご連絡ください。
このテキスト分類チュートリアルでは、感情分析のために IMDB 映画レビュー大型データセット を使って リカレントニューラルネットワーク を訓練します。
設定
matplotlib
をインポートしグラフを描画するためのヘルパー関数を作成します。
入力パイプラインの設定
IMDB 映画レビュー大型データセットは二値分類データセットです。すべてのレビューは、好意的(positive) または 非好意的(negative) のいずれかの感情を含んでいます。
TFDS を使ってこのデータセットをダウンロードします。
このデータセットの info
には、エンコーダー(tfds.features.text.SubwordTextEncoder
) が含まれています。
このテキストエンコーダーは、任意の文字列を可逆的にエンコードします。必要であればバイトエンコーディングにフォールバックします。
訓練用データの準備
次に、これらのエンコード済み文字列をバッチ化します。padded_batch
メソッドを使ってバッチ中の一番長い文字列の長さにゼロパディングを行います。
Note: TensorFlow 2.2 から、padded_shapes は必須ではなくなりました。デフォルトではすべての軸をバッチ中で最も長いものに合わせてパディングします。
モデルの作成
tf.keras.Sequential
モデルを構築しましょう。最初に Embedding レイヤーから始めます。Embedding レイヤーは単語一つに対して一つのベクトルを収容します。呼び出しを受けると、Embedding レイヤーは単語のインデックスのシーケンスを、ベクトルのシーケンスに変換します。これらのベクトルは訓練可能です。(十分なデータで)訓練されたあとは、おなじような意味をもつ単語は、しばしばおなじようなベクトルになります。
このインデックス参照は、ワンホットベクトルを tf.keras.layers.Dense
レイヤーを使って行うおなじような演算に比べてずっと効率的です。
リカレントニューラルネットワーク(RNN)は、シーケンスの入力を要素を一つずつ扱うことで処理します。RNN は、あるタイムステップでの出力を次のタイムステップの入力へと、次々に渡していきます。
RNN レイヤーとともに、tf.keras.layers.Bidirectional
ラッパーを使用することができます。このラッパーは、入力を RNN 層の順方向と逆方向に伝え、その後出力を結合します。これにより、RNN は長期的な依存関係を学習できます。
訓練プロセスを定義するため、Keras モデルをコンパイルします。
モデルの訓練
上記のモデルはシーケンスに適用されたパディングをマスクしていません。パディングされたシーケンスで訓練を行い、パディングをしていないシーケンスでテストするとすれば、このことが結果を歪める可能性があります。理想的にはこれを避けるために、 マスキングを使うべきですが、下記のように出力への影響は小さいものでしかありません。
予測値が 0.5 以上であればポジティブ、それ以外はネガティブです。
2つ以上の LSTM レイヤーを重ねる
Keras のリカレントレイヤーには、コンストラクタの return_sequences
引数でコントロールされる2つのモードがあります。
それぞれのタイムステップの連続した出力のシーケンス全体(shape が
(batch_size, timesteps, output_features)
の3階テンソル)を返す。それぞれの入力シーケンスの最後の出力だけ(shape が
(batch_size, output_features)
の2階テンソル)を返す。
GRU レイヤーなど既存のほかのレイヤーを調べてみましょう。
カスタム RNN の構築に興味があるのであれば、Keras RNN ガイド を参照してください。