Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/lite/guide/signatures.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

TensorFlow Lite のシグネチャ

TensorFlow Lite では、TensorFlow モデルの入出力仕様を TensorFlow Lite モデルに変換できます。入出力仕様は「シグネチャ」と呼ばれます。署名は、SavedModel の構築時または具体的な関数の作成時に指定できます。

TensorFlow Lite のシグネチャには次の機能があります。

  • TensorFlow モデルのシグネチャを守ることで、変換された TensorFlow Lite モデルの入出力を指定する。

  • 1 つの TensorFlow Lite モデルで複数の入力点をサポートできる。

シグネチャには次の 3 つの要素があります。

  • 入力: シグネチャの入力名から入力テンソルへの入力のマッピング。

  • 出力: シグネチャの出力名から出力テンソルへの出力のマッピング。

  • シグネチャキー: グラフの入力点を識別する名前。

MNIST モデルをビルドする

import tensorflow as tf

サンプル モデル

エンコードとデコードといった 2 つのタスクが TensorFlow モデルとして存在するとします。

class Model(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)]) def encode(self, x): result = tf.strings.as_string(x) return { "encoded_result": result } @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def decode(self, x): result = tf.strings.to_number(x) return { "decoded_result": result }

シグネチャという点では、上記の TensorFlow モデルは次のように要約することができます。

  • シグネチャ

    • キー: encode

    • 入力: {"x"}

    • 出力: {"encoded_result"}

  • シグネチャ

    • キー: decode

    • 入力: {"x"}

    • 出力: {"decoded_result"}

シグネチャを使用したモデルの変換

TensorFlow Lite コンバータ API は、上記のシグネチャ情報を変換された TensorFlow Lite モデルに渡します。

この変換機能は、TensorFlow バージョン 2.7.0 以降のすべてのコンバータ API で提供されています。使用例を参照してください。

保存されたモデルから変換

model = Model() # Save the model SAVED_MODEL_PATH = 'content/saved_models/coding' tf.saved_model.save( model, SAVED_MODEL_PATH, signatures={ 'encode': model.encode.get_concrete_function(), 'decode': model.decode.get_concrete_function() }) # Convert the saved model using TFLiteConverter converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_model = converter.convert() # Print the signatures from the converted model interpreter = tf.lite.Interpreter(model_content=tflite_model) signatures = interpreter.get_signature_list() print(signatures)

Keras モデルから変換

# Generate a Keras model. keras_model = tf.keras.Sequential( [ tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'), tf.keras.layers.Dense(1, activation='relu', name='output'), ] ) # Convert the keras model using TFLiteConverter. # Keras model converter API uses the default signature automatically. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) tflite_model = converter.convert() # Print the signatures from the converted model interpreter = tf.lite.Interpreter(model_content=tflite_model) signatures = interpreter.get_signature_list() print(signatures)

Concrete 関数から変換

model = Model() # Convert the concrete functions using TFLiteConverter converter = tf.lite.TFLiteConverter.from_concrete_functions( [model.encode.get_concrete_function(), model.decode.get_concrete_function()], model) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_model = converter.convert() # Print the signatures from the converted model interpreter = tf.lite.Interpreter(model_content=tflite_model) signatures = interpreter.get_signature_list() print(signatures)

シグネチャの実行

TensorFlow の推論 API は、シグネチャに基づく実行をサポートします。

  • シグネチャで指定された入出力の名前を使用して、入出力テンソルにアクセスします。

  • シグネチャキーで指定されたグラフの各入力点を個別に実行します。

  • SavedModel の初期化手順をサポートします。

Java、C++、Python 言語バインディングは現在使用できます。次のセクションの例を参照してください。

Java

try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) { // Run encoding signature. Map<String, Object> inputs = new HashMap<>(); inputs.put("x", input); Map<String, Object> outputs = new HashMap<>(); outputs.put("encoded_result", encoded_result); interpreter.runSignature(inputs, outputs, "encode"); // Run decoding signature. Map<String, Object> inputs = new HashMap<>(); inputs.put("x", encoded_result); Map<String, Object> outputs = new HashMap<>(); outputs.put("decoded_result", decoded_result); interpreter.runSignature(inputs, outputs, "decode"); }

C++

SignatureRunner* encode_runner = interpreter->GetSignatureRunner("encode"); encode_runner->ResizeInputTensor("x", {100}); encode_runner->AllocateTensors(); TfLiteTensor* input_tensor = encode_runner->input_tensor("x"); float* input = GetTensorData<float>(input_tensor); // Fill `input`. encode_runner->Invoke(); const TfLiteTensor* output_tensor = encode_runner->output_tensor( "encoded_result"); float* output = GetTensorData<float>(output_tensor); // Access `output`.

Python

# Load the TFLite model in TFLite Interpreter interpreter = tf.lite.Interpreter(model_content=tflite_model) # Print the signatures from the converted model signatures = interpreter.get_signature_list() print('Signature:', signatures) # encode and decode are callable with input as arguments. encode = interpreter.get_signature_runner('encode') decode = interpreter.get_signature_runner('decode') # 'encoded' and 'decoded' are dictionaries with all outputs from the inference. input = tf.constant([1, 2, 3], dtype=tf.float32) print('Input:', input) encoded = encode(x=input) print('Encoded result:', encoded) decoded = decode(x=encoded['encoded_result']) print('Decoded result:', decoded)

既知の制限

  • TFLite インタープリタはスレッドの安全を保証しないため、同じインタープリタからのシグネチャランナーは同時に実行されません。

  • C/iOS/Swift のサポートはまだ提供されていません。

更新

  • バージョン 2.7

    • 複数のシグネチャ機能が実装されました。

    • バージョン 2 以降のすべてのコンバータ API は、シグネチャ対応 TensorFlow Lite モデルを生成します。

  • バージョン 2.5

    • シグネチャ機能は、from_saved_model コンバータ API から利用できます。