Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/20/word_analogies_jax.ipynb
1192 views
Kernel: Python 3.6.7 64-bit ('base': conda)

Open In Colab

Solving word analogies using pre-trained word embeddings

Based on D2L 14.7

http://d2l.ai/chapter_natural-language-processing-pretraining/similarity-analogy.html

import jax import jax.numpy as jnp import requests import zipfile import hashlib import os
# Required functions def download(name, cache_dir=os.path.join("..", "data")): """Download a file inserted into DATA_HUB, return the local filename.""" assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname, "rb") as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Hit cache print(f"Downloading {fname} from {url}...") r = requests.get(url, stream=True, verify=True) with open(fname, "wb") as f: f.write(r.content) return fname def download_extract(name, folder=None): """Download and extract a zip/tar file.""" fname = download(name) base_dir = os.path.dirname(fname) data_dir, ext = os.path.splitext(fname) if ext == ".zip": fp = zipfile.ZipFile(fname, "r") elif ext in (".tar", ".gz"): fp = tarfile.open(fname, "r") else: assert False, "Only zip/tar files can be extracted." fp.extractall(base_dir) return os.path.join(base_dir, folder) if folder else data_dir

Get pre-trained word embeddings

Pretrained embeddings taken from

GloVe website: https://nlp.stanford.edu/projects/glove/

fastText website: https://fasttext.cc/

DATA_HUB = dict() DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/" DATA_HUB["glove.6b.50d"] = (DATA_URL + "glove.6B.50d.zip", "0b8703943ccdb6eb788e6f091b8946e82231bc4d") DATA_HUB["glove.6b.100d"] = (DATA_URL + "glove.6B.100d.zip", "cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a") DATA_HUB["glove.42b.300d"] = (DATA_URL + "glove.42B.300d.zip", "b5116e234e9eb9076672cfeabf5469f3eec904fa") DATA_HUB["wiki.en"] = (DATA_URL + "wiki.en.zip", "c1816da3821ae9f43899be655002f6c723e91b88")
class TokenEmbedding: """Token Embedding.""" def __init__(self, embedding_name): self.idx_to_token, self.idx_to_vec = self._load_embedding(embedding_name) self.unknown_idx = 0 self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)} def _load_embedding(self, embedding_name): idx_to_token, idx_to_vec = ["<unk>"], [] # data_dir = d2l.download_extract(embedding_name) data_dir = download_extract(embedding_name) # GloVe website: https://nlp.stanford.edu/projects/glove/ # fastText website: https://fasttext.cc/ with open(os.path.join(data_dir, "vec.txt"), "r") as f: for line in f: elems = line.rstrip().split(" ") token, elems = elems[0], [float(elem) for elem in elems[1:]] # Skip header information, such as the top row in fastText if len(elems) > 1: idx_to_token.append(token) idx_to_vec.append(elems) idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec return idx_to_token, jnp.array(idx_to_vec) def __getitem__(self, tokens): indices = [self.token_to_idx.get(token, self.unknown_idx) for token in tokens] vecs = self.idx_to_vec[jnp.array(indices)] return vecs def __len__(self): return len(self.idx_to_token)

Get a 50dimensional glove embedding, with vocab size of 400k

glove_6b50d = TokenEmbedding("glove.6b.50d")
Downloading ../data/glove.6B.50d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.50d.zip...
len(glove_6b50d)
400001

Map from word to index and vice versa.

glove_6b50d.token_to_idx["beautiful"], glove_6b50d.idx_to_token[3367]
(3367, 'beautiful')
embedder = glove_6b50d
# embedder = TokenEmbedding('glove.6b.100d')
embedder.idx_to_vec.shape
(400001, 50)

Finding most similar words

def knn(W, x, k): # The added 1e-9 is for numerical stability cos = (W @ x.reshape(-1, 1)).reshape(-1) / ((jnp.sqrt(jnp.sum(W * W, axis=1) + 1e-9) * jnp.sqrt((x * x).sum()))) _, topk = jax.lax.top_k(cos, k=k) return topk, [cos[int(i)] for i in topk]
def get_similar_tokens(query_token, k, embed): topk, cos = knn(embed.idx_to_vec, embed[[query_token]], k + 1) for i, c in zip(topk[1:], cos[1:]): # Remove input words print(f"cosine sim={float(c):.3f}: {embed.idx_to_token[int(i)]}")
get_similar_tokens("man", 3, embedder)
cosine sim=0.886: woman cosine sim=0.856: boy cosine sim=0.845: another
get_similar_tokens("banana", 3, embedder)
cosine sim=0.815: bananas cosine sim=0.787: coconut cosine sim=0.758: pineapple

Word analogies

# We slightly modify D2L code so it works on the man:woman:king:queen example def get_analogy(token_a, token_b, token_c, embed): vecs = embed[[token_a, token_b, token_c]] x = vecs[1] - vecs[0] + vecs[2] topk, cos = knn(embed.idx_to_vec, x, 10) # remove word c from nearest neighbor idx_c = embed.token_to_idx[token_c] topk = list(topk) topk.remove(idx_c) return embed.idx_to_token[int(topk[0])]
get_analogy("man", "woman", "king", embedder)
'queen'
get_analogy("man", "woman", "son", embedder)
'daughter'
get_analogy("beijing", "china", "tokyo", embedder)
'japan'