Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ko/guide/jax2tf.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

JAX2TF를 사용하여 JAX 모델 가져오기

이 노트북은 JAX를 사용하는 모델을 생성하고 이를 TensorFlow로 가져와 훈련을 계속하는 완전하고 실행 가능한 예제를 제공합니다. 이 작업은 JAX 생태계에서 TensorFlow 생태계로 이동하는 길을 제공하는 경량 API인 JAX2TF를 통해 가능합니다.

JAX는 고성능 배열 컴퓨팅 라이브러리입니다. 모델을 생성하기 위해 이 노트북은 JAX용 신경망 라이브러리인 Flax를 사용합니다. 모델을 훈련하기 위해 JAX용 최적화 라이브러리인 Optax를 사용합니다.

JAX를 사용하는 연구자에게 JAX2TF는 TensorFlow의 검증된 도구를 사용하여 프로덕션으로 이동하는 길을 제공합니다.

이 기능을 유용하게 사용할 수 있는 방법은 여러 가지가 있지만 그 중 몇 가지만 소개하겠습니다.

  • 추론: JAX용으로 작성된 모델을 가져와 TF Serving을 사용하는 서버에 배포하거나, TFLite를 사용하는 온디바이스(on-device)에 배포하거나, TensorFlow.js를 사용하는 웹에 배포할 수 있습니다.

  • 미세 조정: JAX를 사용하여 훈련한 모델의 구성 요소를 JAX2TF를 사용하는 TF로 가져온 다음, 기존 훈련 데이터와 설정을 사용하는 TensorFlow에서 계속 훈련할 수 있습니다.

  • 융합: 유연성을 극대화하기 위해 JAX를 사용하여 훈련한 모델의 일부와 TensorFlow를 사용하여 훈련한 모델의 일부를 결합합니다.

JAX와 TensorFlow 사이의 이러한 상호 운용을 가능하게 하는 핵심은 jax2tf.convert이며 이 기능은 JAX에서 생성된 모델 구성 요소(손실 함수, 예측 함수 등)를 가져와서 이를 TensorFlow 함수와 동등한 모습을 갖도록 만든 다음 TensorFlow SavedModel로 내보낼 수 있습니다.

설치하기

import tensorflow as tf import numpy as np import jax import jax.numpy as jnp import flax import optax import os from matplotlib import pyplot as plt from jax.experimental import jax2tf from threading import Lock # Only used in the visualization utility. from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: # Memory growth must be set before GPUs have been initialized. print(e)
#@title Visualization utilities plt.rcParams["figure.figsize"] = (20,8) # The utility for displaying training and validation curves. def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10): ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs) # The losses. ax = plt.subplot(121) if loss is not None: x = np.arange(len(loss)) / steps_per_epochs #* epochs ax.plot(x, loss) ax.plot(range(1, epochs+1), avg_loss, "-o", linewidth=3) ax.plot(range(1, epochs+1), eval_loss, "-o", linewidth=3) ax.set_title('Loss') ax.set_ylabel('loss') ax.set_xlabel('epoch') if loss is not None: ax.set_ylim(0, np.max(loss[ignore_first_n:])) ax.legend(['train', 'avg train', 'eval']) else: ymin = np.min(avg_loss[ignore_first_n_epochs:]) ymax = np.max(avg_loss[ignore_first_n_epochs:]) ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10) ax.legend(['avg train', 'eval']) # The accuracy. ax = plt.subplot(122) ax.set_title('Eval Accuracy') ax.set_ylabel('accuracy') ax.set_xlabel('epoch') ymin = np.min(eval_accuracy[ignore_first_n_epochs:]) ymax = np.max(eval_accuracy[ignore_first_n_epochs:]) ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10) ax.plot(range(1, epochs+1), eval_accuracy, "-o", linewidth=3) class Progress: """Text mode progress bar. Usage: p = Progress(30) p.step() p.step() p.step(reset=True) # to restart form 0% The progress bar displays a new header at each restart.""" def __init__(self, maxi, size=100, msg=""): """ :param maxi: the number of steps required to reach 100% :param size: the number of characters taken on the screen by the progress bar :param msg: the message displayed in the header of the progress bar """ self.maxi = maxi self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator. self.header_printed = False self.msg = msg self.size = size self.lock = Lock() def step(self, reset=False): with self.lock: if reset: self.__init__(self.maxi, self.size, self.msg) if not self.header_printed: self.__print_header() next(self.p) def __print_header(self): print() format_string = "0%{: ^" + str(self.size - 6) + "}100%" print(format_string.format(self.msg)) self.header_printed = True def __start_progress(self, maxi): def print_progress(): # Bresenham's algorithm. Yields the number of dots printed. # This will always print 100 dots in max invocations. dx = maxi dy = self.size d = dy - dx for x in range(maxi): k = 0 while d >= 0: print('=', end="", flush=True) k += 1 d -= dx d += dy yield k # Keep yielding the last result if there are too many steps. while True: yield k return print_progress

MNIST 데이터세트 다운로드 및 준비하기

(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data() train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1), tf.one_hot(y, depth=10))) BATCH_SIZE = 256 train_data = train_data.batch(BATCH_SIZE, drop_remainder=True) train_data = train_data.cache() train_data = train_data.shuffle(5000, reshuffle_each_iteration=True) test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels)) test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1), tf.one_hot(y, depth=10))) test_data = test_data.batch(10000) test_data = test_data.cache() (one_batch, one_batch_labels) = next(iter(train_data)) # just one batch (all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000

훈련 구성하기

이 노트북에서는 데모 목적으로 간단한 모델을 만들고 훈련합니다.

# Training hyperparameters. JAX_EPOCHS = 3 TF_EPOCHS = 7 STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE LEARNING_RATE = 0.01 LEARNING_RATE_EXP_DECAY = 0.6 # The learning rate schedule for JAX (with Optax). jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True) # THe learning rate schedule for TensorFlow. tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

Flax를 사용하여 모델 만들기

class ConvModel(flax.linen.Module): @flax.linen.compact def __call__(self, x, train): x = flax.linen.Conv(features=12, kernel_size=(3,3), padding="SAME", use_bias=False)(x) x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x) x = x.reshape((x.shape[0], -1)) # flatten x = flax.linen.Dense(features=200, use_bias=True)(x) x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x) x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x) x = flax.linen.relu(x) x = flax.linen.Dense(features=10)(x) #x = flax.linen.log_softmax(x) return x # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number). # `jax.grad` will differentiate it against the fist argument. # The user must split trainable and non-trainable variables into `params` and `other_state`. # Must pass a different RNG key each time for the dropout mask to be different. def loss(self, params, other_state, rng, data, labels, train): logits, batch_stats = self.apply({'params': params, **other_state}, data, mutable=['batch_stats'], rngs={'dropout': rng}, train=train) # The loss averaged across the batch dimension. loss = optax.softmax_cross_entropy(logits, labels).mean() return loss, batch_stats def predict(self, state, data): logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False) probabilities = flax.linen.log_softmax(logits) return probabilities def accuracy(self, state, data, labels): probabilities = self.predict(state, data) predictions = jnp.argmax(probabilities, axis=-1) dense_labels = jnp.argmax(labels, axis=-1) accuracy = jnp.equal(predictions, dense_labels).mean() return accuracy

훈련 단계 함수 작성하기

# The training step. @partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model def train_step(model, state, optimizer_state, rng, data, labels): other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True) updates, optimizer_state = optimizer.update(grads, optimizer_state) params = optax.apply_updates(params, updates) new_state = state.copy(add_or_replace={**batch_stats, 'params': params}) rng, _ = jax.random.split(rng) return new_state, optimizer_state, rng, loss

훈련 루프 작성하기

def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies): p = Progress(STEPS_PER_EPOCH) rng = jax.random.PRNGKey(0) for epoch in range(epochs): # This is where the learning rate schedule state is stored in the optimizer state. optimizer_step = optimizer_state[1].count # Run an epoch of training. for step, (data, labels) in enumerate(train_data): p.step(reset=(step==0)) state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy()) losses.append(loss) avg_loss = np.mean(losses[-step:]) avg_losses.append(avg_loss) # Run one epoch of evals (10,000 test images in a single batch). other_state, params = state.pop('params') # Gotcha: must discard modified batch_stats here eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False) eval_losses.append(eval_loss) eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy()) eval_accuracies.append(eval_accuracy) print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", jlr_decay(optimizer_step)) return state, optimizer_state

모델 및 옵티마이저 생성하기(Optax 사용)

# The model. model = ConvModel() state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout" # The optimizer. optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule optimizer_state = optimizer.init(state['params']) losses=[] avg_losses=[] eval_losses=[] eval_accuracies=[]

모델 훈련하기

new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

모델 부분적으로 훈련하기

계속해서 TensorFlow에서 모델 훈련을 이어 가겠습니다.

model = ConvModel() state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout" # The optimizer. optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule optimizer_state = optimizer.init(state['params']) losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []
state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

추론에 필요한 만큼만 저장하기

JAX 모델을 배포하는 것이 목표인 경우(model.predict()를 사용하여 추론을 실행할 수 있도록) 단순히 SavedModel로 이를 내보내는 것만으로도 충분합니다. 이 섹션에서는 이를 수행하는 방법을 설명합니다.

# Test data with a different batch size to test polymorphic shapes. x, y = next(iter(train_data.unbatch().batch(13))) m = tf.Module() # Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function. state_vars = tf.nest.map_structure(tf.Variable, state) # Keep the wrapped state as flat list (needed in TensorFlow fine-tuning). m.vars = tf.nest.flatten(state_vars) # Convert the desired JAX function (`model.predict`). predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"]) # Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work). @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)]) def predict(data): return predict_fn(state_vars, data) m.predict = predict tf.saved_model.save(m, "./")
# Test the converted function. print("Converted function predictions:", np.argmax(m.predict(x).numpy(), axis=-1)) # Reload the model. reloaded_model = tf.saved_model.load("./") # Test the reloaded converted function (the result should be the same). print("Reloaded function predictions:", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

모두 저장하기

전체 내보내기가 목표인 경우(미세 조정, 융합 등을 위해 모델을 TensorFlow로 가져올 계획인 경우 유용함) 이 섹션에서는 모델을 저장하여 다음과 같은 메서드에 액세스할 수 있는 방법을 설명합니다.

  • model.predict

  • model.accuracy

  • model.loss(train=True/False 부울, 드롭아웃을 위한 RNG 및 BatchNorm 상태 업데이트 포함)

from collections import abc def _fix_frozen(d): """Changes any mappings (e.g. frozendict) back to dict.""" if isinstance(d, list): return [_fix_frozen(v) for v in d] elif isinstance(d, tuple): return tuple(_fix_frozen(v) for v in d) elif not isinstance(d, abc.Mapping): return d d = dict(d) for k, v in d.items(): d[k] = _fix_frozen(v) return d
class TFModel(tf.Module): def __init__(self, state, model): super().__init__() # Special care needed for the train=True/False parameter in the loss @jax.jit def loss_with_train_bool(state, rng, data, labels, train): other_state, params = state.pop('params') loss, batch_stats = jax.lax.cond(train, lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True), lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False), state, data, labels) # must use JAX to split the RNG, therefore, must do it in a @jax.jit function new_rng, _ = jax.random.split(rng) return loss, batch_stats, new_rng self.state_vars = tf.nest.map_structure(tf.Variable, state) self.vars = tf.nest.flatten(self.state_vars) self.jax_rng = tf.Variable(jax.random.PRNGKey(0)) self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=["...", "...", "(b, 28, 28, 1)", "(b, 10)", "..."]) self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=["...", "(b, 28, 28, 1)", "(b, 10)"]) self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"]) # Must specify TensorSpec manually for variable batch size to work @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)]) def predict(self, data): # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly # otherwise, all model weights would be embedded in the TF graph as constants. return self.predict_fn(self.state_vars, data) @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def train_loss(self, data, labels): loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True) # update batch norm stats flat_vars = tf.nest.flatten(self.state_vars['batch_stats']) flat_values = tf.nest.flatten(batch_stats['batch_stats']) for var, val in zip(flat_vars, flat_values): var.assign(val) # update RNG self.jax_rng.assign(new_rng) return loss @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def eval_loss(self, data, labels): loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False) return loss @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def accuracy(self, data, labels): return self.accuracy_fn(self.state_vars, data, labels)
# Instantiate the model. tf_model = TFModel(state, model) # Save the model. tf.saved_model.save(tf_model, "./")

모델 다시 로드하기

reloaded_model = tf.saved_model.load("./") # Test if it works and that the batch size is indeed variable. x,y = next(iter(train_data.unbatch().batch(13))) print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1)) x,y = next(iter(train_data.unbatch().batch(20))) print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1)) print(reloaded_model.accuracy(one_batch, one_batch_labels)) print(reloaded_model.accuracy(all_test_data, all_test_labels))

변환된 JAX 모델을 TensorFlow에서 계속 훈련하기

optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay) # Set the iteration step for the learning rate to resume from where it left off in JAX. optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH) p = Progress(STEPS_PER_EPOCH) for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS): # This is where the learning rate schedule state is stored in the optimizer state. optimizer_step = optimizer.iterations for step, (data, labels) in enumerate(train_data): p.step(reset=(step==0)) with tf.GradientTape() as tape: #loss = reloaded_model.loss(data, labels, True) loss = reloaded_model.train_loss(data, labels) grads = tape.gradient(loss, reloaded_model.vars) optimizer.apply_gradients(zip(grads, reloaded_model.vars)) losses.append(loss) avg_loss = np.mean(losses[-step:]) avg_losses.append(avg_loss) eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy() eval_losses.append(eval_loss) eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy() eval_accuracies.append(eval_accuracy) print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", tflr_decay(optimizer.iterations).numpy())
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH) # The loss takes a hit when the training restarts, but does not go back to random levels. # This is likely caused by the optimizer momentum being reinitialized.

다음 단계

JAXFlax에 대한 자세한 내용은 상세 가이드와 예제가 포함된 해당 문서 웹사이트에서 확인할 수 있습니다. JAX를 처음 접하는 경우 JAX 101 튜토리얼Flax 퀵스타트를 확인하세요. JAX 모델을 TensorFlow 형식으로 변환하는 방법에 대한 자세한 내용은 GitHub에서 jax2tf 유틸리티를 확인하세요. 브라우저에서 실행할 수 있도록 JAX 모델을 변환하는 데 관심이 있는 경우 JAX on the Web with TensorFlow.js를 방문하세요. TensorFlow Lite에서 실행하는 JAX 모델을 준비하려면 TFLite용 JAX 모델 변환 가이드를 참조하세요.