Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/writing_a_custom_training_loop_in_torch.py
3273 views
1
"""
2
Title: Writing a training loop from scratch in PyTorch
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/06/25
5
Last modified: 2023/06/25
6
Description: Writing low-level training & evaluation loops in PyTorch.
7
Accelerator: None
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import os
15
16
# This guide can only be run with the torch backend.
17
os.environ["KERAS_BACKEND"] = "torch"
18
19
import torch
20
import keras
21
import numpy as np
22
23
"""
24
## Introduction
25
26
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
27
Their usage is covered in the guide
28
[Training & evaluation with the built-in methods](/guides/training_with_built_in_methods/).
29
30
If you want to customize the learning algorithm of your model while still leveraging
31
the convenience of `fit()`
32
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
33
implement your own `train_step()` method, which
34
is called repeatedly during `fit()`.
35
36
Now, if you want very low-level control over training & evaluation, you should write
37
your own training & evaluation loops from scratch. This is what this guide is about.
38
"""
39
40
"""
41
## A first end-to-end example
42
43
To write a custom training loop, we need the following ingredients:
44
45
- A model to train, of course.
46
- An optimizer. You could either use a `keras.optimizers` optimizer,
47
or a native PyTorch optimizer from `torch.optim`.
48
- A loss function. You could either use a `keras.losses` loss,
49
or a native PyTorch loss from `torch.nn`.
50
- A dataset. You could use any format: a `tf.data.Dataset`,
51
a PyTorch `DataLoader`, a Python generator, etc.
52
53
Let's line them up. We'll use torch-native objects in each case --
54
except, of course, for the Keras model.
55
56
First, let's get the model and the MNIST dataset:
57
"""
58
59
60
# Let's consider a simple MNIST model
61
def get_model():
62
inputs = keras.Input(shape=(784,), name="digits")
63
x1 = keras.layers.Dense(64, activation="relu")(inputs)
64
x2 = keras.layers.Dense(64, activation="relu")(x1)
65
outputs = keras.layers.Dense(10, name="predictions")(x2)
66
model = keras.Model(inputs=inputs, outputs=outputs)
67
return model
68
69
70
# Create load up the MNIST dataset and put it in a torch DataLoader
71
# Prepare the training dataset.
72
batch_size = 32
73
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
74
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
75
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
76
y_train = keras.utils.to_categorical(y_train)
77
y_test = keras.utils.to_categorical(y_test)
78
79
# Reserve 10,000 samples for validation.
80
x_val = x_train[-10000:]
81
y_val = y_train[-10000:]
82
x_train = x_train[:-10000]
83
y_train = y_train[:-10000]
84
85
# Create torch Datasets
86
train_dataset = torch.utils.data.TensorDataset(
87
torch.from_numpy(x_train), torch.from_numpy(y_train)
88
)
89
val_dataset = torch.utils.data.TensorDataset(
90
torch.from_numpy(x_val), torch.from_numpy(y_val)
91
)
92
93
# Create DataLoaders for the Datasets
94
train_dataloader = torch.utils.data.DataLoader(
95
train_dataset, batch_size=batch_size, shuffle=True
96
)
97
val_dataloader = torch.utils.data.DataLoader(
98
val_dataset, batch_size=batch_size, shuffle=False
99
)
100
101
"""
102
Next, here's our PyTorch optimizer and our PyTorch loss function:
103
"""
104
105
# Instantiate a torch optimizer
106
model = get_model()
107
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
108
109
# Instantiate a torch loss function
110
loss_fn = torch.nn.CrossEntropyLoss()
111
112
"""
113
Let's train our model using mini-batch gradient with a custom training loop.
114
115
Calling `loss.backward()` on a loss tensor triggers backpropagation.
116
Once that's done, your optimizer is magically aware of the gradients for each variable
117
and can update its variables, which is done via `optimizer.step()`.
118
Tensors, variables, optimizers are all interconnected to one another via hidden global state.
119
Also, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't
120
get the right gradients for your variables.
121
122
Here's our training loop, step by step:
123
124
- We open a `for` loop that iterates over epochs
125
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
126
- For each batch, we call the model on the input data to retrieve the predictions,
127
then we use them to compute a loss value
128
- We call `loss.backward()` to
129
- Outside the scope, we retrieve the gradients of the weights
130
of the model with regard to the loss
131
- Finally, we use the optimizer to update the weights of the model based on the
132
gradients
133
"""
134
135
epochs = 3
136
for epoch in range(epochs):
137
for step, (inputs, targets) in enumerate(train_dataloader):
138
# Forward pass
139
logits = model(inputs)
140
loss = loss_fn(logits, targets)
141
142
# Backward pass
143
model.zero_grad()
144
loss.backward()
145
146
# Optimizer variable updates
147
optimizer.step()
148
149
# Log every 100 batches.
150
if step % 100 == 0:
151
print(
152
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
153
)
154
print(f"Seen so far: {(step + 1) * batch_size} samples")
155
156
"""
157
As an alternative, let's look at what the loop looks like when using a Keras optimizer
158
and a Keras loss function.
159
160
Important differences:
161
162
- You retrieve the gradients for the variables via `v.value.grad`,
163
called on each trainable variable.
164
- You update your variables via `optimizer.apply()`, which must be
165
called in a `torch.no_grad()` scope.
166
167
**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs
168
as well as Python `unittest` APIs use the argument order convention
169
`fn(y_true, y_pred)` (reference values first, predicted values second),
170
PyTorch actually uses `fn(y_pred, y_true)` for its losses.
171
So make sure to invert the order of `logits` and `targets`.
172
"""
173
174
model = get_model()
175
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
176
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
177
178
for epoch in range(epochs):
179
print(f"\nStart of epoch {epoch}")
180
for step, (inputs, targets) in enumerate(train_dataloader):
181
# Forward pass
182
logits = model(inputs)
183
loss = loss_fn(targets, logits)
184
185
# Backward pass
186
model.zero_grad()
187
trainable_weights = [v for v in model.trainable_weights]
188
189
# Call torch.Tensor.backward() on the loss to compute gradients
190
# for the weights.
191
loss.backward()
192
gradients = [v.value.grad for v in trainable_weights]
193
194
# Update weights
195
with torch.no_grad():
196
optimizer.apply(gradients, trainable_weights)
197
198
# Log every 100 batches.
199
if step % 100 == 0:
200
print(
201
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
202
)
203
print(f"Seen so far: {(step + 1) * batch_size} samples")
204
205
"""
206
## Low-level handling of metrics
207
208
Let's add metrics monitoring to this basic training loop.
209
210
You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training
211
loops written from scratch. Here's the flow:
212
213
- Instantiate the metric at the start of the loop
214
- Call `metric.update_state()` after each batch
215
- Call `metric.result()` when you need to display the current value of the metric
216
- Call `metric.reset_state()` when you need to clear the state of the metric
217
(typically at the end of an epoch)
218
219
Let's use this knowledge to compute `CategoricalAccuracy` on training and
220
validation data at the end of each epoch:
221
"""
222
223
# Get a fresh model
224
model = get_model()
225
226
# Instantiate an optimizer to train the model.
227
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
228
# Instantiate a loss function.
229
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
230
231
# Prepare the metrics.
232
train_acc_metric = keras.metrics.CategoricalAccuracy()
233
val_acc_metric = keras.metrics.CategoricalAccuracy()
234
235
"""
236
Here's our training & evaluation loop:
237
"""
238
239
for epoch in range(epochs):
240
print(f"\nStart of epoch {epoch}")
241
for step, (inputs, targets) in enumerate(train_dataloader):
242
# Forward pass
243
logits = model(inputs)
244
loss = loss_fn(targets, logits)
245
246
# Backward pass
247
model.zero_grad()
248
trainable_weights = [v for v in model.trainable_weights]
249
250
# Call torch.Tensor.backward() on the loss to compute gradients
251
# for the weights.
252
loss.backward()
253
gradients = [v.value.grad for v in trainable_weights]
254
255
# Update weights
256
with torch.no_grad():
257
optimizer.apply(gradients, trainable_weights)
258
259
# Update training metric.
260
train_acc_metric.update_state(targets, logits)
261
262
# Log every 100 batches.
263
if step % 100 == 0:
264
print(
265
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
266
)
267
print(f"Seen so far: {(step + 1) * batch_size} samples")
268
269
# Display metrics at the end of each epoch.
270
train_acc = train_acc_metric.result()
271
print(f"Training acc over epoch: {float(train_acc):.4f}")
272
273
# Reset training metrics at the end of each epoch
274
train_acc_metric.reset_state()
275
276
# Run a validation loop at the end of each epoch.
277
for x_batch_val, y_batch_val in val_dataloader:
278
val_logits = model(x_batch_val, training=False)
279
# Update val metrics
280
val_acc_metric.update_state(y_batch_val, val_logits)
281
val_acc = val_acc_metric.result()
282
val_acc_metric.reset_state()
283
print(f"Validation acc: {float(val_acc):.4f}")
284
285
286
"""
287
## Low-level handling of losses tracked by the model
288
289
Layers & models recursively track any losses created during the forward pass
290
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
291
values are available via the property `model.losses`
292
at the end of the forward pass.
293
294
If you want to be using these loss components, you should sum them
295
and add them to the main loss in your training step.
296
297
Consider this layer, that creates an activity regularization loss:
298
"""
299
300
301
class ActivityRegularizationLayer(keras.layers.Layer):
302
def call(self, inputs):
303
self.add_loss(1e-2 * torch.sum(inputs))
304
return inputs
305
306
307
"""
308
Let's build a really simple model that uses it:
309
"""
310
311
inputs = keras.Input(shape=(784,), name="digits")
312
x = keras.layers.Dense(64, activation="relu")(inputs)
313
# Insert activity regularization as a layer
314
x = ActivityRegularizationLayer()(x)
315
x = keras.layers.Dense(64, activation="relu")(x)
316
outputs = keras.layers.Dense(10, name="predictions")(x)
317
318
model = keras.Model(inputs=inputs, outputs=outputs)
319
320
"""
321
Here's what our training loop should look like now:
322
"""
323
324
# Get a fresh model
325
model = get_model()
326
327
# Instantiate an optimizer to train the model.
328
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
329
# Instantiate a loss function.
330
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
331
332
# Prepare the metrics.
333
train_acc_metric = keras.metrics.CategoricalAccuracy()
334
val_acc_metric = keras.metrics.CategoricalAccuracy()
335
336
for epoch in range(epochs):
337
print(f"\nStart of epoch {epoch}")
338
for step, (inputs, targets) in enumerate(train_dataloader):
339
# Forward pass
340
logits = model(inputs)
341
loss = loss_fn(targets, logits)
342
if model.losses:
343
loss = loss + torch.sum(*model.losses)
344
345
# Backward pass
346
model.zero_grad()
347
trainable_weights = [v for v in model.trainable_weights]
348
349
# Call torch.Tensor.backward() on the loss to compute gradients
350
# for the weights.
351
loss.backward()
352
gradients = [v.value.grad for v in trainable_weights]
353
354
# Update weights
355
with torch.no_grad():
356
optimizer.apply(gradients, trainable_weights)
357
358
# Update training metric.
359
train_acc_metric.update_state(targets, logits)
360
361
# Log every 100 batches.
362
if step % 100 == 0:
363
print(
364
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
365
)
366
print(f"Seen so far: {(step + 1) * batch_size} samples")
367
368
# Display metrics at the end of each epoch.
369
train_acc = train_acc_metric.result()
370
print(f"Training acc over epoch: {float(train_acc):.4f}")
371
372
# Reset training metrics at the end of each epoch
373
train_acc_metric.reset_state()
374
375
# Run a validation loop at the end of each epoch.
376
for x_batch_val, y_batch_val in val_dataloader:
377
val_logits = model(x_batch_val, training=False)
378
# Update val metrics
379
val_acc_metric.update_state(y_batch_val, val_logits)
380
val_acc = val_acc_metric.result()
381
val_acc_metric.reset_state()
382
print(f"Validation acc: {float(val_acc):.4f}")
383
384
"""
385
That's it!
386
"""
387
388