Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/text_classification_with_transformer.py
8133 views
1
"""
2
Title: Text classification with Transformer
3
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4
Date created: 2020/05/10
5
Last modified: 2024/01/18
6
Description: Implement a Transformer block as a Keras layer and use it for text classification.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
9
"""
10
11
"""
12
## Setup
13
"""
14
15
import keras
16
from keras import ops
17
from keras import layers
18
19
"""
20
## Implement a Transformer block as a layer
21
"""
22
23
24
class TransformerBlock(layers.Layer):
25
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
26
super().__init__()
27
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
28
self.ffn = keras.Sequential(
29
[
30
layers.Dense(ff_dim, activation="relu"),
31
layers.Dense(embed_dim),
32
]
33
)
34
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
35
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
36
self.dropout1 = layers.Dropout(rate)
37
self.dropout2 = layers.Dropout(rate)
38
39
def call(self, inputs):
40
attn_output = self.att(inputs, inputs)
41
attn_output = self.dropout1(attn_output)
42
out1 = self.layernorm1(inputs + attn_output)
43
ffn_output = self.ffn(out1)
44
ffn_output = self.dropout2(ffn_output)
45
return self.layernorm2(out1 + ffn_output)
46
47
48
"""
49
## Implement embedding layer
50
51
Two separate embedding layers, one for tokens, one for token index (positions).
52
"""
53
54
55
class TokenAndPositionEmbedding(layers.Layer):
56
def __init__(self, maxlen, vocab_size, embed_dim):
57
super().__init__()
58
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
59
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
60
61
def call(self, x):
62
maxlen = ops.shape(x)[-1]
63
positions = ops.arange(start=0, stop=maxlen, step=1)
64
positions = self.pos_emb(positions)
65
x = self.token_emb(x)
66
return x + positions
67
68
69
"""
70
## Download and prepare dataset
71
"""
72
73
vocab_size = 20000 # Only consider the top 20k words
74
maxlen = 200 # Only consider the first 200 words of each movie review
75
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
76
print(len(x_train), "Training sequences")
77
print(len(x_val), "Validation sequences")
78
x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
79
x_val = keras.utils.pad_sequences(x_val, maxlen=maxlen)
80
81
"""
82
## Create classifier model using transformer layer
83
84
Transformer layer outputs one vector for each time step of our input sequence.
85
Here, we take the mean across all time steps and
86
use a feed forward network on top of it to classify text.
87
"""
88
89
90
embed_dim = 32 # Embedding size for each token
91
num_heads = 2 # Number of attention heads
92
ff_dim = 32 # Hidden layer size in feed forward network inside transformer
93
94
inputs = layers.Input(shape=(maxlen,))
95
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
96
x = embedding_layer(inputs)
97
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
98
x = transformer_block(x)
99
x = layers.GlobalAveragePooling1D()(x)
100
x = layers.Dropout(0.1)(x)
101
x = layers.Dense(20, activation="relu")(x)
102
x = layers.Dropout(0.1)(x)
103
outputs = layers.Dense(2, activation="softmax")(x)
104
105
model = keras.Model(inputs=inputs, outputs=outputs)
106
107
108
"""
109
## Train and Evaluate
110
"""
111
112
model.compile(
113
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
114
)
115
history = model.fit(
116
x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)
117
)
118
119
"""
120
## Relevant Chapters from Deep Learning with Python
121
- [Chapter 14: Text classification](https://deeplearningwithpython.io/chapters/chapter14_text-classification)
122
- [Chapter 15: Language models and the Transformer](https://deeplearningwithpython.io/chapters/chapter15_language-models-and-the-transformer)
123
"""
124
125