Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/guide/jax2tf.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

JAX2TF を使って JAX モデルをインポートする

このノートブックは、JAX を用いてモデルを作成し、そのトレーニングを TensorFlow で行う、実行可能な完全版サンプルです。これは、JAX エコシステムから TensorFlow エコシステムへのパスを提供する軽量な JAX2TF という API によって実現しています。

JAX は、高性能の配列計算ライブラリです。モデルを作成するために、このノートブックでは JAX 用の Flax というニューラルネットワークライブラリを使用しています。このトレーニングには、Optax という JAX 用の最適化ライブラリを使用しています。

JAX を使用している研究者の方は、JAX2TF を使用することで、TensorFlow の実証されたツールを使ったプロダクションを実現できます。

これを活用できる方法には多数あります。以下はその一部です。

  • 推論: JAX 用に書かれたモデルを、TF Serving、TFLite を使ったオンデバイス、またはTensorFlow.js を使ったウェブにデプロイします。

  • ファインチューニング: JAX でトレーニングされたモデルを使って、そのコンポーネントを JAX2TF 経由で TF に持ち込み、既存のトレーニングデータとセットアップを使って TensorFlow でトレーニングを続けます。

  • 合成: JAX でトレーニングされた部分と TensorFlow でトレーニングされた部分を合わせて、最大限の柔軟性を得ます。

このような JAX と TensorFlow の相互運用性を実現するための鍵は jax2tf.convert にあります。これは JAX 上に作成されたモデルコンポーネント(損失関数、予測関数など)を取って、同等の表現を TensorFlow 関数として作成し、それを TensorFlow SavedModel としてエクスポートすることができます。

セットアップ

import tensorflow as tf import numpy as np import jax import jax.numpy as jnp import flax import optax import os from matplotlib import pyplot as plt from jax.experimental import jax2tf from threading import Lock # Only used in the visualization utility. from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: # Memory growth must be set before GPUs have been initialized. print(e)
#@title Visualization utilities plt.rcParams["figure.figsize"] = (20,8) # The utility for displaying training and validation curves. def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10): ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs) # The losses. ax = plt.subplot(121) if loss is not None: x = np.arange(len(loss)) / steps_per_epochs #* epochs ax.plot(x, loss) ax.plot(range(1, epochs+1), avg_loss, "-o", linewidth=3) ax.plot(range(1, epochs+1), eval_loss, "-o", linewidth=3) ax.set_title('Loss') ax.set_ylabel('loss') ax.set_xlabel('epoch') if loss is not None: ax.set_ylim(0, np.max(loss[ignore_first_n:])) ax.legend(['train', 'avg train', 'eval']) else: ymin = np.min(avg_loss[ignore_first_n_epochs:]) ymax = np.max(avg_loss[ignore_first_n_epochs:]) ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10) ax.legend(['avg train', 'eval']) # The accuracy. ax = plt.subplot(122) ax.set_title('Eval Accuracy') ax.set_ylabel('accuracy') ax.set_xlabel('epoch') ymin = np.min(eval_accuracy[ignore_first_n_epochs:]) ymax = np.max(eval_accuracy[ignore_first_n_epochs:]) ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10) ax.plot(range(1, epochs+1), eval_accuracy, "-o", linewidth=3) class Progress: """Text mode progress bar. Usage: p = Progress(30) p.step() p.step() p.step(reset=True) # to restart form 0% The progress bar displays a new header at each restart.""" def __init__(self, maxi, size=100, msg=""): """ :param maxi: the number of steps required to reach 100% :param size: the number of characters taken on the screen by the progress bar :param msg: the message displayed in the header of the progress bar """ self.maxi = maxi self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator. self.header_printed = False self.msg = msg self.size = size self.lock = Lock() def step(self, reset=False): with self.lock: if reset: self.__init__(self.maxi, self.size, self.msg) if not self.header_printed: self.__print_header() next(self.p) def __print_header(self): print() format_string = "0%{: ^" + str(self.size - 6) + "}100%" print(format_string.format(self.msg)) self.header_printed = True def __start_progress(self, maxi): def print_progress(): # Bresenham's algorithm. Yields the number of dots printed. # This will always print 100 dots in max invocations. dx = maxi dy = self.size d = dy - dx for x in range(maxi): k = 0 while d >= 0: print('=', end="", flush=True) k += 1 d -= dx d += dy yield k # Keep yielding the last result if there are too many steps. while True: yield k return print_progress

MNIST データセットをダウンロードして準備する

(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data() train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1), tf.one_hot(y, depth=10))) BATCH_SIZE = 256 train_data = train_data.batch(BATCH_SIZE, drop_remainder=True) train_data = train_data.cache() train_data = train_data.shuffle(5000, reshuffle_each_iteration=True) test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels)) test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1), tf.one_hot(y, depth=10))) test_data = test_data.batch(10000) test_data = test_data.cache() (one_batch, one_batch_labels) = next(iter(train_data)) # just one batch (all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000

トレーニングを構成する

このノートブックは、デモの目的で、単純なモデルを作成してトレーニングします。

# Training hyperparameters. JAX_EPOCHS = 3 TF_EPOCHS = 7 STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE LEARNING_RATE = 0.01 LEARNING_RATE_EXP_DECAY = 0.6 # The learning rate schedule for JAX (with Optax). jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True) # THe learning rate schedule for TensorFlow. tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

Flax を使ってモデルを作成する

class ConvModel(flax.linen.Module): @flax.linen.compact def __call__(self, x, train): x = flax.linen.Conv(features=12, kernel_size=(3,3), padding="SAME", use_bias=False)(x) x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x) x = x.reshape((x.shape[0], -1)) # flatten x = flax.linen.Dense(features=200, use_bias=True)(x) x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x) x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x) x = flax.linen.relu(x) x = flax.linen.Dense(features=10)(x) #x = flax.linen.log_softmax(x) return x # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number). # `jax.grad` will differentiate it against the fist argument. # The user must split trainable and non-trainable variables into `params` and `other_state`. # Must pass a different RNG key each time for the dropout mask to be different. def loss(self, params, other_state, rng, data, labels, train): logits, batch_stats = self.apply({'params': params, **other_state}, data, mutable=['batch_stats'], rngs={'dropout': rng}, train=train) # The loss averaged across the batch dimension. loss = optax.softmax_cross_entropy(logits, labels).mean() return loss, batch_stats def predict(self, state, data): logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False) probabilities = flax.linen.log_softmax(logits) return probabilities def accuracy(self, state, data, labels): probabilities = self.predict(state, data) predictions = jnp.argmax(probabilities, axis=-1) dense_labels = jnp.argmax(labels, axis=-1) accuracy = jnp.equal(predictions, dense_labels).mean() return accuracy

トレーニングステップの関数を記述する

# The training step. @partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model def train_step(model, state, optimizer_state, rng, data, labels): other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True) updates, optimizer_state = optimizer.update(grads, optimizer_state) params = optax.apply_updates(params, updates) new_state = state.copy(add_or_replace={**batch_stats, 'params': params}) rng, _ = jax.random.split(rng) return new_state, optimizer_state, rng, loss

トレーニングループを書く

def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies): p = Progress(STEPS_PER_EPOCH) rng = jax.random.PRNGKey(0) for epoch in range(epochs): # This is where the learning rate schedule state is stored in the optimizer state. optimizer_step = optimizer_state[1].count # Run an epoch of training. for step, (data, labels) in enumerate(train_data): p.step(reset=(step==0)) state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy()) losses.append(loss) avg_loss = np.mean(losses[-step:]) avg_losses.append(avg_loss) # Run one epoch of evals (10,000 test images in a single batch). other_state, params = state.pop('params') # Gotcha: must discard modified batch_stats here eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False) eval_losses.append(eval_loss) eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy()) eval_accuracies.append(eval_accuracy) print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", jlr_decay(optimizer_step)) return state, optimizer_state

モデルとオプティマイザを作成する(Optax を使用)

# The model. model = ConvModel() state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout" # The optimizer. optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule optimizer_state = optimizer.init(state['params']) losses=[] avg_losses=[] eval_losses=[] eval_accuracies=[]

モデルをトレーニングする

new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

モデルを部分的にトレーニングする

まもなく、モデルのトレーニングを TensorFlow で続けます。

model = ConvModel() state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout" # The optimizer. optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule optimizer_state = optimizer.init(state['params']) losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []
state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

推論に十分な分だけを保存する

JAX モデルのデプロイが目標である場合は(model.predict() を使って推論するために)、それを SavedModel にエクスポートするだけで十分です。このセクションでは、それを行う方法を説明します。

# Test data with a different batch size to test polymorphic shapes. x, y = next(iter(train_data.unbatch().batch(13))) m = tf.Module() # Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function. state_vars = tf.nest.map_structure(tf.Variable, state) # Keep the wrapped state as flat list (needed in TensorFlow fine-tuning). m.vars = tf.nest.flatten(state_vars) # Convert the desired JAX function (`model.predict`). predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"]) # Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work). @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)]) def predict(data): return predict_fn(state_vars, data) m.predict = predict tf.saved_model.save(m, "./")
# Test the converted function. print("Converted function predictions:", np.argmax(m.predict(x).numpy(), axis=-1)) # Reload the model. reloaded_model = tf.saved_model.load("./") # Test the reloaded converted function (the result should be the same). print("Reloaded function predictions:", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

すべてを保存する

完全なエクスポートが目標である場合は(モデルを TensorFlow に持ってきてファインチューニングや合成などをお行う場合に便利)、このセクションで、モデルを保存して以下を含むメソッドでアクセスできるようにする方法を説明します。

  • model.predict

  • model.accuracy

  • model.loss(train=True/False ブール値、ドロップアウトの RNG、BatchNorm 状態の更新を含む)

from collections import abc def _fix_frozen(d): """Changes any mappings (e.g. frozendict) back to dict.""" if isinstance(d, list): return [_fix_frozen(v) for v in d] elif isinstance(d, tuple): return tuple(_fix_frozen(v) for v in d) elif not isinstance(d, abc.Mapping): return d d = dict(d) for k, v in d.items(): d[k] = _fix_frozen(v) return d
class TFModel(tf.Module): def __init__(self, state, model): super().__init__() # Special care needed for the train=True/False parameter in the loss @jax.jit def loss_with_train_bool(state, rng, data, labels, train): other_state, params = state.pop('params') loss, batch_stats = jax.lax.cond(train, lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True), lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False), state, data, labels) # must use JAX to split the RNG, therefore, must do it in a @jax.jit function new_rng, _ = jax.random.split(rng) return loss, batch_stats, new_rng self.state_vars = tf.nest.map_structure(tf.Variable, state) self.vars = tf.nest.flatten(self.state_vars) self.jax_rng = tf.Variable(jax.random.PRNGKey(0)) self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=["...", "...", "(b, 28, 28, 1)", "(b, 10)", "..."]) self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=["...", "(b, 28, 28, 1)", "(b, 10)"]) self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"]) # Must specify TensorSpec manually for variable batch size to work @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)]) def predict(self, data): # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly # otherwise, all model weights would be embedded in the TF graph as constants. return self.predict_fn(self.state_vars, data) @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def train_loss(self, data, labels): loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True) # update batch norm stats flat_vars = tf.nest.flatten(self.state_vars['batch_stats']) flat_values = tf.nest.flatten(batch_stats['batch_stats']) for var, val in zip(flat_vars, flat_values): var.assign(val) # update RNG self.jax_rng.assign(new_rng) return loss @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def eval_loss(self, data, labels): loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False) return loss @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def accuracy(self, data, labels): return self.accuracy_fn(self.state_vars, data, labels)
# Instantiate the model. tf_model = TFModel(state, model) # Save the model. tf.saved_model.save(tf_model, "./")

モデルを再読み込みする

reloaded_model = tf.saved_model.load("./") # Test if it works and that the batch size is indeed variable. x,y = next(iter(train_data.unbatch().batch(13))) print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1)) x,y = next(iter(train_data.unbatch().batch(20))) print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1)) print(reloaded_model.accuracy(one_batch, one_batch_labels)) print(reloaded_model.accuracy(all_test_data, all_test_labels))

変換した JAX モデルのトレーニングを TensorFlow で続行する

optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay) # Set the iteration step for the learning rate to resume from where it left off in JAX. optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH) p = Progress(STEPS_PER_EPOCH) for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS): # This is where the learning rate schedule state is stored in the optimizer state. optimizer_step = optimizer.iterations for step, (data, labels) in enumerate(train_data): p.step(reset=(step==0)) with tf.GradientTape() as tape: #loss = reloaded_model.loss(data, labels, True) loss = reloaded_model.train_loss(data, labels) grads = tape.gradient(loss, reloaded_model.vars) optimizer.apply_gradients(zip(grads, reloaded_model.vars)) losses.append(loss) avg_loss = np.mean(losses[-step:]) avg_losses.append(avg_loss) eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy() eval_losses.append(eval_loss) eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy() eval_accuracies.append(eval_accuracy) print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", tflr_decay(optimizer.iterations).numpy())
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH) # The loss takes a hit when the training restarts, but does not go back to random levels. # This is likely caused by the optimizer momentum being reinitialized.

次のステップ

JAXFlax について、それぞれのドキュメントサイトをご覧ください。詳しいガイドや例が掲載されています。JAX に不慣れな方は、JAX 101 チュートリアルFlax クイックスタートを必ずご覧ください。JAX モデルを TensorFlow 形式に変換する方法については、GitHub の jax2tf ユーティリティをご覧ください。JAX モデルを変換してブラウザで TensorFlow.js を使って実行することに興味のある方は、TensorFlow.js を使ったウェブでの JAX をご覧ください。TensorFlow Lite で実行できるように JAX モデルを準備する場合は、TFLite 向けの JAX モデルの変換ガイドをご覧ください。