Path: blob/master/site/ja/tutorials/distribute/save_and_load.ipynb
38627 views
Copyright 2019 The TensorFlow Authors.
分散ストラテジーを使ってモデルを保存して読み込む
概要
このチュートリアルでは、トレーニング中またはトレーニング後に tf.distribute.Strategy を使用して SavedModel 形式でモデルを保存して読み込む方法を説明します。Keras モデルの保存と読み込みには、高レベル(tf.keras.Model.save と tf.keras.models.load_model)と低レベル(tf.saved_model.save と tf.saved_model.load)の 2 種類の API があります。
SavedModel とシリアル化の全般的な内容については、SavedModel ガイドと Keras モデルのシリアル化ガイドをお読みください。では、単純な例から始めましょう。
注意: TensorFlow モデルはコードであるため、信頼できないコードには注意する必要があります。詳細は、TensorFlow を安全に使用するをご覧ください。
依存関係をインポートします。
TensorFlow Datasets と tf.data でデータを読み込んで準備し、tf.distribute.MirroredStrategy を使ってモデルを作成します。
tf.keras.Model.fit を使用してモデルをトレーニングします。
モデルを保存して読み込む
作業に使用する単純なモデルを準備できたので、保存と読み込みに使用する API を見てみましょう。使用できる API には、以下の 2 種類があります。
高レベル(Keras):
Model.saveおよびtf.keras.models.load_model(.keraszip アーカイブ形式)低レベル:
tf.saved_model.saveおよびtf.saved_model.load(TF SavedModel 形式)
Keras API
Keras API を使用したモデルの保存と読み込みの例を以下に示します。
tf.distribute.Strategy を使用せずにモデルを復元します。
モデルを復元したら、Model.compile をもう一度呼び出さずにそのままトレーニングを続行できます。これは、保存前にすでにコンパイル済みであるためです。このモデルは、Keras zip アーカイブ形式で保存されており、.keras 拡張子で識別できます。詳細については、Keras の保存に関するガイドをご覧ください。
次に、tf.distribute.Strategy を使用してモデルを復元し、トレーニングします。
Model.fit 出力からわかるように、tf.distribute.Strategy を使って期待どおり読み込まれました。ここで使用されるストラテジーは、保存前と同じストラテジーである必要はありません。
tf.saved_model API
より低レベルの API を使用したモデルの保存方法は、Keras API を使う方法に似ています。
読み込みは、tf.saved_model.load を使用して行えますが、これは低レベル API(したがって、より幅広いユースケースのある API)であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。以下に例を示します。
読み込まれたオブジェクトには、それぞれにキーが関連付けられた複数の関数が含まれている可能性があります。"serving_default" キーは、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論するには、以下のようにします。
また、分散方法で読み込んで推論を実行することもできます。
復元された関数の呼び出しは、保存されたモデル(tf.keras.Model.predict)に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには?一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、TF Hub には、以下に示すとおり、この目的に使用できる hub.KerasLayer が用意されています。
上記の例では、hub.KerasLayer は tf.saved_model.load() から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。転移学習を行う際に非常に便利な手法です。
どの API を使用すべきですか?
保存においては、Keras モデルを使用している場合は、低レベル API が実現できる追加の制御が必要でない限り、Keras の Model.save API を使用します。保存しているものが Keras モデルでない場合は、低レベル API の tf.saved_model.save しか使用できません。
読み込みにおいては、使用する API はモデルの読み込みから得ようとしているものによって異なります。Keras モデルを使用できない場合(または使用したくない場合)は、tf.saved_model.load を使用し、使用できる場合は tf.keras.models.load_model を使用します。Keras モデルを保存した場合にのみ、Keras モデルを読み込めることに注意してください。
API を混在させることも可能です。model.save で Keras モデルを保存し、低レベルの tf.saved_model.load API を使用して、非 Keras モデルを読み込むことができます。
ローカルデバイスからの読み込みまたは保存
ローカル I/O デバイスから読み込みと保存を行い、リモートデバイスでトレーニングする場合(Cloud TPU を使用する場合など)、tf.saved_model.SaveOptions と tf.saved_model.LoadOptions に experimental_io_device を使用して、I/O デバイスを localhost に設定する必要があります。以下に例を示します。
警告
Keras モデルを特定の方法で作成してから、トレーニングする前に保存するという、以下のような特別なケースがあります。
SavedModel は tf.function をトレースする際に生成される tf.types.experimental.ConcreteFunction オブジェクトを保存します(詳細は、グラフと tf.function の基本ガイドの関数はいつトレースしますか? をご覧ください)。このような ValueError が発生した場合、Model.save がトレースされた ConcreteFunction を見つけられなかったか作成できなかったことが原因です。
注意: 少なくとも 1 つの ConcreteFunction がない場合にモデルを保存しないことをお勧めします。そうでない場合、低レベル API は、ConcreteFunction シグネチャのない状態で SavedModel を生成してしまうためです(SavedModel 形式については、こちらをご覧ください)。以下に例を示します。
通常、モデルのフォワードパス(call メソッド)は、モデルが Keras の Model.fit メソッドを通じて初めて呼び出されたときに、自動的にトレースされます。また、最初のレイヤーを tf.keras.layers.InputLayer などにして、input_shape キーワード引数に渡すことで入力形状を設定している場合、Keras の Sequential API と Functional API によって ConcreteFunction が生成されることもあります。
モデルにトレース済みの ConcreteFunction が存在するかを確認するには、Model.save_spec が None になっていることを確認します。
tf.keras.Model.fit を使ってモデルをトレーニングし、save_spec が定義され、モデルの保存が機能するかを確認しましょう。
TensorFlow.org で表示
Google Colab で実行
GitHub でソースを表示
ノートブックをダウンロード