Author: Sayak Paul
Date created: 2022/04/05
Last modified: 2026/02/10
Description: Distillation of Vision Transformers through attention.
View in Colab •
GitHub source
Introduction
In the original Vision Transformers (ViT) paper (Dosovitskiy et al.), the authors concluded that to perform on par with Convolutional Neural Networks (CNNs), ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly due to the lack of inductive biases in the ViT architecture -- unlike CNNs, they don't have layers that exploit locality. In a follow-up paper (Steiner et al.), the authors show that it is possible to substantially improve the performance of ViTs with stronger regularization and longer training.
Many groups have proposed different ways to deal with the problem of data-intensiveness of ViT training. One such way was shown in the Data-efficient image Transformers, (DeiT) paper (Touvron et al.). The authors introduced a distillation technique that is specific to transformer-based vision models. DeiT is among the first works to show that it's possible to train ViTs well without using larger datasets.
In this example, we implement the distillation recipe proposed in DeiT. This requires us to slightly tweak the original ViT architecture and write a custom training loop to implement the distillation recipe.
To comfortably navigate through this example, you'll be expected to know how a ViT and knowledge distillation work. The following are good resources in case you needed a refresher:
Imports
from typing import List
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers
tfds.disable_progress_bar()
keras.utils.set_random_seed(42)
```
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770754850.038391 5167 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770754850.043322 5167 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770754850.055075 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055088 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055089 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055090 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
```
Constants
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
PROJECTION_DIM * 4,
PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5
You probably noticed that DROPOUT_RATE has been set 0.0. Dropout has been used in the implementation to keep it complete. For smaller models (like the one used in this example), you don't need it, but for bigger models, using dropout helps.
Load the tf_flowers dataset and prepare preprocessing utilities
The authors use an array of different augmentation techniques, including MixUp (Zhang et al.), RandAugment (Cubuk et al.), and so on. However, to keep the example simple to work through, we'll discard them.
def preprocess_dataset(is_training=True):
def fn(image, label):
if is_training:
image = keras.ops.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
crop_top = tf.random.uniform((), 0, 21, dtype=tf.int32)
crop_left = tf.random.uniform((), 0, 21, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(
image,
offset_height=crop_top,
offset_width=crop_left,
target_height=RESOLUTION,
target_width=RESOLUTION,
)
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_left_right(image)
else:
image = keras.ops.image.resize(image, (RESOLUTION, RESOLUTION))
label = keras.ops.one_hot(label, num_classes=NUM_CLASSES)
return image, label
return fn
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(BATCH_SIZE * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
```
Number of training examples: 3303
Number of validation examples: 367
```
Implementing the DeiT variants of ViT
Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend it to support DeiT's components.
First, we'll implement a layer for Stochastic Depth (Huang et al.) which is used in DeiT for regularization.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, x, training=True):
if training:
keep_prob = 1 - self.drop_prob
shape = (keras.ops.shape(x)[0],) + (1,) * (len(keras.ops.shape(x)) - 1)
random_tensor = keep_prob + keras.random.uniform(
shape, 0, 1, seed=self.seed_generator
)
random_tensor = keras.ops.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
Now, we'll implement the MLP and Transformer blocks.
def mlp(x, dropout_rate: float, hidden_units: List):
"""FFN for a Transformer block."""
for idx, units in enumerate(hidden_units):
x = layers.Dense(
units,
activation="gelu" if idx == 0 else None,
)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer(drop_prob: float, name: str) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS,
key_dim=PROJECTION_DIM,
dropout=DROPOUT_RATE,
)(x1, x1)
attention_output = (
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
)
x2 = layers.Add()([attention_output, encoded_patches])
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, outputs, name=name)
We'll now implement a ViTClassifier class building on top of the components we just developed. Here we'll be following the original pooling strategy used in the ViT paper -- use a class token and use the feature representations corresponding to it for classification.
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(NUM_PATCHES, PROJECTION_DIM),
name="flatten_projection",
),
],
name="projection",
)
dpr = [x for x in keras.ops.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
self.transformer_blocks = [
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
for i in range(NUM_LAYERS)
]
self.dropout = layers.Dropout(DROPOUT_RATE)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
def build(self, input_shape):
self.positional_embedding = self.add_weight(
shape=(1, NUM_PATCHES + 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="pos_embedding",
)
self.cls_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="cls",
)
super().build(input_shape)
def call(self, inputs, training=True):
n = keras.ops.shape(inputs)[0]
projected_patches = self.projection(inputs)
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
projected_patches = keras.ops.concatenate(
[cls_token, projected_patches], axis=1
)
encoded_patches = (
self.positional_embedding + projected_patches
)
encoded_patches = self.dropout(encoded_patches)
for transformer_module in self.transformer_blocks:
encoded_patches = transformer_module(encoded_patches)
representation = self.layer_norm(encoded_patches)
encoded_patches = representation[:, 0]
output = self.head(encoded_patches)
return output
This class can be used standalone as ViT and is end-to-end trainable. Just remove the distilled phrase in MODEL_TYPE and it should work with vit_tiny = ViTClassifier(). Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken from the DeiT paper):

Apart from the class token, DeiT has another token for distillation. During distillation, the logits corresponding to the class token are compared to the true labels, and the logits corresponding to the distillation token are compared to the teacher's predictions.
class ViTDistilled(ViTClassifier):
def __init__(self, regular_training=False, **kwargs):
super().__init__(**kwargs)
self.num_tokens = 2
self.regular_training = regular_training
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
self.head_dist = layers.Dense(
NUM_CLASSES,
name="distillation_head",
)
def build(self, input_shape):
self.cls_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="cls",
)
self.dist_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="dist_token",
)
self.positional_embedding = self.add_weight(
shape=(1, NUM_PATCHES + self.num_tokens, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="pos_embedding",
)
def call(self, inputs, training=True):
n = keras.ops.shape(inputs)[0]
projected_patches = self.projection(inputs)
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
dist_token = keras.ops.tile(self.dist_token, (n, 1, 1))
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
dist_token = keras.ops.cast(dist_token, projected_patches.dtype)
projected_patches = keras.ops.concatenate(
[cls_token, dist_token, projected_patches], axis=1
)
encoded_patches = (
self.positional_embedding + projected_patches
)
encoded_patches = self.dropout(encoded_patches)
for transformer_module in self.transformer_blocks:
encoded_patches = transformer_module(encoded_patches)
representation = self.layer_norm(encoded_patches)
x, x_dist = (
self.head(representation[:, 0]),
self.head_dist(representation[:, 1]),
)
if training and not self.regular_training:
return x, x_dist
return (x + x_dist) / 2
Let's verify if the ViTDistilled class can be initialized and called as expected.
deit_tiny_distilled = ViTDistilled()
dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
Implementing the trainer
Unlike what happens in standard knowledge distillation (Hinton et al.), where a temperature-scaled softmax is used as well as KL divergence, DeiT authors use the following loss function:

Here,
CE is cross-entropy
psi is the softmax function
Z_s denotes student predictions
y denotes true labels
y_t denotes teacher predictions
class DeiT(keras.Model):
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
self.accuracy_metric = keras.metrics.CategoricalAccuracy(name="accuracy")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.student_loss_tracker)
metrics.append(self.dist_loss_tracker)
metrics.append(self.accuracy_metric)
return metrics
def compile(
self,
optimizer,
student_loss_fn,
distillation_loss_fn,
):
super().compile(optimizer=optimizer)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
def train_step(self, data):
x, y = data
x_student = keras.ops.cast(x, "float32") / 255.0
x_teacher = keras.ops.cast(x, "float32")
teacher_output = self.teacher(x_teacher, training=False)
if isinstance(teacher_output, dict):
teacher_output = list(teacher_output.values())[0]
teacher_predictions = keras.ops.nn.softmax(teacher_output, -1)
with tf.GradientTape() as tape:
cls_predictions, dist_predictions = self.student(x_student, training=True)
student_loss = self.student_loss_fn(y, cls_predictions)
distillation_loss = self.distillation_loss_fn(
teacher_predictions, dist_predictions
)
loss = (student_loss + distillation_loss) / 2
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
student_predictions = (cls_predictions + dist_predictions) / 2
self.accuracy_metric.update_state(y, student_predictions)
self.dist_loss_tracker.update_state(distillation_loss)
self.student_loss_tracker.update_state(student_loss)
return {
"loss": loss,
"student_loss": self.student_loss_tracker.result(),
"distillation_loss": self.dist_loss_tracker.result(),
"accuracy": self.accuracy_metric.result(),
}
def test_step(self, data):
x, y = data
x_normalized = keras.ops.cast(x, "float32") / 255.0
y_prediction = self.student(x_normalized, training=False)
student_loss = self.student_loss_fn(y, y_prediction)
self.accuracy_metric.update_state(y, y_prediction)
self.student_loss_tracker.update_state(student_loss)
return {
"loss": student_loss,
"student_loss": self.student_loss_tracker.result(),
"accuracy": self.accuracy_metric.result(),
}
def call(self, inputs):
inputs_normalized = keras.ops.cast(inputs, "float32") / 255.0
return self.student(inputs_normalized, training=False)
Load the teacher model
This model is based on the BiT family of ResNets (Kolesnikov et al.) fine-tuned on the tf_flowers dataset. You can refer to this notebook to know how the training was performed. The teacher model has about 212 Million parameters which is about 40x more than the student.
!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.layers.TFSMLayer(
"bit_teacher_flowers", call_endpoint="serving_default"
)
Training through distillation
deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
optimizer=keras.optimizers.AdamW(
weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled
),
student_loss_fn=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
13/13 ━━━━━━━━━━━━━━━━━━━━ 130s 8s/step - accuracy: 0.2150 - distillation_loss: 2.1021 - loss: 0.0000e+00 - student_loss: 1.8120 - val_accuracy: 0.2616 - val_loss: 1.6223 - val_student_loss: 1.6278
Epoch 2/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2416 - distillation_loss: 1.6185 - loss: 0.0000e+00 - student_loss: 1.6297 - val_accuracy: 0.1662 - val_loss: 1.6018 - val_student_loss: 1.6075
Epoch 3/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 104s 8s/step - accuracy: 0.2467 - distillation_loss: 1.6028 - loss: 0.0000e+00 - student_loss: 1.6087 - val_accuracy: 0.2316 - val_loss: 1.5954 - val_student_loss: 1.6009
Epoch 4/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2349 - distillation_loss: 1.5968 - loss: 0.0000e+00 - student_loss: 1.6022 - val_accuracy: 0.2289 - val_loss: 1.5922 - val_student_loss: 1.6017
Epoch 5/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2634 - distillation_loss: 1.5902 - loss: 0.0000e+00 - student_loss: 1.5928 - val_accuracy: 0.3025 - val_loss: 1.5703 - val_student_loss: 1.5795
Epoch 6/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3279 - distillation_loss: 1.5441 - loss: 0.0000e+00 - student_loss: 1.5456 - val_accuracy: 0.3515 - val_loss: 1.4880 - val_student_loss: 1.4937
Epoch 7/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3966 - distillation_loss: 1.4085 - loss: 0.0000e+00 - student_loss: 1.4534 - val_accuracy: 0.3706 - val_loss: 1.4348 - val_student_loss: 1.4335
Epoch 8/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3890 - distillation_loss: 1.3647 - loss: 0.0000e+00 - student_loss: 1.4229 - val_accuracy: 0.3297 - val_loss: 1.4575 - val_student_loss: 1.4463
Epoch 9/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4223 - distillation_loss: 1.3332 - loss: 0.0000e+00 - student_loss: 1.3850 - val_accuracy: 0.4114 - val_loss: 1.3888 - val_student_loss: 1.3763
Epoch 10/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4475 - distillation_loss: 1.2577 - loss: 0.0000e+00 - student_loss: 1.3548 - val_accuracy: 0.4441 - val_loss: 1.3202 - val_student_loss: 1.3331
Epoch 11/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4717 - distillation_loss: 1.2107 - loss: 0.0000e+00 - student_loss: 1.2995 - val_accuracy: 0.4632 - val_loss: 1.3016 - val_student_loss: 1.2872
Epoch 12/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5017 - distillation_loss: 1.1562 - loss: 0.0000e+00 - student_loss: 1.2542 - val_accuracy: 0.5395 - val_loss: 1.2761 - val_student_loss: 1.2575
Epoch 13/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5328 - distillation_loss: 1.1119 - loss: 0.0000e+00 - student_loss: 1.2223 - val_accuracy: 0.5068 - val_loss: 1.2102 - val_student_loss: 1.2383
Epoch 14/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 102s 8s/step - accuracy: 0.5655 - distillation_loss: 1.0595 - loss: 0.0000e+00 - student_loss: 1.1837 - val_accuracy: 0.5722 - val_loss: 1.1773 - val_student_loss: 1.1774
Epoch 15/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5998 - distillation_loss: 1.0133 - loss: 0.0000e+00 - student_loss: 1.1465 - val_accuracy: 0.5204 - val_loss: 1.2519 - val_student_loss: 1.2340
Epoch 16/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6110 - distillation_loss: 0.9992 - loss: 0.0000e+00 - student_loss: 1.1359 - val_accuracy: 0.6104 - val_loss: 1.0947 - val_student_loss: 1.1090
Epoch 17/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6191 - distillation_loss: 0.9635 - loss: 0.0000e+00 - student_loss: 1.1101 - val_accuracy: 0.6076 - val_loss: 1.0678 - val_student_loss: 1.0952
Epoch 18/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6400 - distillation_loss: 0.9460 - loss: 0.0000e+00 - student_loss: 1.0902 - val_accuracy: 0.6076 - val_loss: 1.0256 - val_student_loss: 1.0681
Epoch 19/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6340 - distillation_loss: 0.9411 - loss: 0.0000e+00 - student_loss: 1.0943 - val_accuracy: 0.6213 - val_loss: 1.0353 - val_student_loss: 1.0702
Epoch 20/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6506 - distillation_loss: 0.9121 - loss: 0.0000e+00 - student_loss: 1.0674 - val_accuracy: 0.6376 - val_loss: 1.0027 - val_student_loss: 1.0602
</div>
If we had trained the same model (the `ViTClassifier`) from scratch with the exact same
hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code
to reproduce this result:
vit_tiny = ViTClassifier()
inputs = keras.Input((RESOLUTION, RESOLUTION, 3)) x = keras.layers.Rescaling(scale=1./255)(inputs) outputs = deit_tiny(x) model = keras.Model(inputs, outputs)
model.compile(...) model.fit(...)
---
## Notes
* Through the use of distillation, we're effectively transferring the inductive biases of
a CNN-based teacher model.
* Interestingly enough, this distillation strategy works better with a CNN as the teacher
model rather than a Transformer as shown in the paper.
* The use of regularization to train DeiT models is very important.
* ViT models are initialized with a combination of different initializers including
truncated normal, random normal, Glorot uniform, etc. If you're looking for
end-to-end reproduction of the original results, don't forget to initialize the ViTs well.
* If you want to explore the pre-trained DeiT models in Keras with code
for fine-tuning, [check out these models on TF-Hub](https:
---
## Acknowledgements
* Ross Wightman for keeping
[`timm`](https:
updated with readable implementations. I referred to the implementations of ViT and DeiT
a lot during implementing them in Keras.
* [Aritra Roy Gosthipaty](https:
who implemented some portions of the `ViTClassifier` in another project.
* [Google Developers Experts](https:
program for supporting me with GCP credits which were used to run experiments for this
example.
Example available on HuggingFace:
| Trained Model | Demo |
| :--: | :--: |
| [![Generic badge](https:
---
## Relevant Chapters from Deep Learning with Python
- [Chapter 8: Image classification](https:
- [Chapter 15: Language models and the Transformer](https: