Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_nnx_guide.py
3273 views
1
"""
2
Title: How to use Keras with NNX backend
3
Author: [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)
4
Date created: 2025/08/07
5
Last modified: 2025/08/07
6
Description: How to use Keras with NNX backend.
7
Accelerator: GPU
8
"""
9
10
"""
11
12
# A Guide to the Keras & Flax NNX Integration
13
14
This tutorial will guide you through the integration of Keras with Flax's NNX
15
(Neural Networks JAX) module system, demonstrating how it significantly
16
enhances variable handling and opens up advanced training capabilities within
17
the JAX ecosystem. Whether you love the simplicity of model.fit() or the
18
fine-grained control of a custom training loop, this integration lets you have
19
the best of both worlds. Let's dive in!
20
21
# Why Keras and NNX Integration?
22
23
Keras is known for its user-friendliness and high-level API, making deep
24
learning accessible. JAX, on the other hand, provides high-performance
25
numerical computation, especially suited for machine learning research due to
26
its JIT compilation and automatic differentiation capabilities. NNX is Flax's
27
functional module system built on JAX, offering explicit state management and
28
powerful functional programming paradigms
29
30
NNX is designed for simplicity. It is characterized by its Pythonic approach,
31
where modules are standard Python classes, promoting ease of use and
32
familiarity. NNX prioritizes user-friendliness and offers fine-grained control
33
over JAX transformations through typed Variable collections.
34
35
The integration of Keras with NNX allows you to leverage the best of both
36
worlds: the simplicity and modularity of Keras for model construction,
37
combined with the power and explicit control of NNX and JAX for variable
38
management and sophisticated training loops.
39
40
# Getting Started: Setting Up Your Environment
41
"""
42
43
"""shell
44
!pip install -q -U keras
45
!pip install -q -U flax==0.11.0
46
"""
47
48
"""
49
# Enabling NNX Mode
50
51
To activate the integration, we must set two environment variables before
52
importing Keras. This tells Keras to use the JAX backend and switch to NNX as
53
an opt in feature.
54
"""
55
56
import os
57
58
os.environ["KERAS_BACKEND"] = "jax"
59
os.environ["KERAS_NNX_ENABLED"] = "true"
60
from flax import nnx
61
import keras
62
import jax.numpy as jnp
63
64
print("✅ Keras is now running on JAX with NNX enabled!")
65
66
"""
67
# The Core Integration: Keras Variables in NNX
68
69
The heart of this integration is the new keras.Variable, which is designed to
70
be a native citizen of the Flax NNX ecosystem. This means you can mix Keras
71
and NNX components freely, and NNX's tracing and state management tools will
72
understand your Keras variables.
73
Let's prove it. We'll create an nnx.Module that contains both a standard
74
nnx.Linear layer and a keras.Variable.
75
"""
76
77
from keras import Variable as KerasVariable
78
79
80
class MyNnxModel(nnx.Module):
81
def __init__(self, rngs):
82
self.linear = nnx.Linear(2, 3, rngs=rngs)
83
self.custom_variable = KerasVariable(jnp.ones((1, 3)))
84
85
def __call__(self, x):
86
return self.linear(x) + self.custom_variable
87
88
89
# Instantiate the model
90
model = MyNnxModel(rngs=nnx.Rngs(0))
91
92
# --- Verification ---
93
# 1. Is the KerasVariable traced by NNX?
94
print(f"✅ Traced: {hasattr(model.custom_variable, '_trace_state')}")
95
96
# 2. Does NNX see the KerasVariable in the model's state?
97
print("✅ Variables:", nnx.variables(model))
98
99
# 3. Can we access its value directly?
100
print("✅ Value:", model.custom_variable.value)
101
102
"""
103
What this shows:
104
The KerasVariable is successfully traced by NNX, just like any native
105
nnx.Variable.
106
The nnx.variables() function correctly identifies and lists our
107
custom_variable as part of the model's state.
108
This confirms that Keras state and NNX state can live together in perfect
109
harmony.
110
111
# The Best of Both Worlds: Training Workflows
112
113
Now for the exciting part: training models. This integration unlocks two
114
powerful workflows.
115
116
## Workflow 1: The Classic Keras Experience (model.fit)
117
"""
118
119
import numpy as np
120
121
"""
122
1. Create a Keras Model
123
"""
124
model = keras.Sequential(
125
[keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer")]
126
)
127
128
print("--- Initial Model Weights ---")
129
initial_weights = model.get_weights()
130
print(f"Initial Kernel: {initial_weights[0].T}") # .T for better display
131
print(f"Initial Bias: {initial_weights[1]}")
132
133
"""
134
2. Create Dummy Data
135
"""
136
X_dummy = np.random.rand(100, 10)
137
y_dummy = np.random.rand(100, 1)
138
"""
139
3. Compile and Fit
140
"""
141
model.compile(
142
optimizer=keras.optimizers.SGD(learning_rate=0.01),
143
loss="mean_squared_error",
144
)
145
146
print("\n--- Training with model.fit() ---")
147
history = model.fit(X_dummy, y_dummy, epochs=5, batch_size=32, verbose=1)
148
149
"""
150
4. Verify a change
151
"""
152
print("\n--- Weights After Training ---")
153
updated_weights = model.get_weights()
154
print(f"Updated Kernel: {updated_weights[0].T}")
155
print(f"Updated Bias: {updated_weights[1]}")
156
157
# Verification
158
if not np.array_equal(initial_weights[1], updated_weights[1]):
159
print("\n✅ SUCCESS: Model variables were updated during training.")
160
else:
161
print("\n❌ FAILURE: Model variables were not updated.")
162
163
"""
164
As you can see, your existing Keras code works out-of-the-box, giving you a
165
high-level, productive experience powered by JAX and NNX under the hood.
166
167
## Workflow 2: The Power of NNX: Custom Training Loops
168
169
For maximum flexibility, you can treat any Keras layer or model as an
170
nnx.Module and write your own training loop using libraries like Optax.
171
This is perfect when you need fine-grained control over the gradient and
172
update process.
173
"""
174
175
import numpy as np
176
import optax
177
178
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
179
Y = 0.8 * X + 0.1 + np.random.normal(0, 0.1, size=X.shape)
180
181
182
class MySimpleKerasModel(keras.Model):
183
def __init__(self, **kwargs):
184
super().__init__(**kwargs)
185
# Define the layers of your model
186
self.dense_layer = keras.layers.Dense(1)
187
188
def call(self, inputs):
189
# Define the forward pass
190
# The 'inputs' argument will receive the input tensor when the model is
191
# called
192
return self.dense_layer(inputs)
193
194
195
model = MySimpleKerasModel()
196
model(X)
197
198
tx = optax.sgd(1e-3)
199
trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable)
200
optimizer = nnx.Optimizer(model, tx, wrt=trainable_var)
201
202
203
@nnx.jit
204
def train_step(model, optimizer, batch):
205
x, y = batch
206
207
def loss_fn(model_):
208
y_pred = model_(x)
209
return jnp.mean((y - y_pred) ** 2)
210
211
grads = nnx.grad(loss_fn, wrt=trainable_var)(model)
212
optimizer.update(model, grads)
213
214
215
@nnx.jit
216
def test_step(model, batch):
217
x, y = batch
218
y_pred = model(x)
219
loss = jnp.mean((y - y_pred) ** 2)
220
return {"loss": loss}
221
222
223
def dataset(batch_size=10):
224
while True:
225
idx = np.random.choice(len(X), size=batch_size)
226
yield X[idx], Y[idx]
227
228
229
for step, batch in enumerate(dataset()):
230
train_step(model, optimizer, batch)
231
232
if step % 100 == 0:
233
logs = test_step(model, (X, Y))
234
print(f"step: {step}, loss: {logs['loss']}")
235
236
if step >= 500:
237
break
238
239
"""
240
This example shows how a keras model object is seamlessly passed to
241
nnx.Optimizer and differentiated by nnx.grad. This composition allows you
242
to integrate Keras components into sophisticated JAX/NNX workflows. This
243
approach also works perfectly with sequential, functional, subclassed keras
244
models are even just layers.
245
246
# Saving and Loading
247
248
Your investment in the Keras ecosystem is safe. Standard features like model
249
serialization work exactly as you'd expect.
250
"""
251
252
# Create a simple model
253
model = keras.Sequential([keras.layers.Dense(units=1, input_shape=(10,))])
254
dummy_input = np.random.rand(1, 10)
255
256
# Test call
257
print("Original model output:", model(dummy_input))
258
259
# Save and load
260
model.save("my_nnx_model.keras")
261
restored_model = keras.models.load_model("my_nnx_model.keras")
262
263
print("Restored model output:", restored_model(dummy_input))
264
265
# Verification
266
np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input))
267
print("\n✅ SUCCESS: Restored model output matches original model output.")
268
269
"""
270
# Real-World Application: Training Gemma
271
272
Before trying out this KerasHub model, please make sure you have set up your
273
Kaggle credentials in colab secrets. The colab pulls in `KAGGLE_KEY` and
274
`KAGGLE_USERNAME` to authenticate and download the models.
275
"""
276
277
import keras_hub
278
279
# Set a float16 policy for memory efficiency
280
keras.config.set_dtype_policy("float16")
281
282
# Load Gemma from KerasHub
283
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
284
285
# --- 1. Inference / Generation ---
286
print("--- Gemma Generation ---")
287
output = gemma_lm.generate("Keras is a", max_length=30)
288
print(output)
289
290
# --- 2. Fine-tuning ---
291
print("\n--- Gemma Fine-tuning ---")
292
# Dummy data for demonstration
293
features = np.array(["The quick brown fox jumped.", "I forgot my homework."])
294
# The model.fit() API works seamlessly!
295
gemma_lm.fit(x=features, batch_size=2)
296
print("\n✅ Gemma fine-tuning step completed successfully!")
297
298
"""
299
# Conclusion
300
301
The Keras-NNX integration represents a significant step forward, offering a
302
unified framework for both rapid prototyping and high-performance,
303
customizable research. You can now:
304
- Use familiar Keras APIs (Sequential, Model, fit, save) on a JAX backend.
305
- Integrate Keras layers and models directly into Flax NNX modules and training loops.
306
- Integrate Keras code/model with the NNX ecosystem like Qwix, Tunix, etc.
307
- Leverage the entire JAX ecosystem (e.g., nnx.jit, optax) with your Keras models.
308
- Seamlessly work with large models from KerasHub.
309
"""
310
311