Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/migrating_to_keras_3.py
3273 views
1
"""
2
Title: Migrating Keras 2 code to multi-backend Keras 3
3
Author: [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)
4
Date created: 2023/10/23
5
Last modified: 2023/10/30
6
Description: Instructions & troubleshooting for migrating your Keras 2 code to multi-backend Keras 3.
7
Accelerator: None
8
"""
9
10
"""
11
This guide will help you migrate TensorFlow-only Keras 2 code to multi-backend Keras
12
3 code. The overhead for the migration is minimal. Once you have migrated,
13
you can run Keras workflows on top of either JAX, TensorFlow, or PyTorch.
14
15
This guide has two parts:
16
17
1. Migrating your legacy Keras 2 code to Keras 3, running on top of the TensorFlow backend.
18
This is generally very easy, though there are minor issues to be mindful of, that we will go over
19
in detail.
20
2. Further migrating your Keras 3 + TensorFlow code to multi-backend Keras 3, so that it can run on
21
JAX and PyTorch.
22
23
Let's get started.
24
"""
25
26
"""
27
## Setup
28
29
First, lets install `keras-nightly`.
30
31
This example uses the TensorFlow backend (`os.environ["KERAS_BACKEND"] = "tensorflow"`).
32
After you've migrated your code, you can change the `"tensorflow"` string to `"jax"` or `"torch"`
33
and click "Restart runtime" in Colab, and your code will run on the JAX or PyTorch backend.
34
"""
35
36
"""shell
37
pip install -q keras-nightly
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "tensorflow"
43
44
import keras
45
import tensorflow as tf
46
import numpy as np
47
48
"""
49
## Going from Keras 2 to Keras 3 with the TensorFlow backend
50
51
First, replace your imports:
52
53
1. Replace `from tensorflow import keras` to `import keras`
54
2. Replace `from tensorflow.keras import xyz` (e.g. `from tensorflow.keras import layers`)
55
to `from keras import xyz` (e.g. `from keras import layers`)
56
3. Replace `tf.keras.*` to `keras.*`
57
58
Next, start running your tests. Most of the time, your code will execute on Keras 3 just fine.
59
All issues you might encounter are detailed below, with their fixes.
60
"""
61
62
"""
63
### `jit_compile` is set to `True` by default on GPU.
64
65
The default value of the `jit_compile` argument to the `Model` constructor has been set to
66
`True` on GPU in Keras 3. This means that models will be compiled with Just-In-Time (JIT)
67
compilation by default on GPU.
68
69
JIT compilation can improve the performance of some models. However, it may not work with
70
all TensorFlow operations. If you are using a custom model or layer and you see an
71
XLA-related error, you may need to set the `jit_compile` argument to `False`. Here is a list
72
of [known issues](https://www.tensorflow.org/xla/known_issues) encountered when
73
using XLA with TensorFlow. In addition to these issues, there are some
74
ops that are not supported by XLA.
75
76
The error message you could encounter would be as follows:
77
78
```
79
Detected unsupported operations when trying to compile graph
80
__inference_one_step_on_data_125[] on XLA_GPU_JIT
81
```
82
83
For example, the following snippet of code will reproduce the above error:
84
85
```python
86
class MyModel(keras.Model):
87
def __init__(self, *args, **kwargs):
88
super().__init__(*args, **kwargs)
89
90
def call(self, inputs):
91
string_input = tf.strings.as_string(inputs)
92
return tf.strings.to_number(string_input)
93
94
95
subclass_model = MyModel()
96
x_train = np.array([[1, 2, 3], [4, 5, 6]])
97
subclass_model.compile(optimizer="sgd", loss="mse")
98
subclass_model.predict(x_train)
99
```
100
"""
101
102
"""
103
**How to fix it:** set `jit_compile=False` in `model.compile(..., jit_compile=False)`,
104
or set the `jit_compile` attribute to `False`, like this:
105
"""
106
107
108
class MyModel(keras.Model):
109
def __init__(self, *args, **kwargs):
110
super().__init__(*args, **kwargs)
111
112
def call(self, inputs):
113
# tf.strings ops aren't support by XLA
114
string_input = tf.strings.as_string(inputs)
115
return tf.strings.to_number(string_input)
116
117
118
subclass_model = MyModel()
119
x_train = np.array([[1, 2, 3], [4, 5, 6]])
120
subclass_model.jit_compile = False
121
subclass_model.predict(x_train)
122
123
"""
124
### Saving a model in the TF SavedModel format
125
126
Saving to the TF SavedModel format via `model.save()` is no longer supported in Keras 3.
127
128
The error message you could encounter would be as follows:
129
130
```
131
>>> model.save("mymodel")
132
ValueError: Invalid filepath extension for saving. Please add either a `.keras` extension
133
for the native Keras format (recommended) or a `.h5` extension. Use
134
`model.export(filepath)` if you want to export a SavedModel for use with
135
TFLite/TFServing/etc. Received: filepath=saved_model.
136
```
137
138
The following snippet of code will reproduce the above error:
139
140
```python
141
sequential_model = keras.Sequential([
142
keras.layers.Dense(2)
143
])
144
sequential_model.save("saved_model")
145
```
146
"""
147
148
"""
149
**How to fix it:** use `model.export(filepath)` instead of `model.save(filepath)`
150
"""
151
152
sequential_model = keras.Sequential([keras.layers.Dense(2)])
153
sequential_model(np.random.rand(3, 5))
154
sequential_model.export("saved_model")
155
156
"""
157
### Loading a TF SavedModel
158
159
Loading a TF SavedModel file via `keras.models.load_model()` is no longer supported
160
If you try to use `keras.models.load_model()` with a TF SavedModel, you will get the following error:
161
162
```python
163
ValueError: File format not supported: filepath=saved_model. Keras 3 only supports V3
164
`.keras` files and legacy H5 format files (`.h5` extension). Note that the legacy
165
SavedModel format is not supported by `load_model()` in Keras 3. In order to reload a
166
TensorFlow SavedModel as an inference-only layer in Keras 3, use
167
`keras.layers.TFSMLayer(saved_model, call_endpoint='serving_default')` (note that your
168
`call_endpoint` might have a different name).
169
```
170
171
The following snippet of code will reproduce the above error:
172
173
```python
174
keras.models.load_model("saved_model")
175
```
176
"""
177
178
"""
179
**How to fix it:** Use `keras.layers.TFSMLayer(filepath, call_endpoint="serving_default")` to reload a TF
180
SavedModel as a Keras layer. This is not limited to SavedModels that originate from Keras -- it will work
181
with any SavedModel, e.g. TF-Hub models.
182
"""
183
184
keras.layers.TFSMLayer("saved_model", call_endpoint="serving_default")
185
186
"""
187
### Using deeply nested inputs in Functional Models
188
189
`Model()` can no longer be passed deeply nested inputs/outputs (nested more than 1 level
190
deep, e.g. lists of lists of tensors).
191
192
You would encounter errors as follows:
193
194
```
195
ValueError: When providing `inputs` as a dict, all values in the dict must be
196
KerasTensors. Received: inputs={'foo': <KerasTensor shape=(None, 1), dtype=float32,
197
sparse=None, name=foo>, 'bar': {'baz': <KerasTensor shape=(None, 1), dtype=float32,
198
sparse=None, name=bar>}} including invalid value {'baz': <KerasTensor shape=(None, 1),
199
dtype=float32, sparse=None, name=bar>} of type <class 'dict'>
200
```
201
202
The following snippet of code will reproduce the above error:
203
204
```python
205
inputs = {
206
"foo": keras.Input(shape=(1,), name="foo"),
207
"bar": {
208
"baz": keras.Input(shape=(1,), name="bar"),
209
},
210
}
211
outputs = inputs["foo"] + inputs["bar"]["baz"]
212
keras.Model(inputs, outputs)
213
```
214
215
"""
216
217
"""
218
**How to fix it:** replace nested input with either dicts, lists, and tuples
219
of input tensors.
220
"""
221
222
inputs = {
223
"foo": keras.Input(shape=(1,), name="foo"),
224
"bar": keras.Input(shape=(1,), name="bar"),
225
}
226
outputs = inputs["foo"] + inputs["bar"]
227
keras.Model(inputs, outputs)
228
229
"""
230
### TF autograph
231
232
In Keras 2, TF autograph is enabled by default on the `call()` method of custom
233
layers. In Keras 3, it is not. This means you may have to use cond ops if you're using
234
control flow, or alternatively you can decorate your `call()` method with `@tf.function`.
235
236
You would encounter an error as follows:
237
```
238
OperatorNotAllowedInGraphError: Exception encountered when calling MyCustomLayer.call().
239
240
Using a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the
241
following resolutions to the problem: If you are running in Graph mode, use Eager
242
execution mode or decorate this function with @tf.function. If you are using AutoGraph,
243
you can try decorating this function with @tf.function. If that does not work, then you
244
may be using an unsupported feature or your source code may not be visible to AutoGraph.
245
Here is a [link for more information](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/ref
246
erence/limitations.md#access-to-source-code).
247
```
248
249
The following snippet of code will reproduce the above error:
250
251
```python
252
class MyCustomLayer(keras.layers.Layer):
253
254
def call(self, inputs):
255
if tf.random.uniform(()) > 0.5:
256
return inputs * 2
257
else:
258
return inputs / 2
259
260
261
layer = MyCustomLayer()
262
data = np.random.uniform(size=[3, 3])
263
model = keras.models.Sequential([layer])
264
model.compile(optimizer="adam", loss="mse")
265
model.predict(data)
266
```
267
"""
268
269
"""
270
**How to fix it:** decorate your `call()` method with `@tf.function`
271
"""
272
273
274
class MyCustomLayer(keras.layers.Layer):
275
@tf.function()
276
def call(self, inputs):
277
if tf.random.uniform(()) > 0.5:
278
return inputs * 2
279
else:
280
return inputs / 2
281
282
283
layer = MyCustomLayer()
284
data = np.random.uniform(size=[3, 3])
285
model = keras.models.Sequential([layer])
286
model.compile(optimizer="adam", loss="mse")
287
model.predict(data)
288
289
"""
290
### Calling TF ops with a `KerasTensor`
291
292
Using a TF op on a Keras tensor during functional model construction is disallowed: "A
293
KerasTensor cannot be used as input to a TensorFlow function".
294
295
The error you would encounter would be as follows:
296
297
```
298
ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor
299
is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional
300
models or Keras Functions. You can only use it as input to a Keras layer or a Keras
301
operation (from the namespaces `keras.layers` and `keras.operations`).
302
```
303
304
The following snippet of code will reproduce the error:
305
306
```python
307
input = keras.layers.Input([2, 2, 1])
308
tf.squeeze(input)
309
```
310
"""
311
312
"""
313
**How to fix it:** use an equivalent op from `keras.ops`.
314
"""
315
316
input = keras.layers.Input([2, 2, 1])
317
keras.ops.squeeze(input)
318
319
"""
320
### Multi-output model `evaluate()`
321
322
The `evaluate()` method of a multi-output model no longer returns individual output
323
losses separately. Instead, you should utilize the `metrics` argument in the `compile()`
324
method to keep track of these losses.
325
326
327
When dealing with multiple named outputs, such as output_a and output_b, the legacy
328
`tf.keras` would include <output_a>_loss, <output_b>_loss, and similar entries in
329
metrics. However, in keras 3.0, these entries are not automatically added to metrics.
330
They must be explicitly provided in the metrics list for each individual output.
331
332
The following snippet of code will reproduce the above behavior:
333
334
```python
335
from keras import layers
336
# A functional model with multiple outputs
337
inputs = layers.Input(shape=(10,))
338
x1 = layers.Dense(5, activation='relu')(inputs)
339
x2 = layers.Dense(5, activation='relu')(x1)
340
output_1 = layers.Dense(5, activation='softmax', name="output_1")(x1)
341
output_2 = layers.Dense(5, activation='softmax', name="output_2")(x2)
342
model = keras.Model(inputs=inputs, outputs=[output_1, output_2])
343
model.compile(optimizer='adam', loss='categorical_crossentropy')
344
# dummy data
345
x_test = np.random.uniform(size=[10, 10])
346
y_test = np.random.uniform(size=[10, 5])
347
348
model.evaluate(x_test, y_test)
349
```
350
"""
351
352
from keras import layers
353
354
# A functional model with multiple outputs
355
inputs = layers.Input(shape=(10,))
356
x1 = layers.Dense(5, activation="relu")(inputs)
357
x2 = layers.Dense(5, activation="relu")(x1)
358
output_1 = layers.Dense(5, activation="softmax", name="output_1")(x1)
359
output_2 = layers.Dense(5, activation="softmax", name="output_2")(x2)
360
# dummy data
361
x_test = np.random.uniform(size=[10, 10])
362
y_test = np.random.uniform(size=[10, 5])
363
multi_output_model = keras.Model(inputs=inputs, outputs=[output_1, output_2])
364
multi_output_model.compile(
365
optimizer="adam",
366
loss="categorical_crossentropy",
367
metrics=["categorical_crossentropy", "categorical_crossentropy"],
368
)
369
multi_output_model.evaluate(x_test, y_test)
370
371
372
"""
373
### TensorFlow variables tracking
374
375
Setting a `tf.Variable` as an attribute of a Keras 3 layer or model will not automatically
376
track the variable, unlike in Keras 2. The following snippet of code will show that the `tf.Variables`
377
are not being tracked.
378
379
```python
380
class MyCustomLayer(keras.layers.Layer):
381
def __init__(self, units):
382
super().__init__()
383
self.units = units
384
385
def build(self, input_shape):
386
input_dim = input_shape[-1]
387
self.w = tf.Variable(initial_value=tf.zeros([input_dim, self.units]))
388
self.b = tf.Variable(initial_value=tf.zeros([self.units,]))
389
390
def call(self, inputs):
391
return keras.ops.matmul(inputs, self.w) + self.b
392
393
394
layer = MyCustomLayer(3)
395
data = np.random.uniform(size=[3, 3])
396
model = keras.models.Sequential([layer])
397
model.compile(optimizer="adam", loss="mse")
398
model.predict(data)
399
# The model does not have any trainable variables
400
for layer in model.layers:
401
print(layer.trainable_variables)
402
```
403
404
You will see the following warning:
405
406
```
407
UserWarning: The model does not have any trainable weights.
408
warnings.warn("The model does not have any trainable weights.")
409
```
410
411
**How to fix it:** use `self.add_weight()` method or opt for a `keras.Variable` instead. If you
412
are currently using `tf.variable`, you can switch to `keras.Variable`.
413
"""
414
415
416
class MyCustomLayer(keras.layers.Layer):
417
def __init__(self, units):
418
super().__init__()
419
self.units = units
420
421
def build(self, input_shape):
422
input_dim = input_shape[-1]
423
self.w = self.add_weight(
424
shape=[input_dim, self.units],
425
initializer="zeros",
426
)
427
self.b = self.add_weight(
428
shape=[
429
self.units,
430
],
431
initializer="zeros",
432
)
433
434
def call(self, inputs):
435
return keras.ops.matmul(inputs, self.w) + self.b
436
437
438
layer = MyCustomLayer(3)
439
data = np.random.uniform(size=[3, 3])
440
model = keras.models.Sequential([layer])
441
model.compile(optimizer="adam", loss="mse")
442
model.predict(data)
443
# Verify that the variables are now being tracked
444
for layer in model.layers:
445
print(layer.trainable_variables)
446
447
"""
448
### `None` entries in nested `call()` arguments
449
450
`None` entries are not allowed as part of nested (e.g. list/tuples) tensor
451
arguments in `Layer.call()`, nor as part of `call()`'s nested return values.
452
453
If the `None` in the argument is intentional and serves a specific purpose,
454
ensure that the argument is optional and structure it as a separate parameter.
455
For example, consider defining the `call` method with optional argument.
456
457
The following snippet of code will reproduce the error.
458
459
```python
460
class CustomLayer(keras.layers.Layer):
461
def __init__(self):
462
super().__init__()
463
464
def call(self, inputs):
465
foo = inputs["foo"]
466
baz = inputs["bar"]["baz"]
467
if baz is not None:
468
return foo + baz
469
return foo
470
471
layer = CustomLayer()
472
inputs = {
473
"foo": keras.Input(shape=(1,), name="foo"),
474
"bar": {
475
"baz": None,
476
},
477
}
478
layer(inputs)
479
```
480
"""
481
482
"""
483
**How to fix it:**
484
485
**Solution 1:** Replace `None` with a value, like this:
486
"""
487
488
489
class CustomLayer(keras.layers.Layer):
490
def __init__(self):
491
super().__init__()
492
493
def call(self, inputs):
494
foo = inputs["foo"]
495
baz = inputs["bar"]["baz"]
496
return foo + baz
497
498
499
layer = CustomLayer()
500
inputs = {
501
"foo": keras.Input(shape=(1,), name="foo"),
502
"bar": {
503
"baz": keras.Input(shape=(1,), name="bar"),
504
},
505
}
506
layer(inputs)
507
508
509
"""
510
**Solution 2:** Define the call method with an optional argument.
511
Here is an example of this fix:
512
"""
513
514
515
class CustomLayer(keras.layers.Layer):
516
def __init__(self):
517
super().__init__()
518
519
def call(self, foo, baz=None):
520
if baz is not None:
521
return foo + baz
522
return foo
523
524
525
layer = CustomLayer()
526
foo = keras.Input(shape=(1,), name="foo")
527
baz = None
528
layer(foo, baz=baz)
529
530
"""
531
### State-building issues
532
533
Keras 3 is significantly stricter than Keras 2 about when state (e.g. numerical weight variables)
534
can be created. Keras 3 wants all state to be created before the model can be trained. This is a requirement
535
for using JAX (whereas TensorFlow was very lenient about state creation timing).
536
537
Keras layers should create their state either in their constructor (`__init__()` method) or in their `build()` method.
538
They should avoid creating state in `call()`.
539
540
If you ignore this recommendation and create state in `call()`
541
anyway (e.g. by calling a previously unbuilt layer), then Keras will attempt to build the layer automatically
542
by calling the `call()` method on symbolic inputs before training.
543
However, this attempt at automatic state creation may fail in certain cases.
544
This will cause an error that looks like like this:
545
546
```
547
Layer 'frame_position_embedding' looks like it has unbuilt state,
548
but Keras is not able to trace the layer `call()` in order to build it automatically.
549
Possible causes:
550
1. The `call()` method of your layer may be crashing.
551
Try to `__call__()` the layer eagerly on some test input first to see if it works.
552
E.g. `x = np.random.random((3, 4)); y = layer(x)`
553
2. If the `call()` method is correct, then you may need to implement
554
the `def build(self, input_shape)` method on your layer.
555
It should create all variables used by the layer
556
(e.g. by calling `layer.build()` on all its children layers).
557
```
558
559
You could reproduce this error with the following layer, when used with the JAX backend:
560
561
```python
562
class PositionalEmbedding(keras.layers.Layer):
563
def __init__(self, sequence_length, output_dim, **kwargs):
564
super().__init__(**kwargs)
565
self.position_embeddings = layers.Embedding(
566
input_dim=sequence_length, output_dim=output_dim
567
)
568
self.sequence_length = sequence_length
569
self.output_dim = output_dim
570
571
def call(self, inputs):
572
inputs = keras.ops.cast(inputs, self.compute_dtype)
573
length = keras.ops.shape(inputs)[1]
574
positions = keras.ops.arange(start=0, stop=length, step=1)
575
embedded_positions = self.position_embeddings(positions)
576
return inputs + embedded_positions
577
```
578
579
**How to fix it:** Do exactly what the error message asks. First, try to run the layer eagerly
580
to see if the `call()` method is in fact correct (note: if it was working in Keras 2, then it is correct
581
and does not need to be changed). If it is indeed correct, then you should implement a `build(self, input_shape)`
582
method that creates all of the layer's state, including the state of sublayers. Here's the fix as applied for the layer above
583
(note the `build()` method):
584
585
```python
586
class PositionalEmbedding(keras.layers.Layer):
587
def __init__(self, sequence_length, output_dim, **kwargs):
588
super().__init__(**kwargs)
589
self.position_embeddings = layers.Embedding(
590
input_dim=sequence_length, output_dim=output_dim
591
)
592
self.sequence_length = sequence_length
593
self.output_dim = output_dim
594
595
def build(self, input_shape):
596
self.position_embeddings.build(input_shape)
597
598
def call(self, inputs):
599
inputs = keras.ops.cast(inputs, self.compute_dtype)
600
length = keras.ops.shape(inputs)[1]
601
positions = keras.ops.arange(start=0, stop=length, step=1)
602
embedded_positions = self.position_embeddings(positions)
603
return inputs + embedded_positions
604
```
605
"""
606
607
608
"""
609
### Removed features
610
611
A small number of legacy features with very low usage were removed from Keras 3 as a cleanup measure:
612
613
* `keras.layers.ThresholdedReLU` is removed. Instead, you can simply use the `ReLU` layer
614
with the argument `threshold`.
615
* Symbolic `Layer.add_loss()`: Symbolic `add_loss()` is removed (you can still use
616
`add_loss()` inside the `call()` method of a layer/model).
617
* Locally connected layers (`LocallyConnected1D`, `LocallyConnected2D`
618
are removed due to very low usage. To
619
use locally connected layers, copy the layer implementation into your own codebase.
620
* `keras.layers.experimental.RandomFourierFeatures` is removed due to very low usage.
621
To use it, copy the layer implementation into your own codebase.
622
* Removed layer attributes: Layer attributes `metrics`, `dynamic` are removed. `metrics` is still
623
available on the `Model` class.
624
* The `constants` and `time_major` arguments in RNN layers are removed.
625
The `constants` argument was a remnant of Theano and had very low usage. The `time_major`
626
argument also had very low usage.
627
* `reset_metrics` argument: The `reset_metrics` argument is removed from `model.*_on_batch()`
628
methods. This argument had very low usage.
629
* The `keras.constraints.RadialConstraint` object is removed. This object had very low usage.
630
"""
631
632
"""
633
## Transitioning to backend-agnostic Keras 3
634
635
Keras 3 code with the TensorFlow backend will work with native TensorFlow APIs.
636
However, if you want your code to be backend-agnostic, you will need to:
637
638
- Replace all of the `tf.*` API calls with their equivalent Keras APIs.
639
- Convert your custom `train_step`/`test_step` methods to a multi-framework
640
implementation.
641
- Make sure you're using stateless `keras.random` ops correctly in your layers.
642
643
Let's go over each point in detail.
644
645
### Switching to Keras ops
646
647
In many cases, this is the only thing you need to do to start being able to run
648
your custom layers and metrics with JAX and PyTorch:
649
replace any `tf.*`, `tf.math*`, `tf.linalg.*`, etc. with `keras.ops.*`. Most TF ops
650
should be consistent with Keras 3. If the names different, they will be
651
highlighted in this guide.
652
653
#### NumPy ops
654
655
Keras implements the NumPy API as part of `keras.ops`.
656
657
The table below only lists a small subset of TensorFlow and Keras ops; ops not listed
658
are usually named the same in both frameworks (e.g. `reshape`, `matmul`, `cast`, etc.)
659
660
| TensorFlow | Keras 3.0 |
661
|--------------------------------------------|-------------------------------------------|
662
| `tf.abs` | `keras.ops.absolute` |
663
| `tf.reduce_all` | `keras.ops.all` |
664
| `tf.reduce_max` | `keras.ops.amax` |
665
| `tf.reduce_min` | `keras.ops.amin` |
666
| `tf.reduce_any` | `keras.ops.any` |
667
| `tf.concat` | `keras.ops.concatenate` |
668
| `tf.range` | `keras.ops.arange` |
669
| `tf.acos` | `keras.ops.arccos` |
670
| `tf.asin` | `keras.ops.arcsin` |
671
| `tf.asinh` | `keras.ops.arcsinh` |
672
| `tf.atan` | `keras.ops.arctan` |
673
| `tf.atan2` | `keras.ops.arctan2` |
674
| `tf.atanh` | `keras.ops.arctanh` |
675
| `tf.convert_to_tensor` | `keras.ops.convert_to_tensor` |
676
| `tf.reduce_mean` | `keras.ops.mean` |
677
| `tf.clip_by_value` | `keras.ops.clip` |
678
| `tf.math.conj` | `keras.ops.conjugate` |
679
| `tf.linalg.diag_part` | `keras.ops.diagonal` |
680
| `tf.reverse` | `keras.ops.flip` |
681
| `tf.gather` | `keras.ops.take` |
682
| `tf.math.is_finite` | `keras.ops.isfinite` |
683
| `tf.math.is_inf` | `keras.ops.isinf` |
684
| `tf.math.is_nan` | `keras.ops.isnan` |
685
| `tf.reduce_max` | `keras.ops.max` |
686
| `tf.reduce_mean` | `keras.ops.mean` |
687
| `tf.reduce_min` | `keras.ops.min` |
688
| `tf.rank` | `keras.ops.ndim` |
689
| `tf.math.pow` | `keras.ops.power` |
690
| `tf.reduce_prod` | `keras.ops.prod` |
691
| `tf.math.reduce_std` | `keras.ops.std` |
692
| `tf.reduce_sum` | `keras.ops.sum` |
693
| `tf.gather` | `keras.ops.take` |
694
| `tf.gather_nd` | `keras.ops.take_along_axis` |
695
| `tf.math.reduce_variance` | `keras.ops.var` |
696
697
698
#### Others ops
699
700
| TensorFlow | Keras 3.0 |
701
|----------------------------------------------------|-------------------------------------------------------------------|
702
| `tf.nn.sigmoid_cross_entropy_with_logits` | `keras.ops.binary_crossentropy` (mind the `from_logits` argument) |
703
| `tf.nn.sparse_softmax_cross_entropy_with_logits` | `keras.ops.sparse_categorical_crossentropy` (mind the `from_logits` argument)|
704
| `tf.nn.sparse_softmax_cross_entropy_with_logits` | `keras.ops.categorical_crossentropy(target, output, from_logits=False, axis=-1)`|
705
| `tf.nn.conv1d`, `tf.nn.conv2d`, `tf.nn.conv3d`, `tf.nn.convolution` | `keras.ops.conv` |
706
| `tf.nn.conv_transpose`, `tf.nn.conv1d_transpose`, `tf.nn.conv2d_transpose`, `tf.nn.conv3d_transpose` | `keras.ops.conv_transpose` |
707
| `tf.nn.depthwise_conv2d` | `keras.ops.depthwise_conv` |
708
| `tf.nn.separable_conv2d` | `keras.ops.separable_conv` |
709
| `tf.nn.batch_normalization` | No direct equivalent; use `keras.layers.BatchNormalization` |
710
| `tf.nn.dropout` | `keras.random.dropout` |
711
| `tf.nn.embedding_lookup` | `keras.ops.take` |
712
| `tf.nn.l2_normalize` | `keras.utils.normalize` (not an op) |
713
| `x.numpy` | `keras.ops.convert_to_numpy` |
714
| `tf.scatter_nd_update` | `keras.ops.scatter_update` |
715
| `tf.tensor_scatter_nd_update` | `keras.ops.slice_update` |
716
| `tf.signal.fft2d` | `keras.ops.fft2` |
717
| `tf.signal.inverse_stft` | `keras.ops.istft` |
718
| `tf.image.crop_to_bounding_box` | `keras.ops.image.crop_images` |
719
| `tf.image.pad_to_bounding_box` | `keras.ops.image.pad_images` |
720
721
"""
722
723
"""
724
### Custom `train_step()` methods
725
726
Your models may include a custom `train_step()` or `test_step()` method, which rely
727
on TensorFlow-only APIs -- for instance, your `train_step()` method may leverage TensorFlow's `tf.GradientTape`.
728
To convert such models to run on JAX or PyTorch, you will have a write a different `train_step()` implementation
729
for each backend you want to support.
730
731
In some cases, you might be able to simply override the `Model.compute_loss()` method and make it fully backend-agnostic,
732
instead of overriding `train_step()`. Here's an example of a layer with a custom `compute_loss()` method which works
733
across JAX, TensorFlow, and PyTorch:
734
"""
735
736
737
class MyModel(keras.Model):
738
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
739
loss = keras.ops.sum(keras.losses.mean_squared_error(y, y_pred, sample_weight))
740
return loss
741
742
743
"""
744
If you need to modify the optimization mechanism itself, beyond the loss computation,
745
then you will need to override `train_step()`, and implement one `train_step` method per backend, like below.
746
747
See the following guides for details on how each backend should be handled:
748
749
- [Customizing what happens in `fit()` with JAX](https://keras.io/guides/custom_train_step_in_jax/)
750
- [Customizing what happens in `fit()` with TensorFlow](https://keras.io/guides/custom_train_step_in_tensorflow/)
751
- [Customizing what happens in `fit()` with PyTorch](https://keras.io/guides/custom_train_step_in_torch/)
752
"""
753
754
755
class MyModel(keras.Model):
756
def train_step(self, *args, **kwargs):
757
if keras.backend.backend() == "jax":
758
return self._jax_train_step(*args, **kwargs)
759
elif keras.backend.backend() == "tensorflow":
760
return self._tensorflow_train_step(*args, **kwargs)
761
elif keras.backend.backend() == "torch":
762
return self._torch_train_step(*args, **kwargs)
763
764
def _jax_train_step(self, state, data):
765
pass # See guide: keras.io/guides/custom_train_step_in_jax/
766
767
def _tensorflow_train_step(self, data):
768
pass # See guide: keras.io/guides/custom_train_step_in_tensorflow/
769
770
def _torch_train_step(self, data):
771
pass # See guide: keras.io/guides/custom_train_step_in_torch/
772
773
774
"""
775
### RNG-using layers
776
777
Keras 3 has a new `keras.random` namespace, containing:
778
779
- `keras.random.normal`
780
- `keras.random.uniform`
781
- `keras.random.shuffle`
782
- etc.
783
784
These operations are **stateless**, which means that if you pass a `seed`
785
argument, they will return the same result every time. Like this:
786
"""
787
788
print(keras.random.normal(shape=(), seed=123))
789
print(keras.random.normal(shape=(), seed=123))
790
791
"""
792
Crucially, this differs from the behavior of stateful `tf.random` ops:
793
"""
794
795
print(tf.random.normal(shape=(), seed=123))
796
print(tf.random.normal(shape=(), seed=123))
797
798
"""
799
When you write a RNG-using layer, such as a custom dropout layer, you are
800
going to want to use a different seed value at layer call. However, you cannot
801
just increment a Python integer and pass it, because while this would work fine
802
when executed eagerly, it would not work as expected when using compilation
803
(which is available with JAX, TensorFlow, and PyTorch). When compiling the layer,
804
the first Python integer seed value seen by the layer would be hardcoded into the
805
compiled graph.
806
807
To address this, you should pass as the `seed` argument an instance of a
808
stateful `keras.random.SeedGenerator` object, like this:
809
"""
810
811
seed_generator = keras.random.SeedGenerator(1337)
812
print(keras.random.normal(shape=(), seed=seed_generator))
813
print(keras.random.normal(shape=(), seed=seed_generator))
814
815
816
"""
817
So when writing a RNG using layer, you would use the following pattern:
818
"""
819
820
821
class RandomNoiseLayer(keras.layers.Layer):
822
def __init__(self, noise_rate, **kwargs):
823
super().__init__(**kwargs)
824
self.noise_rate = noise_rate
825
self.seed_generator = keras.random.SeedGenerator(1337)
826
827
def call(self, inputs):
828
noise = keras.random.uniform(
829
minval=0, maxval=self.noise_rate, seed=self.seed_generator
830
)
831
return inputs + noise
832
833
834
"""
835
Such a layer is safe to use in any setting -- in eager execution or in a compiled model. Each
836
layer call will be using a different seed value, as expected.
837
"""
838
839