Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/eanet.py
7966 views
1
"""
2
Title: Image classification with EANet (External Attention Transformer)
3
Author: [ZhiYong Chang](https://github.com/czy00000)
4
Date created: 2021/10/19
5
Last modified: 2023/07/18
6
Description: Image classification with a Transformer that leverages external attention.
7
Accelerator: GPU
8
Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com)
9
"""
10
11
"""
12
## Introduction
13
14
This example implements the [EANet](https://arxiv.org/abs/2105.02358)
15
model for image classification, and demonstrates it on the CIFAR-100 dataset.
16
EANet introduces a novel attention mechanism
17
named ***external attention***, based on two external, small, learnable, and
18
shared memories, which can be implemented easily by simply using two cascaded
19
linear layers and two normalization layers. It conveniently replaces self-attention
20
as used in existing architectures. External attention has linear complexity, as it only
21
implicitly considers the correlations between all samples.
22
"""
23
24
"""
25
## Setup
26
"""
27
28
import keras
29
from keras import layers
30
from keras import ops
31
32
import matplotlib.pyplot as plt
33
34
"""
35
## Prepare the data
36
"""
37
38
num_classes = 100
39
input_shape = (32, 32, 3)
40
41
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
42
y_train = keras.utils.to_categorical(y_train, num_classes)
43
y_test = keras.utils.to_categorical(y_test, num_classes)
44
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
45
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
46
47
"""
48
## Configure the hyperparameters
49
"""
50
51
weight_decay = 0.0001
52
learning_rate = 0.001
53
label_smoothing = 0.1
54
validation_split = 0.2
55
batch_size = 128
56
num_epochs = 50
57
patch_size = 2 # Size of the patches to be extracted from the input images.
58
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
59
embedding_dim = 64 # Number of hidden units.
60
mlp_dim = 64
61
dim_coefficient = 4
62
num_heads = 4
63
attention_dropout = 0.2
64
projection_dropout = 0.2
65
num_transformer_blocks = 8 # Number of repetitions of the transformer layer
66
67
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
68
print(f"Patches per image: {num_patches}")
69
70
71
"""
72
## Use data augmentation
73
"""
74
75
data_augmentation = keras.Sequential(
76
[
77
layers.Normalization(),
78
layers.RandomFlip("horizontal"),
79
layers.RandomRotation(factor=0.1),
80
layers.RandomContrast(factor=0.1),
81
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
82
],
83
name="data_augmentation",
84
)
85
# Compute the mean and the variance of the training data for normalization.
86
data_augmentation.layers[0].adapt(x_train)
87
88
"""
89
## Implement the patch extraction and encoding layer
90
"""
91
92
93
class PatchExtract(layers.Layer):
94
def __init__(self, patch_size, **kwargs):
95
super().__init__(**kwargs)
96
self.patch_size = patch_size
97
98
def call(self, x):
99
B, C = ops.shape(x)[0], ops.shape(x)[-1]
100
x = ops.image.extract_patches(x, self.patch_size)
101
x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
102
return x
103
104
105
class PatchEmbedding(layers.Layer):
106
def __init__(self, num_patch, embed_dim, **kwargs):
107
super().__init__(**kwargs)
108
self.num_patch = num_patch
109
self.proj = layers.Dense(embed_dim)
110
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
111
112
def call(self, patch):
113
pos = ops.arange(start=0, stop=self.num_patch, step=1)
114
return self.proj(patch) + self.pos_embed(pos)
115
116
117
"""
118
## Implement the external attention block
119
"""
120
121
122
def external_attention(
123
x,
124
dim,
125
num_heads,
126
dim_coefficient=4,
127
attention_dropout=0,
128
projection_dropout=0,
129
):
130
_, num_patch, channel = x.shape
131
assert dim % num_heads == 0
132
num_heads = num_heads * dim_coefficient
133
134
x = layers.Dense(dim * dim_coefficient)(x)
135
# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
136
x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
137
x = ops.transpose(x, axes=[0, 2, 1, 3])
138
# a linear layer M_k
139
attn = layers.Dense(dim // dim_coefficient)(x)
140
# normalize attention map
141
attn = layers.Softmax(axis=2)(attn)
142
# dobule-normalization
143
attn = layers.Lambda(
144
lambda attn: ops.divide(
145
attn,
146
ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
147
)
148
)(attn)
149
attn = layers.Dropout(attention_dropout)(attn)
150
# a linear layer M_v
151
x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
152
x = ops.transpose(x, axes=[0, 2, 1, 3])
153
x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
154
# a linear layer to project original dim
155
x = layers.Dense(dim)(x)
156
x = layers.Dropout(projection_dropout)(x)
157
return x
158
159
160
"""
161
## Implement the MLP block
162
"""
163
164
165
def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
166
x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
167
x = layers.Dropout(drop_rate)(x)
168
x = layers.Dense(embedding_dim)(x)
169
x = layers.Dropout(drop_rate)(x)
170
return x
171
172
173
"""
174
## Implement the Transformer block
175
"""
176
177
178
def transformer_encoder(
179
x,
180
embedding_dim,
181
mlp_dim,
182
num_heads,
183
dim_coefficient,
184
attention_dropout,
185
projection_dropout,
186
attention_type="external_attention",
187
):
188
residual_1 = x
189
x = layers.LayerNormalization(epsilon=1e-5)(x)
190
if attention_type == "external_attention":
191
x = external_attention(
192
x,
193
embedding_dim,
194
num_heads,
195
dim_coefficient,
196
attention_dropout,
197
projection_dropout,
198
)
199
elif attention_type == "self_attention":
200
x = layers.MultiHeadAttention(
201
num_heads=num_heads,
202
key_dim=embedding_dim,
203
dropout=attention_dropout,
204
)(x, x)
205
x = layers.add([x, residual_1])
206
residual_2 = x
207
x = layers.LayerNormalization(epsilon=1e-5)(x)
208
x = mlp(x, embedding_dim, mlp_dim)
209
x = layers.add([x, residual_2])
210
return x
211
212
213
"""
214
## Implement the EANet model
215
"""
216
217
"""
218
The EANet model leverages external attention.
219
The computational complexity of traditional self attention is `O(d * N ** 2)`,
220
where `d` is the embedding size, and `N` is the number of patch.
221
the authors find that most pixels are closely related to just a few other
222
pixels, and an `N`-to-`N` attention matrix may be redundant.
223
So, they propose as an alternative an external
224
attention module where the computational complexity of external attention is `O(d * S * N)`.
225
As `d` and `S` are hyper-parameters,
226
the proposed algorithm is linear in the number of pixels. In fact, this is equivalent
227
to a drop patch operation, because a lot of information contained in a patch
228
in an image is redundant and unimportant.
229
"""
230
231
232
def get_model(attention_type="external_attention"):
233
inputs = layers.Input(shape=input_shape)
234
# Image augment
235
x = data_augmentation(inputs)
236
# Extract patches.
237
x = PatchExtract(patch_size)(x)
238
# Create patch embedding.
239
x = PatchEmbedding(num_patches, embedding_dim)(x)
240
# Create Transformer block.
241
for _ in range(num_transformer_blocks):
242
x = transformer_encoder(
243
x,
244
embedding_dim,
245
mlp_dim,
246
num_heads,
247
dim_coefficient,
248
attention_dropout,
249
projection_dropout,
250
attention_type,
251
)
252
253
x = layers.GlobalAveragePooling1D()(x)
254
outputs = layers.Dense(num_classes, activation="softmax")(x)
255
model = keras.Model(inputs=inputs, outputs=outputs)
256
return model
257
258
259
"""
260
## Train on CIFAR-100
261
262
"""
263
264
265
model = get_model(attention_type="external_attention")
266
267
model.compile(
268
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
269
optimizer=keras.optimizers.AdamW(
270
learning_rate=learning_rate, weight_decay=weight_decay
271
),
272
metrics=[
273
keras.metrics.CategoricalAccuracy(name="accuracy"),
274
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
275
],
276
)
277
278
history = model.fit(
279
x_train,
280
y_train,
281
batch_size=batch_size,
282
epochs=num_epochs,
283
validation_split=validation_split,
284
)
285
286
"""
287
### Let's visualize the training progress of the model.
288
289
"""
290
291
plt.plot(history.history["loss"], label="train_loss")
292
plt.plot(history.history["val_loss"], label="val_loss")
293
plt.xlabel("Epochs")
294
plt.ylabel("Loss")
295
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
296
plt.legend()
297
plt.grid()
298
plt.show()
299
300
"""
301
### Let's display the final results of the test on CIFAR-100.
302
303
"""
304
305
loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
306
print(f"Test loss: {round(loss, 2)}")
307
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
308
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
309
310
"""
311
EANet just replaces self attention in Vit with external attention.
312
The traditional Vit achieved a ~73% test top-5 accuracy and ~41 top-1 accuracy after
313
training 50 epochs, but with 0.6M parameters. Under the same experimental environment
314
and the same hyperparameters, The EANet model we just trained has just 0.3M parameters,
315
and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the
316
effectiveness of external attention.
317
318
We only show the training
319
process of EANet, you can train Vit under the same experimental conditions and observe
320
the test results.
321
"""
322
323