Path: blob/master/site/ja/tutorials/distribute/save_and_load.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
分散ストラテジーを使ってモデルを保存して読み込む
概要
このチュートリアルでは、トレーニング中またはトレーニング後に tf.distribute.Strategy
を使用して SavedModel 形式でモデルを保存して読み込む方法を説明します。Keras モデルの保存と読み込みには、高レベル(tf.keras.Model.save
と tf.keras.models.load_model
)と低レベル(tf.saved_model.save
と tf.saved_model.load
)の 2 種類の API があります。
SavedModel とシリアル化の全般的な内容については、SavedModel ガイドと Keras モデルのシリアル化ガイドをお読みください。では、単純な例から始めましょう。
注意: TensorFlow モデルはコードであるため、信頼できないコードには注意する必要があります。詳細は、TensorFlow を安全に使用するをご覧ください。
依存関係をインポートします。
TensorFlow Datasets と tf.data
でデータを読み込んで準備し、tf.distribute.MirroredStrategy
を使ってモデルを作成します。
tf.keras.Model.fit
を使用してモデルをトレーニングします。
モデルを保存して読み込む
作業に使用する単純なモデルを準備できたので、保存と読み込みに使用する API を見てみましょう。使用できる API には、以下の 2 種類があります。
高レベル(Keras):
Model.save
およびtf.keras.models.load_model
(.keras
zip アーカイブ形式)低レベル:
tf.saved_model.save
およびtf.saved_model.load
(TF SavedModel 形式)
Keras API
Keras API を使用したモデルの保存と読み込みの例を以下に示します。
tf.distribute.Strategy
を使用せずにモデルを復元します。
モデルを復元したら、Model.compile
をもう一度呼び出さずにそのままトレーニングを続行できます。これは、保存前にすでにコンパイル済みであるためです。このモデルは、Keras zip アーカイブ形式で保存されており、.keras
拡張子で識別できます。詳細については、Keras の保存に関するガイドをご覧ください。
次に、tf.distribute.Strategy
を使用してモデルを復元し、トレーニングします。
Model.fit
出力からわかるように、tf.distribute.Strategy
を使って期待どおり読み込まれました。ここで使用されるストラテジーは、保存前と同じストラテジーである必要はありません。
tf.saved_model
API
より低レベルの API を使用したモデルの保存方法は、Keras API を使う方法に似ています。
読み込みは、tf.saved_model.load
を使用して行えますが、これは低レベル API(したがって、より幅広いユースケースのある API)であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。以下に例を示します。
読み込まれたオブジェクトには、それぞれにキーが関連付けられた複数の関数が含まれている可能性があります。"serving_default"
キーは、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論するには、以下のようにします。
また、分散方法で読み込んで推論を実行することもできます。
復元された関数の呼び出しは、保存されたモデル(tf.keras.Model.predict
)に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには?一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、TF Hub には、以下に示すとおり、この目的に使用できる hub.KerasLayer
が用意されています。
上記の例では、hub.KerasLayer
は tf.saved_model.load()
から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。転移学習を行う際に非常に便利な手法です。
どの API を使用すべきですか?
保存においては、Keras モデルを使用している場合は、低レベル API が実現できる追加の制御が必要でない限り、Keras の Model.save
API を使用します。保存しているものが Keras モデルでない場合は、低レベル API の tf.saved_model.save
しか使用できません。
読み込みにおいては、使用する API はモデルの読み込みから得ようとしているものによって異なります。Keras モデルを使用できない場合(または使用したくない場合)は、tf.saved_model.load
を使用し、使用できる場合は tf.keras.models.load_model
を使用します。Keras モデルを保存した場合にのみ、Keras モデルを読み込めることに注意してください。
API を混在させることも可能です。model.save
で Keras モデルを保存し、低レベルの tf.saved_model.load
API を使用して、非 Keras モデルを読み込むことができます。
ローカルデバイスからの読み込みまたは保存
ローカル I/O デバイスから読み込みと保存を行い、リモートデバイスでトレーニングする場合(Cloud TPU を使用する場合など)、tf.saved_model.SaveOptions
と tf.saved_model.LoadOptions
に experimental_io_device
を使用して、I/O デバイスを localhost
に設定する必要があります。以下に例を示します。
警告
Keras モデルを特定の方法で作成してから、トレーニングする前に保存するという、以下のような特別なケースがあります。
SavedModel は tf.function
をトレースする際に生成される tf.types.experimental.ConcreteFunction
オブジェクトを保存します(詳細は、グラフと tf.function の基本ガイドの関数はいつトレースしますか? をご覧ください)。このような ValueError
が発生した場合、Model.save
がトレースされた ConcreteFunction
を見つけられなかったか作成できなかったことが原因です。
注意: 少なくとも 1 つの ConcreteFunction
がない場合にモデルを保存しないことをお勧めします。そうでない場合、低レベル API は、ConcreteFunction
シグネチャのない状態で SavedModel を生成してしまうためです(SavedModel 形式については、こちらをご覧ください)。以下に例を示します。
通常、モデルのフォワードパス(call
メソッド)は、モデルが Keras の Model.fit
メソッドを通じて初めて呼び出されたときに、自動的にトレースされます。また、最初のレイヤーを tf.keras.layers.InputLayer
などにして、input_shape
キーワード引数に渡すことで入力形状を設定している場合、Keras の Sequential API と Functional API によって ConcreteFunction
が生成されることもあります。
モデルにトレース済みの ConcreteFunction
が存在するかを確認するには、Model.save_spec
が None
になっていることを確認します。
tf.keras.Model.fit
を使ってモデルをトレーニングし、save_spec
が定義され、モデルの保存が機能するかを確認しましょう。