Path: blob/master/notebooks/book1/20/skipgram_jax.ipynb
1192 views
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/20/skipgram_torch.ipynb
Learning word emebddings using skipgram with negative sampling.
Based on D2L 14.3 http://d2l.ai/chapter_natural-language-processing-pretraining/word-embedding-dataset.html and 14.4 of http://d2l.ai/chapter_natural-language-processing-pretraining/word2vec-pretraining.html.
Data
We use the Penn Tree Bank (PTB), which is a small but commonly-used corpus derived from the Wall Stree Journal.
We make a vocabulary, replacing any word that occurs less than 10 times with unk.
Mikolov suggested keeping word with probability where is a threshold, and is the empirical frequency of word .
We compare the frequency of certain common and rare words in the original and subsampled data below.
Let's tokenize the subsampled data.
Extracting central target words and their contexts
We randomly sample a context length for each central word, up to some maximum length, and then extract all the context words as a list of lists.
Example. Suppose we have a corpus with 2 sentences of length 7 and 3, and we use a max context of size 2. Here are the centers and contexts.
Extract context for the full dataset.
Negative sampling
For speed, we define a sampling class that pre-computes 10,000 random indices from the weighted distribution, using a single call to random.choices
, and then sequentially returns elements of this list. If we reach the end of the cache, we refill it.
Example.
Now we generate negatives for each context. These are drawn from .
Minibatching
Suppose the 'th central word has contexts and noise words. Since might be different for each (due to edge effects), the minibatch will be ragged. To fix this, we pad to a maximum length , and then create a validity mask of length , where 0 means invalid location (to be ignored when computing the loss) and 1 means valid location. We assign the label vector to have 1's and 0's. (Some of these labels will be masked out.)
Example. We make a ragged minibatch with 2 examples, and then pad them to a standard size.
Dataloader
Now we put it altogether.
Let's print the first minibatch.
Model
The model just has 2 embedding matrices, and . The core computation is computing the logits, as shown below. The center variable has the shape (batch size, 1), while the contexts_and_negatives variable has the shape (batch size, max_len). These get embedded into size and . We permute the latter to and use matrix multiplication to get matrix of inner products between each center's embedding and each context's embedding.
Example. Assume the vocab size is 20 and we use embedding dimensions. We compute the logits for a minibatch of sequences, with max length .
Loss
We use masked binary cross entropy loss.
Different masks can lead to different results.
If we normalize by the number of valid masked entries, then predictions with the same per-token accuracy will score the same.
Training
Test
We find the nearest words to the query, where we measure similarity using cosine similarity
Pre-trained models
For better results, you should use a larger model that is trained on more data, such as those provided by the Spacy library. For a demo, see this script.