Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/guide/migrate/tpu_estimator.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

This guide demonstrates how to migrate your workflows running on TPUs from TensorFlow 1's TPUEstimator API to TensorFlow 2's TPUStrategy API.

  • In TensorFlow 1, the tf.compat.v1.estimator.tpu.TPUEstimator API lets you train and evaluate a model, as well as perform inference and save your model (for serving) on (Cloud) TPUs.

  • In TensorFlow 2, to perform synchronous training on TPUs and TPU Pods (a collection of TPU devices connected by dedicated high-speed network interfaces), you need to use a TPU distribution strategy—tf.distribute.TPUStrategy. The strategy can work with the Keras APIs—including for model building (tf.keras.Model), optimizers (tf.keras.optimizers.Optimizer), and training (Model.fit)—as well as a custom training loop (with tf.function and tf.GradientTape).

For end-to-end TensorFlow 2 examples, check out the Use TPUs guide—namely, the Classification on TPUs section—and the Solve GLUE tasks using BERT on TPU tutorial. You may also find the Distributed training guide useful, which covers all TensorFlow distribution strategies, including TPUStrategy.

Setup

Start with imports and a simple dataset for demonstration purposes:

import tensorflow as tf import tensorflow.compat.v1 as tf1
features = [[1., 1.5]] labels = [[0.3]] eval_features = [[4., 4.5]] eval_labels = [[0.8]]

TensorFlow 1: Drive a model on TPUs with TPUEstimator

This section of the guide demonstrates how to perform training and evaluation with tf.compat.v1.estimator.tpu.TPUEstimator in TensorFlow 1.

To use a TPUEstimator, first define a few functions: an input function for the training data, an evaluation input function for the evaluation data, and a model function that tells the TPUEstimator how the training op is defined with the features and labels:

def _input_fn(params): dataset = tf1.data.Dataset.from_tensor_slices((features, labels)) dataset = dataset.repeat() return dataset.batch(params['batch_size'], drop_remainder=True) def _eval_input_fn(params): dataset = tf1.data.Dataset.from_tensor_slices((eval_features, eval_labels)) dataset = dataset.repeat() return dataset.batch(params['batch_size'], drop_remainder=True) def _model_fn(features, labels, mode, params): logits = tf1.layers.Dense(1)(features) loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits) optimizer = tf1.train.AdagradOptimizer(0.05) train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step()) return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

With those functions defined, create a tf.distribute.cluster_resolver.TPUClusterResolver that provides the cluster information, and a tf.compat.v1.estimator.tpu.RunConfig object. Along with the model function you have defined, you can now create a TPUEstimator. Here, you will simplify the flow by skipping checkpoint savings. Then, you will specify the batch size for both training and evaluation for the TPUEstimator.

cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='') print("All devices: ", tf1.config.list_logical_devices('TPU'))
tpu_config = tf1.estimator.tpu.TPUConfig(iterations_per_loop=10) config = tf1.estimator.tpu.RunConfig( cluster=cluster_resolver, save_checkpoints_steps=None, tpu_config=tpu_config) estimator = tf1.estimator.tpu.TPUEstimator( model_fn=_model_fn, config=config, train_batch_size=8, eval_batch_size=8)

Call TPUEstimator.train to begin training the model:

estimator.train(_input_fn, steps=1)

Then, call TPUEstimator.evaluate to evaluate the model using the evaluation data:

estimator.evaluate(_eval_input_fn, steps=1)

TensorFlow 2: Drive a model on TPUs with Keras Model.fit and TPUStrategy

In TensorFlow 2, to train on the TPU workers, use tf.distribute.TPUStrategy together with the Keras APIs for model definition and training/evaluation. (Refer to the Use TPUs guide for more examples of training with Keras Model.fit and a custom training loop (with tf.function and tf.GradientTape).)

Since you need to perform some initialization work to connect to the remote cluster and initialize the TPU workers, start by creating a TPUClusterResolver to provide the cluster information and connect to the cluster. (Learn more in the TPU initialization section of the Use TPUs guide.)

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(cluster_resolver) tf.tpu.experimental.initialize_tpu_system(cluster_resolver) print("All devices: ", tf.config.list_logical_devices('TPU'))

Next, once your data is prepared, you will create a TPUStrategy, define a model, metrics, and an optimizer under the scope of this strategy.

To achieve comparable training speed with TPUStrategy, you should make sure to pick a number for steps_per_execution in Model.compile because it specifies the number of batches to run during each tf.function call, and is critical for performance. This argument is similar to iterations_per_loop used in a TPUEstimator. If you are using custom training loops, you should make sure multiple steps are run within the tf.function-ed training function. Go to the Improving performance with multiple steps inside tf.function section of the Use TPUs guide for more information.

tf.distribute.TPUStrategy can support bounded dynamic shapes, which is the case that the upper bound of the dynamic shape computation can be inferred. But dynamic shapes may introduce some performance overhead compared to static shapes. So, it is generally recommended to make your input shapes static if possible, especially in training. One common op that returns a dynamic shape is tf.data.Dataset.batch(batch_size), since the number of samples remaining in a stream might be less than the batch size. Therefore, when training on the TPU, you should use tf.data.Dataset.batch(..., drop_remainder=True) for best training performance.

dataset = tf.data.Dataset.from_tensor_slices( (features, labels)).shuffle(10).repeat().batch( 8, drop_remainder=True).prefetch(2) eval_dataset = tf.data.Dataset.from_tensor_slices( (eval_features, eval_labels)).batch(1, drop_remainder=True) strategy = tf.distribute.TPUStrategy(cluster_resolver) with strategy.scope(): model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)]) optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05) model.compile(optimizer, "mse", steps_per_execution=10)

With that, you are ready to train the model with the training dataset:

model.fit(dataset, epochs=5, steps_per_epoch=10)

Finally, evaluate the model using the evaluation dataset:

model.evaluate(eval_dataset, return_dict=True)

Next steps

To learn more about TPUStrategy in TensorFlow 2, consider the following resources:

To learn more about customizing your training, refer to:

TPUs—Google's specialized ASICs for machine learning—are available through Google Colab, the TPU Research Cloud, and Cloud TPU.