Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/text_generation_with_miniature_gpt.py
8159 views
1
"""
2
Title: Text generation with a miniature GPT
3
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4
Date created: 2020/05/29
5
Last modified: 2020/05/29
6
Description: Implement a miniature version of GPT and train it to generate text.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates how to implement an autoregressive language model
14
using a miniature version of the GPT model.
15
The model consists of a single Transformer block with causal masking
16
in its attention layer.
17
We use the text from the IMDB sentiment classification dataset for training
18
and generate new movie reviews for a given prompt.
19
When using this script with your own dataset, make sure it has at least
20
1 million words.
21
22
This example should be run with `tf-nightly>=2.3.0-dev20200531` or
23
with TensorFlow 2.3 or higher.
24
25
**References:**
26
27
- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)
28
- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)
29
- [GPT-3](https://arxiv.org/abs/2005.14165)
30
"""
31
"""
32
## Setup
33
"""
34
# We set the backend to TensorFlow. The code works with
35
# both `tensorflow` and `torch`. It does not work with JAX
36
# due to the behavior of `jax.numpy.tile` in a jit scope
37
# (used in `causal_attention_mask()`: `tile` in JAX does
38
# not support a dynamic `reps` argument.
39
# You can make the code work in JAX by wrapping the
40
# inside of the `causal_attention_mask` function in
41
# a decorator to prevent jit compilation:
42
# `with jax.ensure_compile_time_eval():`.
43
import os
44
45
os.environ["KERAS_BACKEND"] = "tensorflow"
46
47
import keras
48
from keras import layers
49
from keras import ops
50
from keras.layers import TextVectorization
51
import numpy as np
52
import os
53
import string
54
import random
55
import tensorflow
56
import tensorflow.data as tf_data
57
import tensorflow.strings as tf_strings
58
59
"""
60
## Implement a Transformer block as a layer
61
"""
62
63
64
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
65
"""
66
Mask the upper half of the dot product matrix in self attention.
67
This prevents flow of information from future tokens to current token.
68
1's in the lower triangle, counting from the lower right corner.
69
"""
70
i = ops.arange(n_dest)[:, None]
71
j = ops.arange(n_src)
72
m = i >= j - n_src + n_dest
73
mask = ops.cast(m, dtype)
74
mask = ops.reshape(mask, [1, n_dest, n_src])
75
mult = ops.concatenate(
76
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
77
)
78
return ops.tile(mask, mult)
79
80
81
class TransformerBlock(layers.Layer):
82
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
83
super().__init__()
84
self.att = layers.MultiHeadAttention(num_heads, embed_dim)
85
self.ffn = keras.Sequential(
86
[
87
layers.Dense(ff_dim, activation="relu"),
88
layers.Dense(embed_dim),
89
]
90
)
91
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
92
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
93
self.dropout1 = layers.Dropout(rate)
94
self.dropout2 = layers.Dropout(rate)
95
96
def call(self, inputs):
97
input_shape = ops.shape(inputs)
98
batch_size = input_shape[0]
99
seq_len = input_shape[1]
100
causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
101
attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
102
attention_output = self.dropout1(attention_output)
103
out1 = self.layernorm1(inputs + attention_output)
104
ffn_output = self.ffn(out1)
105
ffn_output = self.dropout2(ffn_output)
106
return self.layernorm2(out1 + ffn_output)
107
108
109
"""
110
## Implement an embedding layer
111
112
Create two separate embedding layers: one for tokens and one for token index
113
(positions).
114
"""
115
116
117
class TokenAndPositionEmbedding(layers.Layer):
118
def __init__(self, maxlen, vocab_size, embed_dim):
119
super().__init__()
120
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
121
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
122
123
def call(self, x):
124
maxlen = ops.shape(x)[-1]
125
positions = ops.arange(0, maxlen, 1)
126
positions = self.pos_emb(positions)
127
x = self.token_emb(x)
128
return x + positions
129
130
131
"""
132
## Implement the miniature GPT model
133
"""
134
vocab_size = 20000 # Only consider the top 20k words
135
maxlen = 80 # Max sequence size
136
embed_dim = 256 # Embedding size for each token
137
num_heads = 2 # Number of attention heads
138
feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer
139
140
141
def create_model():
142
inputs = layers.Input(shape=(maxlen,), dtype="int32")
143
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
144
x = embedding_layer(inputs)
145
transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
146
x = transformer_block(x)
147
outputs = layers.Dense(vocab_size)(x)
148
model = keras.Model(inputs=inputs, outputs=[outputs, x])
149
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
150
model.compile(
151
"adam",
152
loss=[loss_fn, None],
153
) # No loss and optimization based on word embeddings from transformer block
154
return model
155
156
157
"""
158
## Prepare the data for word-level language modelling
159
160
Download the IMDB dataset and combine training and validation sets for a text
161
generation task.
162
"""
163
164
"""shell
165
curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
166
tar -xf aclImdb_v1.tar.gz
167
"""
168
169
170
batch_size = 128
171
172
# The dataset contains each review in a separate text file
173
# The text files are present in four different folders
174
# Create a list all files
175
filenames = []
176
directories = [
177
"aclImdb/train/pos",
178
"aclImdb/train/neg",
179
"aclImdb/test/pos",
180
"aclImdb/test/neg",
181
]
182
for dir in directories:
183
for f in os.listdir(dir):
184
filenames.append(os.path.join(dir, f))
185
186
print(f"{len(filenames)} files")
187
188
# Create a dataset from text files
189
random.shuffle(filenames)
190
text_ds = tf_data.TextLineDataset(filenames)
191
text_ds = text_ds.shuffle(buffer_size=256)
192
text_ds = text_ds.batch(batch_size)
193
194
195
def custom_standardization(input_string):
196
"""Remove html line-break tags and handle punctuation"""
197
lowercased = tf_strings.lower(input_string)
198
stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ")
199
return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
200
201
202
# Create a vectorization layer and adapt it to the text
203
vectorize_layer = TextVectorization(
204
standardize=custom_standardization,
205
max_tokens=vocab_size - 1,
206
output_mode="int",
207
output_sequence_length=maxlen + 1,
208
)
209
vectorize_layer.adapt(text_ds)
210
vocab = vectorize_layer.get_vocabulary() # To get words back from token indices
211
212
213
def prepare_lm_inputs_labels(text):
214
"""
215
Shift word sequences by 1 position so that the target for position (i) is
216
word at position (i+1). The model will use all words up till position (i)
217
to predict the next word.
218
"""
219
text = tensorflow.expand_dims(text, -1)
220
tokenized_sentences = vectorize_layer(text)
221
x = tokenized_sentences[:, :-1]
222
y = tokenized_sentences[:, 1:]
223
return x, y
224
225
226
text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
227
text_ds = text_ds.prefetch(tf_data.AUTOTUNE)
228
229
230
"""
231
## Implement a Keras callback for generating text
232
"""
233
234
235
class TextGenerator(keras.callbacks.Callback):
236
"""A callback to generate text from a trained model.
237
1. Feed some starting prompt to the model
238
2. Predict probabilities for the next token
239
3. Sample the next token and add it to the next input
240
241
Arguments:
242
max_tokens: Integer, the number of tokens to be generated after prompt.
243
start_tokens: List of integers, the token indices for the starting prompt.
244
index_to_word: List of strings, obtained from the TextVectorization layer.
245
top_k: Integer, sample from the `top_k` token predictions.
246
print_every: Integer, print after this many epochs.
247
"""
248
249
def __init__(
250
self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
251
):
252
self.max_tokens = max_tokens
253
self.start_tokens = start_tokens
254
self.index_to_word = index_to_word
255
self.print_every = print_every
256
self.k = top_k
257
258
def sample_from(self, logits):
259
logits, indices = ops.top_k(logits, k=self.k, sorted=True)
260
indices = np.asarray(indices).astype("int32")
261
preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
262
preds = np.asarray(preds).astype("float32")
263
return np.random.choice(indices, p=preds)
264
265
def detokenize(self, number):
266
return self.index_to_word[number]
267
268
def on_epoch_end(self, epoch, logs=None):
269
start_tokens = [_ for _ in self.start_tokens]
270
if (epoch + 1) % self.print_every != 0:
271
return
272
num_tokens_generated = 0
273
tokens_generated = []
274
while num_tokens_generated <= self.max_tokens:
275
pad_len = maxlen - len(start_tokens)
276
sample_index = len(start_tokens) - 1
277
if pad_len < 0:
278
x = start_tokens[:maxlen]
279
sample_index = maxlen - 1
280
elif pad_len > 0:
281
x = start_tokens + [0] * pad_len
282
else:
283
x = start_tokens
284
x = np.array([x])
285
y, _ = self.model.predict(x, verbose=0)
286
sample_token = self.sample_from(y[0][sample_index])
287
tokens_generated.append(sample_token)
288
start_tokens.append(sample_token)
289
num_tokens_generated = len(tokens_generated)
290
txt = " ".join(
291
[self.detokenize(_) for _ in self.start_tokens + tokens_generated]
292
)
293
print(f"generated text:\n{txt}\n")
294
295
296
# Tokenize starting prompt
297
word_to_index = {}
298
for index, word in enumerate(vocab):
299
word_to_index[word] = index
300
301
start_prompt = "this movie is"
302
start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
303
num_tokens_generated = 40
304
text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)
305
306
307
"""
308
## Train the model
309
310
Note: This code should preferably be run on GPU.
311
"""
312
313
model = create_model()
314
315
model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])
316
317
"""
318
## Relevant Chapters from Deep Learning with Python
319
- [Chapter 15: Language models and the Transformer](https://deeplearningwithpython.io/chapters/chapter15_language-models-and-the-transformer)
320
- [Chapter 16: Text generation](https://deeplearningwithpython.io/chapters/chapter16_text-generation)
321
"""
322
323