Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/custom_train_step_in_torch.py
3273 views
1
"""
2
Title: Customizing what happens in `fit()` with PyTorch
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/06/27
5
Last modified: 2024/08/01
6
Description: Overriding the training step of the Model class with PyTorch.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
When you're doing supervised learning, you can use `fit()` and everything works
14
smoothly.
15
16
When you need to take control of every little detail, you can write your own training
17
loop entirely from scratch.
18
19
But what if you need a custom training algorithm, but you still want to benefit from
20
the convenient features of `fit()`, such as callbacks, built-in distribution support,
21
or step fusing?
22
23
A core principle of Keras is **progressive disclosure of complexity**. You should
24
always be able to get into lower-level workflows in a gradual way. You shouldn't fall
25
off a cliff if the high-level functionality doesn't exactly match your use case. You
26
should be able to gain more control over the small details while retaining a
27
commensurate amount of high-level convenience.
28
29
When you need to customize what `fit()` does, you should **override the training step
30
function of the `Model` class**. This is the function that is called by `fit()` for
31
every batch of data. You will then be able to call `fit()` as usual -- and it will be
32
running your own learning algorithm.
33
34
Note that this pattern does not prevent you from building models with the Functional
35
API. You can do this whether you're building `Sequential` models, Functional API
36
models, or subclassed models.
37
38
Let's see how that works.
39
"""
40
41
"""
42
## Setup
43
"""
44
45
import os
46
47
# This guide can only be run with the torch backend.
48
os.environ["KERAS_BACKEND"] = "torch"
49
50
import torch
51
import keras
52
from keras import layers
53
import numpy as np
54
55
"""
56
## A first simple example
57
58
Let's start from a simple example:
59
60
- We create a new class that subclasses `keras.Model`.
61
- We just override the method `train_step(self, data)`.
62
- We return a dictionary mapping metric names (including the loss) to their current
63
value.
64
65
The input argument `data` is what gets passed to fit as training data:
66
67
- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple
68
`(x, y)`
69
- If you pass a `torch.utils.data.DataLoader` or a `tf.data.Dataset`,
70
by calling `fit(dataset, ...)`, then `data` will be what gets yielded
71
by `dataset` at each batch.
72
73
In the body of the `train_step()` method, we implement a regular training update,
74
similar to what you are already familiar with. Importantly, **we compute the loss via
75
`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to
76
`compile()`.
77
78
Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,
79
to update the state of the metrics that were passed in `compile()`,
80
and we query results from `self.metrics` at the end to retrieve their current value.
81
"""
82
83
84
class CustomModel(keras.Model):
85
def train_step(self, data):
86
# Unpack the data. Its structure depends on your model and
87
# on what you pass to `fit()`.
88
x, y = data
89
90
# Call torch.nn.Module.zero_grad() to clear the leftover gradients
91
# for the weights from the previous train step.
92
self.zero_grad()
93
94
# Compute loss
95
y_pred = self(x, training=True) # Forward pass
96
loss = self.compute_loss(y=y, y_pred=y_pred)
97
98
# Call torch.Tensor.backward() on the loss to compute gradients
99
# for the weights.
100
loss.backward()
101
102
trainable_weights = [v for v in self.trainable_weights]
103
gradients = [v.value.grad for v in trainable_weights]
104
105
# Update weights
106
with torch.no_grad():
107
self.optimizer.apply(gradients, trainable_weights)
108
109
# Update metrics (includes the metric that tracks the loss)
110
for metric in self.metrics:
111
if metric.name == "loss":
112
metric.update_state(loss)
113
else:
114
metric.update_state(y, y_pred)
115
116
# Return a dict mapping metric names to current value
117
# Note that it will include the loss (tracked in self.metrics).
118
return {m.name: m.result() for m in self.metrics}
119
120
121
"""
122
Let's try this out:
123
"""
124
125
# Construct and compile an instance of CustomModel
126
inputs = keras.Input(shape=(32,))
127
outputs = keras.layers.Dense(1)(inputs)
128
model = CustomModel(inputs, outputs)
129
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
130
131
# Just use `fit` as usual
132
x = np.random.random((1000, 32))
133
y = np.random.random((1000, 1))
134
model.fit(x, y, epochs=3)
135
136
"""
137
## Going lower-level
138
139
Naturally, you could just skip passing a loss function in `compile()`, and instead do
140
everything *manually* in `train_step`. Likewise for metrics.
141
142
Here's a lower-level example, that only uses `compile()` to configure the optimizer:
143
144
- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).
145
- We implement a custom `train_step()` that updates the state of these metrics
146
(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,
147
to be displayed by the progress bar and to be pass to any callback.
148
- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise
149
calling `result()` would return an average since the start of training, whereas we usually work
150
with per-epoch averages. Thankfully, the framework can do that for us: just list any metric
151
you want to reset in the `metrics` property of the model. The model will call `reset_states()`
152
on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to
153
`evaluate()`.
154
"""
155
156
157
class CustomModel(keras.Model):
158
def __init__(self, *args, **kwargs):
159
super().__init__(*args, **kwargs)
160
self.loss_tracker = keras.metrics.Mean(name="loss")
161
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
162
self.loss_fn = keras.losses.MeanSquaredError()
163
164
def train_step(self, data):
165
x, y = data
166
167
# Call torch.nn.Module.zero_grad() to clear the leftover gradients
168
# for the weights from the previous train step.
169
self.zero_grad()
170
171
# Compute loss
172
y_pred = self(x, training=True) # Forward pass
173
loss = self.loss_fn(y, y_pred)
174
175
# Call torch.Tensor.backward() on the loss to compute gradients
176
# for the weights.
177
loss.backward()
178
179
trainable_weights = [v for v in self.trainable_weights]
180
gradients = [v.value.grad for v in trainable_weights]
181
182
# Update weights
183
with torch.no_grad():
184
self.optimizer.apply(gradients, trainable_weights)
185
186
# Compute our own metrics
187
self.loss_tracker.update_state(loss)
188
self.mae_metric.update_state(y, y_pred)
189
return {
190
"loss": self.loss_tracker.result(),
191
"mae": self.mae_metric.result(),
192
}
193
194
@property
195
def metrics(self):
196
# We list our `Metric` objects here so that `reset_states()` can be
197
# called automatically at the start of each epoch
198
# or at the start of `evaluate()`.
199
return [self.loss_tracker, self.mae_metric]
200
201
202
# Construct an instance of CustomModel
203
inputs = keras.Input(shape=(32,))
204
outputs = keras.layers.Dense(1)(inputs)
205
model = CustomModel(inputs, outputs)
206
207
# We don't pass a loss or metrics here.
208
model.compile(optimizer="adam")
209
210
# Just use `fit` as usual -- you can use callbacks, etc.
211
x = np.random.random((1000, 32))
212
y = np.random.random((1000, 1))
213
model.fit(x, y, epochs=5)
214
215
216
"""
217
## Supporting `sample_weight` & `class_weight`
218
219
You may have noticed that our first basic example didn't make any mention of sample
220
weighting. If you want to support the `fit()` arguments `sample_weight` and
221
`class_weight`, you'd simply do the following:
222
223
- Unpack `sample_weight` from the `data` argument
224
- Pass it to `compute_loss` & `update_state` (of course, you could also just apply
225
it manually if you don't rely on `compile()` for losses & metrics)
226
- That's it.
227
"""
228
229
230
class CustomModel(keras.Model):
231
def train_step(self, data):
232
# Unpack the data. Its structure depends on your model and
233
# on what you pass to `fit()`.
234
if len(data) == 3:
235
x, y, sample_weight = data
236
else:
237
sample_weight = None
238
x, y = data
239
240
# Call torch.nn.Module.zero_grad() to clear the leftover gradients
241
# for the weights from the previous train step.
242
self.zero_grad()
243
244
# Compute loss
245
y_pred = self(x, training=True) # Forward pass
246
loss = self.compute_loss(
247
y=y,
248
y_pred=y_pred,
249
sample_weight=sample_weight,
250
)
251
252
# Call torch.Tensor.backward() on the loss to compute gradients
253
# for the weights.
254
loss.backward()
255
256
trainable_weights = [v for v in self.trainable_weights]
257
gradients = [v.value.grad for v in trainable_weights]
258
259
# Update weights
260
with torch.no_grad():
261
self.optimizer.apply(gradients, trainable_weights)
262
263
# Update metrics (includes the metric that tracks the loss)
264
for metric in self.metrics:
265
if metric.name == "loss":
266
metric.update_state(loss)
267
else:
268
metric.update_state(y, y_pred, sample_weight=sample_weight)
269
270
# Return a dict mapping metric names to current value
271
# Note that it will include the loss (tracked in self.metrics).
272
return {m.name: m.result() for m in self.metrics}
273
274
275
# Construct and compile an instance of CustomModel
276
inputs = keras.Input(shape=(32,))
277
outputs = keras.layers.Dense(1)(inputs)
278
model = CustomModel(inputs, outputs)
279
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
280
281
# You can now use sample_weight argument
282
x = np.random.random((1000, 32))
283
y = np.random.random((1000, 1))
284
sw = np.random.random((1000, 1))
285
model.fit(x, y, sample_weight=sw, epochs=3)
286
287
"""
288
## Providing your own evaluation step
289
290
What if you want to do the same for calls to `model.evaluate()`? Then you would
291
override `test_step` in exactly the same way. Here's what it looks like:
292
"""
293
294
295
class CustomModel(keras.Model):
296
def test_step(self, data):
297
# Unpack the data
298
x, y = data
299
# Compute predictions
300
y_pred = self(x, training=False)
301
# Updates the metrics tracking the loss
302
loss = self.compute_loss(y=y, y_pred=y_pred)
303
# Update the metrics.
304
for metric in self.metrics:
305
if metric.name == "loss":
306
metric.update_state(loss)
307
else:
308
metric.update_state(y, y_pred)
309
# Return a dict mapping metric names to current value.
310
# Note that it will include the loss (tracked in self.metrics).
311
return {m.name: m.result() for m in self.metrics}
312
313
314
# Construct an instance of CustomModel
315
inputs = keras.Input(shape=(32,))
316
outputs = keras.layers.Dense(1)(inputs)
317
model = CustomModel(inputs, outputs)
318
model.compile(loss="mse", metrics=["mae"])
319
320
# Evaluate with our custom test_step
321
x = np.random.random((1000, 32))
322
y = np.random.random((1000, 1))
323
model.evaluate(x, y)
324
325
"""
326
## Wrapping up: an end-to-end GAN example
327
328
Let's walk through an end-to-end example that leverages everything you just learned.
329
330
Let's consider:
331
332
- A generator network meant to generate 28x28x1 images.
333
- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and
334
"real").
335
- One optimizer for each.
336
- A loss function to train the discriminator.
337
"""
338
339
# Create the discriminator
340
discriminator = keras.Sequential(
341
[
342
keras.Input(shape=(28, 28, 1)),
343
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
344
layers.LeakyReLU(negative_slope=0.2),
345
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
346
layers.LeakyReLU(negative_slope=0.2),
347
layers.GlobalMaxPooling2D(),
348
layers.Dense(1),
349
],
350
name="discriminator",
351
)
352
353
# Create the generator
354
latent_dim = 128
355
generator = keras.Sequential(
356
[
357
keras.Input(shape=(latent_dim,)),
358
# We want to generate 128 coefficients to reshape into a 7x7x128 map
359
layers.Dense(7 * 7 * 128),
360
layers.LeakyReLU(negative_slope=0.2),
361
layers.Reshape((7, 7, 128)),
362
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
363
layers.LeakyReLU(negative_slope=0.2),
364
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
365
layers.LeakyReLU(negative_slope=0.2),
366
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
367
],
368
name="generator",
369
)
370
371
"""
372
Here's a feature-complete GAN class, overriding `compile()` to use its own signature,
373
and implementing the entire GAN algorithm in 17 lines in `train_step`:
374
"""
375
376
377
class GAN(keras.Model):
378
def __init__(self, discriminator, generator, latent_dim):
379
super().__init__()
380
self.discriminator = discriminator
381
self.generator = generator
382
self.latent_dim = latent_dim
383
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
384
self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
385
self.seed_generator = keras.random.SeedGenerator(1337)
386
self.built = True
387
388
@property
389
def metrics(self):
390
return [self.d_loss_tracker, self.g_loss_tracker]
391
392
def compile(self, d_optimizer, g_optimizer, loss_fn):
393
super().compile()
394
self.d_optimizer = d_optimizer
395
self.g_optimizer = g_optimizer
396
self.loss_fn = loss_fn
397
398
def train_step(self, real_images):
399
device = "cuda" if torch.cuda.is_available() else "cpu"
400
if isinstance(real_images, tuple) or isinstance(real_images, list):
401
real_images = real_images[0]
402
# Sample random points in the latent space
403
batch_size = real_images.shape[0]
404
random_latent_vectors = keras.random.normal(
405
shape=(batch_size, self.latent_dim), seed=self.seed_generator
406
)
407
408
# Decode them to fake images
409
generated_images = self.generator(random_latent_vectors)
410
411
# Combine them with real images
412
real_images = torch.tensor(real_images, device=device)
413
combined_images = torch.concat([generated_images, real_images], axis=0)
414
415
# Assemble labels discriminating real from fake images
416
labels = torch.concat(
417
[
418
torch.ones((batch_size, 1), device=device),
419
torch.zeros((batch_size, 1), device=device),
420
],
421
axis=0,
422
)
423
# Add random noise to the labels - important trick!
424
labels += 0.05 * keras.random.uniform(labels.shape, seed=self.seed_generator)
425
426
# Train the discriminator
427
self.zero_grad()
428
predictions = self.discriminator(combined_images)
429
d_loss = self.loss_fn(labels, predictions)
430
d_loss.backward()
431
grads = [v.value.grad for v in self.discriminator.trainable_weights]
432
with torch.no_grad():
433
self.d_optimizer.apply(grads, self.discriminator.trainable_weights)
434
435
# Sample random points in the latent space
436
random_latent_vectors = keras.random.normal(
437
shape=(batch_size, self.latent_dim), seed=self.seed_generator
438
)
439
440
# Assemble labels that say "all real images"
441
misleading_labels = torch.zeros((batch_size, 1), device=device)
442
443
# Train the generator (note that we should *not* update the weights
444
# of the discriminator)!
445
self.zero_grad()
446
predictions = self.discriminator(self.generator(random_latent_vectors))
447
g_loss = self.loss_fn(misleading_labels, predictions)
448
grads = g_loss.backward()
449
grads = [v.value.grad for v in self.generator.trainable_weights]
450
with torch.no_grad():
451
self.g_optimizer.apply(grads, self.generator.trainable_weights)
452
453
# Update metrics and return their value.
454
self.d_loss_tracker.update_state(d_loss)
455
self.g_loss_tracker.update_state(g_loss)
456
return {
457
"d_loss": self.d_loss_tracker.result(),
458
"g_loss": self.g_loss_tracker.result(),
459
}
460
461
462
"""
463
Let's test-drive it:
464
"""
465
466
# Prepare the dataset. We use both the training & test MNIST digits.
467
batch_size = 64
468
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
469
all_digits = np.concatenate([x_train, x_test])
470
all_digits = all_digits.astype("float32") / 255.0
471
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
472
473
# Create a TensorDataset
474
dataset = torch.utils.data.TensorDataset(
475
torch.from_numpy(all_digits), torch.from_numpy(all_digits)
476
)
477
# Create a DataLoader
478
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
479
480
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
481
gan.compile(
482
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
483
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
484
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
485
)
486
487
gan.fit(dataloader, epochs=1)
488
489
"""
490
The ideas behind deep learning are simple, so why should their implementation be painful?
491
"""
492
493