Path: blob/master/examples/generative/text_generation_with_miniature_gpt.py
8159 views
"""1Title: Text generation with a miniature GPT2Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)3Date created: 2020/05/294Last modified: 2020/05/295Description: Implement a miniature version of GPT and train it to generate text.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates how to implement an autoregressive language model13using a miniature version of the GPT model.14The model consists of a single Transformer block with causal masking15in its attention layer.16We use the text from the IMDB sentiment classification dataset for training17and generate new movie reviews for a given prompt.18When using this script with your own dataset, make sure it has at least191 million words.2021This example should be run with `tf-nightly>=2.3.0-dev20200531` or22with TensorFlow 2.3 or higher.2324**References:**2526- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)27- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)28- [GPT-3](https://arxiv.org/abs/2005.14165)29"""30"""31## Setup32"""33# We set the backend to TensorFlow. The code works with34# both `tensorflow` and `torch`. It does not work with JAX35# due to the behavior of `jax.numpy.tile` in a jit scope36# (used in `causal_attention_mask()`: `tile` in JAX does37# not support a dynamic `reps` argument.38# You can make the code work in JAX by wrapping the39# inside of the `causal_attention_mask` function in40# a decorator to prevent jit compilation:41# `with jax.ensure_compile_time_eval():`.42import os4344os.environ["KERAS_BACKEND"] = "tensorflow"4546import keras47from keras import layers48from keras import ops49from keras.layers import TextVectorization50import numpy as np51import os52import string53import random54import tensorflow55import tensorflow.data as tf_data56import tensorflow.strings as tf_strings5758"""59## Implement a Transformer block as a layer60"""616263def causal_attention_mask(batch_size, n_dest, n_src, dtype):64"""65Mask the upper half of the dot product matrix in self attention.66This prevents flow of information from future tokens to current token.671's in the lower triangle, counting from the lower right corner.68"""69i = ops.arange(n_dest)[:, None]70j = ops.arange(n_src)71m = i >= j - n_src + n_dest72mask = ops.cast(m, dtype)73mask = ops.reshape(mask, [1, n_dest, n_src])74mult = ops.concatenate(75[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 076)77return ops.tile(mask, mult)787980class TransformerBlock(layers.Layer):81def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):82super().__init__()83self.att = layers.MultiHeadAttention(num_heads, embed_dim)84self.ffn = keras.Sequential(85[86layers.Dense(ff_dim, activation="relu"),87layers.Dense(embed_dim),88]89)90self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)91self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)92self.dropout1 = layers.Dropout(rate)93self.dropout2 = layers.Dropout(rate)9495def call(self, inputs):96input_shape = ops.shape(inputs)97batch_size = input_shape[0]98seq_len = input_shape[1]99causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")100attention_output = self.att(inputs, inputs, attention_mask=causal_mask)101attention_output = self.dropout1(attention_output)102out1 = self.layernorm1(inputs + attention_output)103ffn_output = self.ffn(out1)104ffn_output = self.dropout2(ffn_output)105return self.layernorm2(out1 + ffn_output)106107108"""109## Implement an embedding layer110111Create two separate embedding layers: one for tokens and one for token index112(positions).113"""114115116class TokenAndPositionEmbedding(layers.Layer):117def __init__(self, maxlen, vocab_size, embed_dim):118super().__init__()119self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)120self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)121122def call(self, x):123maxlen = ops.shape(x)[-1]124positions = ops.arange(0, maxlen, 1)125positions = self.pos_emb(positions)126x = self.token_emb(x)127return x + positions128129130"""131## Implement the miniature GPT model132"""133vocab_size = 20000 # Only consider the top 20k words134maxlen = 80 # Max sequence size135embed_dim = 256 # Embedding size for each token136num_heads = 2 # Number of attention heads137feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer138139140def create_model():141inputs = layers.Input(shape=(maxlen,), dtype="int32")142embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)143x = embedding_layer(inputs)144transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)145x = transformer_block(x)146outputs = layers.Dense(vocab_size)(x)147model = keras.Model(inputs=inputs, outputs=[outputs, x])148loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)149model.compile(150"adam",151loss=[loss_fn, None],152) # No loss and optimization based on word embeddings from transformer block153return model154155156"""157## Prepare the data for word-level language modelling158159Download the IMDB dataset and combine training and validation sets for a text160generation task.161"""162163"""shell164curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz165tar -xf aclImdb_v1.tar.gz166"""167168169batch_size = 128170171# The dataset contains each review in a separate text file172# The text files are present in four different folders173# Create a list all files174filenames = []175directories = [176"aclImdb/train/pos",177"aclImdb/train/neg",178"aclImdb/test/pos",179"aclImdb/test/neg",180]181for dir in directories:182for f in os.listdir(dir):183filenames.append(os.path.join(dir, f))184185print(f"{len(filenames)} files")186187# Create a dataset from text files188random.shuffle(filenames)189text_ds = tf_data.TextLineDataset(filenames)190text_ds = text_ds.shuffle(buffer_size=256)191text_ds = text_ds.batch(batch_size)192193194def custom_standardization(input_string):195"""Remove html line-break tags and handle punctuation"""196lowercased = tf_strings.lower(input_string)197stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ")198return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")199200201# Create a vectorization layer and adapt it to the text202vectorize_layer = TextVectorization(203standardize=custom_standardization,204max_tokens=vocab_size - 1,205output_mode="int",206output_sequence_length=maxlen + 1,207)208vectorize_layer.adapt(text_ds)209vocab = vectorize_layer.get_vocabulary() # To get words back from token indices210211212def prepare_lm_inputs_labels(text):213"""214Shift word sequences by 1 position so that the target for position (i) is215word at position (i+1). The model will use all words up till position (i)216to predict the next word.217"""218text = tensorflow.expand_dims(text, -1)219tokenized_sentences = vectorize_layer(text)220x = tokenized_sentences[:, :-1]221y = tokenized_sentences[:, 1:]222return x, y223224225text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)226text_ds = text_ds.prefetch(tf_data.AUTOTUNE)227228229"""230## Implement a Keras callback for generating text231"""232233234class TextGenerator(keras.callbacks.Callback):235"""A callback to generate text from a trained model.2361. Feed some starting prompt to the model2372. Predict probabilities for the next token2383. Sample the next token and add it to the next input239240Arguments:241max_tokens: Integer, the number of tokens to be generated after prompt.242start_tokens: List of integers, the token indices for the starting prompt.243index_to_word: List of strings, obtained from the TextVectorization layer.244top_k: Integer, sample from the `top_k` token predictions.245print_every: Integer, print after this many epochs.246"""247248def __init__(249self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1250):251self.max_tokens = max_tokens252self.start_tokens = start_tokens253self.index_to_word = index_to_word254self.print_every = print_every255self.k = top_k256257def sample_from(self, logits):258logits, indices = ops.top_k(logits, k=self.k, sorted=True)259indices = np.asarray(indices).astype("int32")260preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]261preds = np.asarray(preds).astype("float32")262return np.random.choice(indices, p=preds)263264def detokenize(self, number):265return self.index_to_word[number]266267def on_epoch_end(self, epoch, logs=None):268start_tokens = [_ for _ in self.start_tokens]269if (epoch + 1) % self.print_every != 0:270return271num_tokens_generated = 0272tokens_generated = []273while num_tokens_generated <= self.max_tokens:274pad_len = maxlen - len(start_tokens)275sample_index = len(start_tokens) - 1276if pad_len < 0:277x = start_tokens[:maxlen]278sample_index = maxlen - 1279elif pad_len > 0:280x = start_tokens + [0] * pad_len281else:282x = start_tokens283x = np.array([x])284y, _ = self.model.predict(x, verbose=0)285sample_token = self.sample_from(y[0][sample_index])286tokens_generated.append(sample_token)287start_tokens.append(sample_token)288num_tokens_generated = len(tokens_generated)289txt = " ".join(290[self.detokenize(_) for _ in self.start_tokens + tokens_generated]291)292print(f"generated text:\n{txt}\n")293294295# Tokenize starting prompt296word_to_index = {}297for index, word in enumerate(vocab):298word_to_index[word] = index299300start_prompt = "this movie is"301start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]302num_tokens_generated = 40303text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)304305306"""307## Train the model308309Note: This code should preferably be run on GPU.310"""311312model = create_model()313314model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])315316"""317## Relevant Chapters from Deep Learning with Python318- [Chapter 15: Language models and the Transformer](https://deeplearningwithpython.io/chapters/chapter15_language-models-and-the-transformer)319- [Chapter 16: Text generation](https://deeplearningwithpython.io/chapters/chapter16_text-generation)320"""321322323