Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/intro_to_keras_for_engineers.py
3273 views
1
"""
2
Title: Introduction to Keras for engineers
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/07/10
5
Last modified: 2023/07/10
6
Description: First contact with Keras 3.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Keras 3 is a deep learning framework
14
works with TensorFlow, JAX, and PyTorch interchangeably.
15
This notebook will walk you through key Keras 3 workflows.
16
17
Let's start by installing Keras 3:
18
"""
19
20
"""shell
21
pip install keras --upgrade --quiet
22
"""
23
24
"""
25
## Setup
26
27
We're going to be using the JAX backend here -- but you can
28
edit the string below to `"tensorflow"` or `"torch"` and hit
29
"Restart runtime", and the whole notebook will run just the same!
30
This entire guide is backend-agnostic.
31
"""
32
33
import numpy as np
34
import os
35
36
os.environ["KERAS_BACKEND"] = "jax"
37
38
# Note that Keras should only be imported after the backend
39
# has been configured. The backend cannot be changed once the
40
# package is imported.
41
import keras
42
43
"""
44
## A first example: A MNIST convnet
45
46
Let's start with the Hello World of ML: training a convnet
47
to classify MNIST digits.
48
49
Here's the data:
50
"""
51
52
# Load the data and split it between train and test sets
53
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
54
55
# Scale images to the [0, 1] range
56
x_train = x_train.astype("float32") / 255
57
x_test = x_test.astype("float32") / 255
58
# Make sure images have shape (28, 28, 1)
59
x_train = np.expand_dims(x_train, -1)
60
x_test = np.expand_dims(x_test, -1)
61
print("x_train shape:", x_train.shape)
62
print("y_train shape:", y_train.shape)
63
print(x_train.shape[0], "train samples")
64
print(x_test.shape[0], "test samples")
65
66
"""
67
Here's our model.
68
69
Different model-building options that Keras offers include:
70
71
- [The Sequential API](https://keras.io/guides/sequential_model/) (what we use below)
72
- [The Functional API](https://keras.io/guides/functional_api/) (most typical)
73
- [Writing your own models yourself via subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) (for advanced use cases)
74
"""
75
76
# Model parameters
77
num_classes = 10
78
input_shape = (28, 28, 1)
79
80
model = keras.Sequential(
81
[
82
keras.layers.Input(shape=input_shape),
83
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
84
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
85
keras.layers.MaxPooling2D(pool_size=(2, 2)),
86
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
87
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
88
keras.layers.GlobalAveragePooling2D(),
89
keras.layers.Dropout(0.5),
90
keras.layers.Dense(num_classes, activation="softmax"),
91
]
92
)
93
94
"""
95
Here's our model summary:
96
"""
97
98
model.summary()
99
100
"""
101
We use the `compile()` method to specify the optimizer, loss function,
102
and the metrics to monitor. Note that with the JAX and TensorFlow backends,
103
XLA compilation is turned on by default.
104
"""
105
106
model.compile(
107
loss=keras.losses.SparseCategoricalCrossentropy(),
108
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
109
metrics=[
110
keras.metrics.SparseCategoricalAccuracy(name="acc"),
111
],
112
)
113
114
"""
115
Let's train and evaluate the model. We'll set aside a validation split of 15%
116
of the data during training to monitor generalization on unseen data.
117
"""
118
119
batch_size = 128
120
epochs = 20
121
122
callbacks = [
123
keras.callbacks.ModelCheckpoint(filepath="model_at_epoch_{epoch}.keras"),
124
keras.callbacks.EarlyStopping(monitor="val_loss", patience=2),
125
]
126
127
model.fit(
128
x_train,
129
y_train,
130
batch_size=batch_size,
131
epochs=epochs,
132
validation_split=0.15,
133
callbacks=callbacks,
134
)
135
score = model.evaluate(x_test, y_test, verbose=0)
136
137
"""
138
During training, we were saving a model at the end of each epoch. You
139
can also save the model in its latest state like this:
140
"""
141
142
model.save("final_model.keras")
143
144
"""
145
And reload it like this:
146
"""
147
148
model = keras.saving.load_model("final_model.keras")
149
150
"""
151
Next, you can query predictions of class probabilities with `predict()`:
152
"""
153
154
predictions = model.predict(x_test)
155
156
"""
157
That's it for the basics!
158
"""
159
160
"""
161
## Writing cross-framework custom components
162
163
Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers
164
that work across TensorFlow, JAX, and PyTorch with the same codebase. Let's take a look
165
at custom layers first.
166
167
The `keras.ops` namespace contains:
168
169
- An implementation of the NumPy API, e.g. `keras.ops.stack` or `keras.ops.matmul`.
170
- A set of neural network specific ops that are absent from NumPy, such as `keras.ops.conv`
171
or `keras.ops.binary_crossentropy`.
172
173
Let's make a custom `Dense` layer that works with all backends:
174
"""
175
176
177
class MyDense(keras.layers.Layer):
178
def __init__(self, units, activation=None, name=None):
179
super().__init__(name=name)
180
self.units = units
181
self.activation = keras.activations.get(activation)
182
183
def build(self, input_shape):
184
input_dim = input_shape[-1]
185
self.w = self.add_weight(
186
shape=(input_dim, self.units),
187
initializer=keras.initializers.GlorotNormal(),
188
name="kernel",
189
trainable=True,
190
)
191
192
self.b = self.add_weight(
193
shape=(self.units,),
194
initializer=keras.initializers.Zeros(),
195
name="bias",
196
trainable=True,
197
)
198
199
def call(self, inputs):
200
# Use Keras ops to create backend-agnostic layers/metrics/etc.
201
x = keras.ops.matmul(inputs, self.w) + self.b
202
return self.activation(x)
203
204
205
"""
206
Next, let's make a custom `Dropout` layer that relies on the `keras.random`
207
namespace:
208
"""
209
210
211
class MyDropout(keras.layers.Layer):
212
def __init__(self, rate, name=None):
213
super().__init__(name=name)
214
self.rate = rate
215
# Use seed_generator for managing RNG state.
216
# It is a state element and its seed variable is
217
# tracked as part of `layer.variables`.
218
self.seed_generator = keras.random.SeedGenerator(1337)
219
220
def call(self, inputs):
221
# Use `keras.random` for random ops.
222
return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)
223
224
225
"""
226
Next, let's write a custom subclassed model that uses our two custom layers:
227
"""
228
229
230
class MyModel(keras.Model):
231
def __init__(self, num_classes):
232
super().__init__()
233
self.conv_base = keras.Sequential(
234
[
235
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
236
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
237
keras.layers.MaxPooling2D(pool_size=(2, 2)),
238
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
239
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
240
keras.layers.GlobalAveragePooling2D(),
241
]
242
)
243
self.dp = MyDropout(0.5)
244
self.dense = MyDense(num_classes, activation="softmax")
245
246
def call(self, x):
247
x = self.conv_base(x)
248
x = self.dp(x)
249
return self.dense(x)
250
251
252
"""
253
Let's compile it and fit it:
254
"""
255
256
model = MyModel(num_classes=10)
257
model.compile(
258
loss=keras.losses.SparseCategoricalCrossentropy(),
259
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
260
metrics=[
261
keras.metrics.SparseCategoricalAccuracy(name="acc"),
262
],
263
)
264
265
model.fit(
266
x_train,
267
y_train,
268
batch_size=batch_size,
269
epochs=1, # For speed
270
validation_split=0.15,
271
)
272
273
"""
274
## Training models on arbitrary data sources
275
276
All Keras models can be trained and evaluated on a wide variety of data sources,
277
independently of the backend you're using. This includes:
278
279
- NumPy arrays
280
- Pandas dataframes
281
- TensorFlow `tf.data.Dataset` objects
282
- PyTorch `DataLoader` objects
283
- Keras `PyDataset` objects
284
285
They all work whether you're using TensorFlow, JAX, or PyTorch as your Keras backend.
286
287
Let's try it out with PyTorch `DataLoaders`:
288
"""
289
290
import torch
291
292
# Create a TensorDataset
293
train_torch_dataset = torch.utils.data.TensorDataset(
294
torch.from_numpy(x_train), torch.from_numpy(y_train)
295
)
296
val_torch_dataset = torch.utils.data.TensorDataset(
297
torch.from_numpy(x_test), torch.from_numpy(y_test)
298
)
299
300
# Create a DataLoader
301
train_dataloader = torch.utils.data.DataLoader(
302
train_torch_dataset, batch_size=batch_size, shuffle=True
303
)
304
val_dataloader = torch.utils.data.DataLoader(
305
val_torch_dataset, batch_size=batch_size, shuffle=False
306
)
307
308
model = MyModel(num_classes=10)
309
model.compile(
310
loss=keras.losses.SparseCategoricalCrossentropy(),
311
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
312
metrics=[
313
keras.metrics.SparseCategoricalAccuracy(name="acc"),
314
],
315
)
316
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)
317
318
319
"""
320
Now let's try this out with `tf.data`:
321
"""
322
323
import tensorflow as tf
324
325
train_dataset = (
326
tf.data.Dataset.from_tensor_slices((x_train, y_train))
327
.batch(batch_size)
328
.prefetch(tf.data.AUTOTUNE)
329
)
330
test_dataset = (
331
tf.data.Dataset.from_tensor_slices((x_test, y_test))
332
.batch(batch_size)
333
.prefetch(tf.data.AUTOTUNE)
334
)
335
336
model = MyModel(num_classes=10)
337
model.compile(
338
loss=keras.losses.SparseCategoricalCrossentropy(),
339
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
340
metrics=[
341
keras.metrics.SparseCategoricalAccuracy(name="acc"),
342
],
343
)
344
model.fit(train_dataset, epochs=1, validation_data=test_dataset)
345
346
"""
347
## Further reading
348
349
This concludes our short overview of the new multi-backend capabilities
350
of Keras 3. Next, you can learn about:
351
352
### How to customize what happens in `fit()`
353
354
Want to implement a non-standard training algorithm yourself but still want to benefit from
355
the power and usability of `fit()`? It's easy to customize
356
`fit()` to support arbitrary use cases:
357
358
- [Customizing what happens in `fit()` with TensorFlow](http://keras.io/guides/custom_train_step_in_tensorflow/)
359
- [Customizing what happens in `fit()` with JAX](http://keras.io/guides/custom_train_step_in_jax/)
360
- [Customizing what happens in `fit()` with PyTorch](http://keras.io/guides/custom_train_step_in_torch/)
361
362
## How to write custom training loops
363
364
- [Writing a training loop from scratch in TensorFlow](http://keras.io/guides/writing_a_custom_training_loop_in_tensorflow/)
365
- [Writing a training loop from scratch in JAX](http://keras.io/guides/writing_a_custom_training_loop_in_jax/)
366
- [Writing a training loop from scratch in PyTorch](http://keras.io/guides/writing_a_custom_training_loop_in_torch/)
367
368
## How to distribute training
369
370
- [Guide to distributed training with TensorFlow](http://keras.io/guides/distributed_training_with_tensorflow/)
371
- [JAX distributed training example](https://github.com/keras-team/keras/blob/master/examples/demo_jax_distributed.py)
372
- [PyTorch distributed training example](https://github.com/keras-team/keras/blob/master/examples/demo_torch_multi_gpu.py)
373
374
Enjoy the library! 🚀
375
"""
376
377