Migrating Keras 2 code to multi-backend Keras 3
Author: Divyashree Sreepathihalli
Date created: 2023/10/23
Last modified: 2023/10/30
Description: Instructions & troubleshooting for migrating your Keras 2 code to multi-backend Keras 3.
View in Colab •
GitHub source
This guide will help you migrate TensorFlow-only Keras 2 code to multi-backend Keras 3 code. The overhead for the migration is minimal. Once you have migrated, you can run Keras workflows on top of either JAX, TensorFlow, or PyTorch.
This guide has two parts:
Migrating your legacy Keras 2 code to Keras 3, running on top of the TensorFlow backend. This is generally very easy, though there are minor issues to be mindful of, that we will go over in detail.
Further migrating your Keras 3 + TensorFlow code to multi-backend Keras 3, so that it can run on JAX and PyTorch.
Let's get started.
Setup
First, lets install keras-nightly
.
This example uses the TensorFlow backend (os.environ["KERAS_BACKEND"] = "tensorflow"
). After you've migrated your code, you can change the "tensorflow"
string to "jax"
or "torch"
and click "Restart runtime" in Colab, and your code will run on the JAX or PyTorch backend.
!pip install -q keras-nightly
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import tensorflow as tf
import numpy as np
```
[[34;49mnotice[1;39;49m][39;49m A new release of pip is available: [31;49m23.3.1[39;49m -> [32;49m24.0
[[34;49mnotice[1;39;49m][39;49m To update, run: [32;49mpip install --upgrade pip
</div>
---
First, replace your imports:
1. Replace `from tensorflow import keras` to `import keras`
2. Replace `from tensorflow.keras import xyz` (e.g. `from tensorflow.keras import layers`)
to `from keras import xyz` (e.g. `from keras import layers`)
3. Replace `tf.keras.*` to `keras.*`
Next, start running your tests. Most of the time, your code will execute on Keras 3 just fine.
All issues you might encounter are detailed below, with their fixes.
The default value of the `jit_compile` argument to the `Model` constructor has been set to
`True` on GPU in Keras 3. This means that models will be compiled with Just-In-Time (JIT)
compilation by default on GPU.
JIT compilation can improve the performance of some models. However, it may not work with
all TensorFlow operations. If you are using a custom model or layer and you see an
XLA-related error, you may need to set the `jit_compile` argument to `False`. Here is a list
of [known issues](https://www.tensorflow.org/xla/known_issues) encountered when
using XLA with TensorFlow. In addition to these issues, there are some
ops that are not supported by XLA.
The error message you could encounter would be as follows:
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_125[] on XLA_GPU_JIT
For example, the following snippet of code will reproduce the above error:
```python
class MyModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, inputs):
string_input = tf.strings.as_string(inputs)
return tf.strings.to_number(string_input)
subclass_model = MyModel()
x_train = np.array([[1, 2, 3], [4, 5, 6]])
subclass_model.compile(optimizer="sgd", loss="mse")
subclass_model.predict(x_train)
How to fix it: set jit_compile=False
in model.compile(..., jit_compile=False)
, or set the jit_compile
attribute to False
, like this:
class MyModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, inputs):
string_input = tf.strings.as_string(inputs)
return tf.strings.to_number(string_input)
subclass_model = MyModel()
x_train = np.array([[1, 2, 3], [4, 5, 6]])
subclass_model.jit_compile = False
subclass_model.predict(x_train)
```
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 51ms/step
array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)
</div>
### Saving a model in the TF SavedModel format
Saving to the TF SavedModel format via `model.save()` is no longer supported in Keras 3.
The error message you could encounter would be as follows:
model.save("mymodel") ValueError: Invalid filepath extension for saving. Please add either a .keras
extension for the native Keras format (recommended) or a .h5
extension. Use model.export(filepath)
if you want to export a SavedModel for use with TFLite/TFServing/etc. Received: filepath=saved_model.
The following snippet of code will reproduce the above error:
```python
sequential_model = keras.Sequential([
keras.layers.Dense(2)
])
sequential_model.save("saved_model")
How to fix it: use model.export(filepath)
instead of model.save(filepath)
sequential_model = keras.Sequential([keras.layers.Dense(2)])
sequential_model(np.random.rand(3, 5))
sequential_model.export("saved_model")
```
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets
Saved artifact at 'saved_model'. The following endpoints are available:
</div>
<div class="k-default-codeblock">
Endpoint 'serve' args_0 (POSITIONAL_ONLY): TensorSpec(shape=(3, 5), dtype=tf.float32, name='keras_tensor') Output Type: TensorSpec(shape=(3, 2), dtype=tf.float32, name=None) Captures: 14428321600: TensorSpec(shape=(), dtype=tf.resource, name=None) 14439128528: TensorSpec(shape=(), dtype=tf.resource, name=None)
</div>
Loading a TF SavedModel file via `keras.models.load_model()` is no longer supported
If you try to use `keras.models.load_model()` with a TF SavedModel, you will get the following error:
```python
ValueError: File format not supported: filepath=saved_model. Keras 3 only supports V3
`.keras` files and legacy H5 format files (`.h5` extension). Note that the legacy
SavedModel format is not supported by `load_model()` in Keras 3. In order to reload a
TensorFlow SavedModel as an inference-only layer in Keras 3, use
`keras.layers.TFSMLayer(saved_model, call_endpoint='serving_default')` (note that your
`call_endpoint` might have a different name).
The following snippet of code will reproduce the above error:
keras.models.load_model("saved_model")
How to fix it: Use keras.layers.TFSMLayer(filepath, call_endpoint="serving_default")
to reload a TF SavedModel as a Keras layer. This is not limited to SavedModels that originate from Keras -- it will work with any SavedModel, e.g. TF-Hub models.
keras.layers.TFSMLayer("saved_model", call_endpoint="serving_default")
</div>
### Using deeply nested inputs in Functional Models
`Model()` can no longer be passed deeply nested inputs/outputs (nested more than 1 level
deep, e.g. lists of lists of tensors).
You would encounter errors as follows:
ValueError: When providing inputs
as a dict, all values in the dict must be KerasTensors. Received: inputs={'foo': <KerasTensor shape=(None, 1), dtype=float32, sparse=None, name=foo>, 'bar': {'baz': <KerasTensor shape=(None, 1), dtype=float32, sparse=None, name=bar>}} including invalid value {'baz': <KerasTensor shape=(None, 1), dtype=float32, sparse=None, name=bar>} of type <class 'dict'>
The following snippet of code will reproduce the above error:
```python
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": {
"baz": keras.Input(shape=(1,), name="bar"),
},
}
outputs = inputs["foo"] + inputs["bar"]["baz"]
keras.Model(inputs, outputs)
How to fix it: replace nested input with either dicts, lists, and tuples of input tensors.
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": keras.Input(shape=(1,), name="bar"),
}
outputs = inputs["foo"] + inputs["bar"]
keras.Model(inputs, outputs)
</div>
In Keras 2, TF autograph is enabled by default on the `call()` method of custom
layers. In Keras 3, it is not. This means you may have to use cond ops if you're using
control flow, or alternatively you can decorate your `call()` method with `@tf.function`.
You would encounter an error as follows:
OperatorNotAllowedInGraphError: Exception encountered when calling MyCustomLayer.call().
Using a symbolic tf.Tensor
as a Python bool
is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. Here is a [link for more information](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/ref erence/limitations.md#access-to-source-code).
The following snippet of code will reproduce the above error:
```python
class MyCustomLayer(keras.layers.Layer):
def call(self, inputs):
if tf.random.uniform(()) > 0.5:
return inputs * 2
else:
return inputs / 2
layer = MyCustomLayer()
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
How to fix it: decorate your call()
method with @tf.function
class MyCustomLayer(keras.layers.Layer):
@tf.function()
def call(self, inputs):
if tf.random.uniform(()) > 0.5:
return inputs * 2
else:
return inputs / 2
layer = MyCustomLayer()
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
```
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 43ms/step
array([[0.59727275, 1.9986179 , 1.5514829 ], [0.56239295, 1.6529864 , 0.33085832], [0.67086476, 1.5208522 , 1.99276 ]], dtype=float32)
</div>
### Calling TF ops with a `KerasTensor`
Using a TF op on a Keras tensor during functional model construction is disallowed: "A
KerasTensor cannot be used as input to a TensorFlow function".
The error you would encounter would be as follows:
ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces keras.layers
and keras.operations
).
The following snippet of code will reproduce the error:
```python
input = keras.layers.Input([2, 2, 1])
tf.squeeze(input)
How to fix it: use an equivalent op from keras.ops
.
input = keras.layers.Input([2, 2, 1])
keras.ops.squeeze(input)
</div>
The `evaluate()` method of a multi-output model no longer returns individual output
losses separately. Instead, you should utilize the `metrics` argument in the `compile()`
method to keep track of these losses.
When dealing with multiple named outputs, such as output_a and output_b, the legacy
`tf.keras` would include <output_a>_loss, <output_b>_loss, and similar entries in
metrics. However, in keras 3.0, these entries are not automatically added to metrics.
They must be explicitly provided in the metrics list for each individual output.
The following snippet of code will reproduce the above behavior:
```python
from keras import layers
inputs = layers.Input(shape=(10,))
x1 = layers.Dense(5, activation='relu')(inputs)
x2 = layers.Dense(5, activation='relu')(x1)
output_1 = layers.Dense(5, activation='softmax', name="output_1")(x1)
output_2 = layers.Dense(5, activation='softmax', name="output_2")(x2)
model = keras.Model(inputs=inputs, outputs=[output_1, output_2])
model.compile(optimizer='adam', loss='categorical_crossentropy')
x_test = np.random.uniform(size=[10, 10])
y_test = np.random.uniform(size=[10, 5])
model.evaluate(x_test, y_test)
from keras import layers
inputs = layers.Input(shape=(10,))
x1 = layers.Dense(5, activation="relu")(inputs)
x2 = layers.Dense(5, activation="relu")(x1)
output_1 = layers.Dense(5, activation="softmax", name="output_1")(x1)
output_2 = layers.Dense(5, activation="softmax", name="output_2")(x2)
x_test = np.random.uniform(size=[10, 10])
y_test = np.random.uniform(size=[10, 5])
multi_output_model = keras.Model(inputs=inputs, outputs=[output_1, output_2])
multi_output_model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["categorical_crossentropy", "categorical_crossentropy"],
)
multi_output_model.evaluate(x_test, y_test)
```
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 112ms/step - loss: 4.0217 - output_1_categorical_crossentropy: 4.0217
[4.021683692932129, 4.021683692932129]
</div>
Setting a `tf.Variable` as an attribute of a Keras 3 layer or model will not automatically
track the variable, unlike in Keras 2. The following snippet of code will show that the `tf.Variables`
are not being tracked.
```python
class MyCustomLayer(keras.layers.Layer):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = tf.Variable(initial_value=tf.zeros([input_dim, self.units]))
self.b = tf.Variable(initial_value=tf.zeros([self.units,]))
def call(self, inputs):
return keras.ops.matmul(inputs, self.w) + self.b
layer = MyCustomLayer(3)
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
for layer in model.layers:
print(layer.trainable_variables)
You will see the following warning:
UserWarning: The model does not have any trainable weights.
warnings.warn("The model does not have any trainable weights.")
How to fix it: use self.add_weight()
method or opt for a keras.Variable
instead. If you are currently using tf.variable
, you can switch to keras.Variable
.
class MyCustomLayer(keras.layers.Layer):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = self.add_weight(
shape=[input_dim, self.units],
initializer="zeros",
)
self.b = self.add_weight(
shape=[
self.units,
],
initializer="zeros",
)
def call(self, inputs):
return keras.ops.matmul(inputs, self.w) + self.b
layer = MyCustomLayer(3)
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
for layer in model.layers:
print(layer.trainable_variables)
```
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 33ms/step
[, ]
</div>
`None` entries are not allowed as part of nested (e.g. list/tuples) tensor
arguments in `Layer.call()`, nor as part of `call()`'s nested return values.
If the `None` in the argument is intentional and serves a specific purpose,
ensure that the argument is optional and structure it as a separate parameter.
For example, consider defining the `call` method with optional argument.
The following snippet of code will reproduce the error.
```python
class CustomLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, inputs):
foo = inputs["foo"]
baz = inputs["bar"]["baz"]
if baz is not None:
return foo + baz
return foo
layer = CustomLayer()
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": {
"baz": None,
},
}
layer(inputs)
How to fix it:
Solution 1: Replace None
with a value, like this:
class CustomLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, inputs):
foo = inputs["foo"]
baz = inputs["bar"]["baz"]
return foo + baz
layer = CustomLayer()
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": {
"baz": keras.Input(shape=(1,), name="bar"),
},
}
layer(inputs)
</div>
**Solution 2:** Define the call method with an optional argument.
Here is an example of this fix:
```python
class CustomLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, foo, baz=None):
if baz is not None:
return foo + baz
return foo
layer = CustomLayer()
foo = keras.Input(shape=(1,), name="foo")
baz = None
layer(foo, baz=baz)
</div>
### State-building issues
Keras 3 is significantly stricter than Keras 2 about when state (e.g. numerical weight variables)
can be created. Keras 3 wants all state to be created before the model can be trained. This is a requirement
for using JAX (whereas TensorFlow was very lenient about state creation timing).
Keras layers should create their state either in their constructor (`__init__()` method) or in their `build()` method.
They should avoid creating state in `call()`.
If you ignore this recommendation and create state in `call()`
anyway (e.g. by calling a previously unbuilt layer), then Keras will attempt to build the layer automatically
by calling the `call()` method on symbolic inputs before training.
However, this attempt at automatic state creation may fail in certain cases.
This will cause an error that looks like like this:
Layer 'frame_position_embedding' looks like it has unbuilt state, but Keras is not able to trace the layer call()
in order to build it automatically. Possible causes:
The call()
method of your layer may be crashing. Try to __call__()
the layer eagerly on some test input first to see if it works. E.g. x = np.random.random((3, 4)); y = layer(x)
If the call()
method is correct, then you may need to implement the def build(self, input_shape)
method on your layer. It should create all variables used by the layer (e.g. by calling layer.build()
on all its children layers).
You could reproduce this error with the following layer, when used with the JAX backend:
```python
class PositionalEmbedding(keras.layers.Layer):
def __init__(self, sequence_length, output_dim, **kwargs):
super().__init__(**kwargs)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
self.sequence_length = sequence_length
self.output_dim = output_dim
def call(self, inputs):
inputs = keras.ops.cast(inputs, self.compute_dtype)
length = keras.ops.shape(inputs)[1]
positions = keras.ops.arange(start=0, stop=length, step=1)
embedded_positions = self.position_embeddings(positions)
return inputs + embedded_positions
How to fix it: Do exactly what the error message asks. First, try to run the layer eagerly to see if the call()
method is in fact correct (note: if it was working in Keras 2, then it is correct and does not need to be changed). If it is indeed correct, then you should implement a build(self, input_shape)
method that creates all of the layer's state, including the state of sublayers. Here's the fix as applied for the layer above (note the build()
method):
class PositionalEmbedding(keras.layers.Layer):
def __init__(self, sequence_length, output_dim, **kwargs):
super().__init__(**kwargs)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
self.sequence_length = sequence_length
self.output_dim = output_dim
def build(self, input_shape):
self.position_embeddings.build(input_shape)
def call(self, inputs):
inputs = keras.ops.cast(inputs, self.compute_dtype)
length = keras.ops.shape(inputs)[1]
positions = keras.ops.arange(start=0, stop=length, step=1)
embedded_positions = self.position_embeddings(positions)
return inputs + embedded_positions
Removed features
A small number of legacy features with very low usage were removed from Keras 3 as a cleanup measure:
keras.layers.ThresholdedReLU
is removed. Instead, you can simply use the ReLU
layer with the argument threshold
.
Symbolic Layer.add_loss()
: Symbolic add_loss()
is removed (you can still use add_loss()
inside the call()
method of a layer/model).
Locally connected layers (LocallyConnected1D
, LocallyConnected2D
are removed due to very low usage. To use locally connected layers, copy the layer implementation into your own codebase.
keras.layers.experimental.RandomFourierFeatures
is removed due to very low usage. To use it, copy the layer implementation into your own codebase.
Removed layer attributes: Layer attributes metrics
, dynamic
are removed. metrics
is still available on the Model
class.
The constants
and time_major
arguments in RNN layers are removed. The constants
argument was a remnant of Theano and had very low usage. The time_major
argument also had very low usage.
reset_metrics
argument: The reset_metrics
argument is removed from model.*_on_batch()
methods. This argument had very low usage.
The keras.constraints.RadialConstraint
object is removed. This object had very low usage.
Transitioning to backend-agnostic Keras 3
Keras 3 code with the TensorFlow backend will work with native TensorFlow APIs. However, if you want your code to be backend-agnostic, you will need to:
Replace all of the tf.*
API calls with their equivalent Keras APIs.
Convert your custom train_step
/test_step
methods to a multi-framework implementation.
Make sure you're using stateless keras.random
ops correctly in your layers.
Let's go over each point in detail.
Switching to Keras ops
In many cases, this is the only thing you need to do to start being able to run your custom layers and metrics with JAX and PyTorch: replace any tf.*
, tf.math*
, tf.linalg.*
, etc. with keras.ops.*
. Most TF ops should be consistent with Keras 3. If the names different, they will be highlighted in this guide.
NumPy ops
Keras implements the NumPy API as part of keras.ops
.
The table below only lists a small subset of TensorFlow and Keras ops; ops not listed are usually named the same in both frameworks (e.g. reshape
, matmul
, cast
, etc.)
TensorFlow | Keras 3.0 |
---|
tf.abs | keras.ops.absolute |
tf.reduce_all | keras.ops.all |
tf.reduce_max | keras.ops.amax |
tf.reduce_min | keras.ops.amin |
tf.reduce_any | keras.ops.any |
tf.concat | keras.ops.concatenate |
tf.range | keras.ops.arange |
tf.acos | keras.ops.arccos |
tf.asin | keras.ops.arcsin |
tf.asinh | keras.ops.arcsinh |
tf.atan | keras.ops.arctan |
tf.atan2 | keras.ops.arctan2 |
tf.atanh | keras.ops.arctanh |
tf.convert_to_tensor | keras.ops.convert_to_tensor |
tf.reduce_mean | keras.ops.mean |
tf.clip_by_value | keras.ops.clip |
tf.math.conj | keras.ops.conjugate |
tf.linalg.diag_part | keras.ops.diagonal |
tf.reverse | keras.ops.flip |
tf.gather | keras.ops.take |
tf.math.is_finite | keras.ops.isfinite |
tf.math.is_inf | keras.ops.isinf |
tf.math.is_nan | keras.ops.isnan |
tf.reduce_max | keras.ops.max |
tf.reduce_mean | keras.ops.mean |
tf.reduce_min | keras.ops.min |
tf.rank | keras.ops.ndim |
tf.math.pow | keras.ops.power |
tf.reduce_prod | keras.ops.prod |
tf.math.reduce_std | keras.ops.std |
tf.reduce_sum | keras.ops.sum |
tf.gather | keras.ops.take |
tf.gather_nd | keras.ops.take_along_axis |
tf.math.reduce_variance | keras.ops.var |
Others ops
TensorFlow | Keras 3.0 |
---|
tf.nn.sigmoid_cross_entropy_with_logits | keras.ops.binary_crossentropy (mind the from_logits argument) |
tf.nn.sparse_softmax_cross_entropy_with_logits | keras.ops.sparse_categorical_crossentropy (mind the from_logits argument) |
tf.nn.sparse_softmax_cross_entropy_with_logits | keras.ops.categorical_crossentropy(target, output, from_logits=False, axis=-1) |
tf.nn.conv1d , tf.nn.conv2d , tf.nn.conv3d , tf.nn.convolution | keras.ops.conv |
tf.nn.conv_transpose , tf.nn.conv1d_transpose , tf.nn.conv2d_transpose , tf.nn.conv3d_transpose | keras.ops.conv_transpose |
tf.nn.depthwise_conv2d | keras.ops.depthwise_conv |
tf.nn.separable_conv2d | keras.ops.separable_conv |
tf.nn.batch_normalization | No direct equivalent; use keras.layers.BatchNormalization |
tf.nn.dropout | keras.random.dropout |
tf.nn.embedding_lookup | keras.ops.take |
tf.nn.l2_normalize | keras.utils.normalize (not an op) |
x.numpy | keras.ops.convert_to_numpy |
tf.scatter_nd_update | keras.ops.scatter_update |
tf.tensor_scatter_nd_update | keras.ops.slice_update |
tf.signal.fft2d | keras.ops.fft2 |
tf.signal.inverse_stft | keras.ops.istft |
tf.image.crop_to_bounding_box | keras.ops.image.crop_images |
tf.image.pad_to_bounding_box | keras.ops.image.pad_images |
Custom train_step()
methods
Your models may include a custom train_step()
or test_step()
method, which rely on TensorFlow-only APIs -- for instance, your train_step()
method may leverage TensorFlow's tf.GradientTape
. To convert such models to run on JAX or PyTorch, you will have a write a different train_step()
implementation for each backend you want to support.
In some cases, you might be able to simply override the Model.compute_loss()
method and make it fully backend-agnostic, instead of overriding train_step()
. Here's an example of a layer with a custom compute_loss()
method which works across JAX, TensorFlow, and PyTorch:
class MyModel(keras.Model):
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
loss = keras.ops.sum(keras.losses.mean_squared_error(y, y_pred, sample_weight))
return loss
If you need to modify the optimization mechanism itself, beyond the loss computation, then you will need to override train_step()
, and implement one train_step
method per backend, like below.
See the following guides for details on how each backend should be handled:
class MyModel(keras.Model):
def train_step(self, *args, **kwargs):
if keras.backend.backend() == "jax":
return self._jax_train_step(*args, **kwargs)
elif keras.backend.backend() == "tensorflow":
return self._tensorflow_train_step(*args, **kwargs)
elif keras.backend.backend() == "torch":
return self._torch_train_step(*args, **kwargs)
def _jax_train_step(self, state, data):
pass
def _tensorflow_train_step(self, data):
pass
def _torch_train_step(self, data):
pass
RNG-using layers
Keras 3 has a new keras.random
namespace, containing:
keras.random.normal
keras.random.uniform
keras.random.shuffle
etc.
These operations are stateless, which means that if you pass a seed
argument, they will return the same result every time. Like this:
print(keras.random.normal(shape=(), seed=123))
print(keras.random.normal(shape=(), seed=123))
```
tf.Tensor(0.7832616, shape=(), dtype=float32)
tf.Tensor(0.7832616, shape=(), dtype=float32)
</div>
Crucially, this differs from the behavior of stateful `tf.random` ops:
```python
print(tf.random.normal(shape=(), seed=123))
print(tf.random.normal(shape=(), seed=123))
```
tf.Tensor(2.4435377, shape=(), dtype=float32)
tf.Tensor(-0.6386405, shape=(), dtype=float32)
</div>
When you write a RNG-using layer, such as a custom dropout layer, you are
going to want to use a different seed value at layer call. However, you cannot
just increment a Python integer and pass it, because while this would work fine
when executed eagerly, it would not work as expected when using compilation
(which is available with JAX, TensorFlow, and PyTorch). When compiling the layer,
the first Python integer seed value seen by the layer would be hardcoded into the
compiled graph.
To address this, you should pass as the `seed` argument an instance of a
stateful `keras.random.SeedGenerator` object, like this:
```python
seed_generator = keras.random.SeedGenerator(1337)
print(keras.random.normal(shape=(), seed=seed_generator))
print(keras.random.normal(shape=(), seed=seed_generator))
```
tf.Tensor(0.6077996, shape=(), dtype=float32)
tf.Tensor(0.8211102, shape=(), dtype=float32)
</div>
So when writing a RNG using layer, you would use the following pattern:
```python
class RandomNoiseLayer(keras.layers.Layer):
def __init__(self, noise_rate, **kwargs):
super().__init__(**kwargs)
self.noise_rate = noise_rate
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, inputs):
noise = keras.random.uniform(
minval=0, maxval=self.noise_rate, seed=self.seed_generator
)
return inputs + noise
Such a layer is safe to use in any setting -- in eager execution or in a compiled model. Each layer call will be using a different seed value, as expected.