Path: blob/master/keras/text_classification/transformers/keras_word2vec.py
2594 views
import numpy as np1from tqdm import trange2from keras import layers, optimizers, Model3from sklearn.preprocessing import normalize4from sklearn.base import BaseEstimator, TransformerMixin5from keras.preprocessing.sequence import skipgrams, make_sampling_table678class KerasWord2VecVectorizer(BaseEstimator, TransformerMixin):9"""10Word vectors are averaged across to create the document-level vectors/features.1112Attributes13----------14word2index_ : dict[str, int]15Each distinct word in the corpus gets map to a numeric index.16e.g. {'unk': 0, 'film': 1}1718index2word_ : list[str]19Reverse napping of ``word2index_`` e.g. ['unk', 'film']2021vocab_size_ : int2223model_ : keras.models.Model2425"""2627def __init__(self, embed_size=100, window_size=5, batch_size=64, epochs=5000,28learning_rate=0.05, negative_samples=0.5, min_count=2,29use_sampling_table=True, sort_vocab=True):30self.min_count = min_count31self.embed_size = embed_size32self.sort_vocab = sort_vocab33self.window_size = window_size34self.batch_size = batch_size35self.epochs = epochs36self.learning_rate = learning_rate37self.negative_samples = negative_samples38self.use_sampling_table = use_sampling_table3940def fit(self, X, y=None):41self.build_vocab(X)42self.build_graph()43indexed_texts = self.texts_to_index(X)4445sampling_table = None46if self.sort_vocab and self.use_sampling_table:47sampling_table = make_sampling_table(self.vocab_size_)4849for epoch in trange(self.epochs):50(batch_center,51batch_context,52batch_label) = generate_batch_data(53indexed_texts, self.batch_size, self.vocab_size_, self.window_size,54self.negative_samples, sampling_table)55self.model_.train_on_batch([batch_center, batch_context], batch_label)5657return self5859def transform(self, X):60embed_in = self._get_word_vectors()61X_embeddings = np.array([self._get_embedding(words, embed_in) for words in X])62return X_embeddings6364def _get_word_vectors(self):65return self.model_.get_layer('embed_in').get_weights()[0]6667def _get_embedding(self, words, embed_in):6869valid_words = [word for word in words if word in self.word2index_]70if valid_words:71embedding = np.zeros((len(valid_words), self.embed_size), dtype=np.float32)72for idx, word in enumerate(valid_words):73word_idx = self.word2index_[word]74embedding[idx] = embed_in[word_idx]7576return np.mean(embedding, axis=0)77else:78return np.zeros(self.embed_size)7980def build_vocab(self, texts):8182# list[str] flatten to list of words83words = [token for text in texts for token in text]8485word_count = {}86for word in words:87word_count[word] = word_count.get(word, 0) + 18889valid_word_count = [(word, count) for word, count in word_count.items()90if count >= self.min_count]91if self.sort_vocab:92from operator import itemgetter93valid_word_count = sorted(valid_word_count, key=itemgetter(1), reverse=True)9495index2word = ['unk']96word2index = {'unk': 0}97for word, _ in valid_word_count:98word2index[word] = len(word2index)99index2word.append(word)100101self.word2index_ = word2index102self.index2word_ = index2word103self.vocab_size_ = len(word2index)104return self105106def texts_to_index(self, texts):107"""108Returns109-------110texts_index : list[list[int]]111e.g. [[0, 2], [3, 1]]112each element in the outer list is the sentence, e.g. [0, 2]113and each element in the inner list is each word represented in numeric index.114"""115word2index = self.word2index_116texts_index = []117for text in texts:118text_index = [word2index.get(token, 0) for token in text]119texts_index.append(text_index)120121return texts_index122123def build_graph(self):124input_center = layers.Input((1,))125input_context = layers.Input((1,))126127embedding = layers.Embedding(self.vocab_size_, self.embed_size,128input_length=1, name='embed_in')129center = embedding(input_center) # shape [seq_len, # features (1), embed_size]130context = embedding(input_context)131132center = layers.Reshape((self.embed_size,))(center)133context = layers.Reshape((self.embed_size,))(context)134135dot_product = layers.dot([center, context], axes=1)136output = layers.Dense(1, activation='sigmoid')(dot_product)137self.model_ = Model(inputs=[input_center, input_context], outputs=output)138self.model_.compile(loss='binary_crossentropy',139optimizer=optimizers.RMSprop(lr=self.learning_rate))140return self141142# def build_graph(self):143# """144# A different way of building the graph where the center word and145# context word each have its own embedding layer.146# """147# input_center = layers.Input((1,))148# input_context = layers.Input((1,))149150# embedding_center = layers.Embedding(self.vocab_size_, self.embed_size,151# input_length=1, name='embed_in')152# embedding_context = layers.Embedding(self.vocab_size_, self.embed_size,153# input_length=1, name='embed_out')154# center = embedding_center(input_center) # shape [seq_len, # features (1), embed_size]155# context = embedding_context(input_context)156157# center = layers.Reshape((self.embed_size,))(center)158# context = layers.Reshape((self.embed_size,))(context)159160# dot_product = layers.dot([center, context], axes=1)161# output = layers.Dense(1, activation='sigmoid')(dot_product)162# self.model_ = Model(inputs=[input_center, input_context], outputs=output)163# self.model_.compile(loss='binary_crossentropy',164# optimizer=optimizers.RMSprop(lr=self.learning_rate))165# return self166167def most_similar(self, positive, negative=None, topn=10):168169# normalize word vectors to make the cosine distance calculation easier170# normed_vectors = vectors / np.sqrt((word_vectors ** 2).sum(axis=-1))[..., np.newaxis]171# ?? whether to cache the normed vector or replace the original one to speed up computation172word_vectors = self._get_word_vectors()173normed_vectors = normalize(word_vectors)174175# assign weight to positive and negative query words176positive = [] if positive is None else [(word, 1.0) for word in positive]177negative = [] if negative is None else [(word, -1.0) for word in negative]178179# compute the weighted average of all the query words180queries = []181all_word_index = set()182for word, weight in positive + negative:183word_index = self.word2index_[word]184word_vector = normed_vectors[word_index]185queries.append(weight * word_vector)186all_word_index.add(word_index)187188if not queries:189raise ValueError('cannot compute similarity with no input')190191query_vector = np.mean(queries, axis=0).reshape(1, -1)192normed_query_vector = normalize(query_vector).ravel()193194# cosine similarity between the query vector and all the existing word vectors195scores = np.dot(normed_vectors, normed_query_vector)196197actual_len = topn + len(all_word_index)198sorted_index = np.argpartition(scores, -actual_len)[-actual_len:]199best = sorted_index[np.argsort(scores[sorted_index])[::-1]]200201result = [(self.index2word_[index], scores[index])202for index in best if index not in all_word_index]203return result[:topn]204205206def generate_batch_data(indexed_texts, batch_size, vocab_size,207window_size, negative_samples, sampling_table):208batch_label = []209batch_center = []210batch_context = []211while len(batch_center) < batch_size:212# list[int]213rand_indexed_texts = np.random.choice(indexed_texts)214215# couples: list[(str, str)], list of word pairs216couples, labels = skipgrams(rand_indexed_texts, vocab_size,217window_size=window_size,218sampling_table=sampling_table,219negative_samples=negative_samples)220if couples:221centers, contexts = zip(*couples)222batch_center.extend(centers)223batch_context.extend(contexts)224batch_label.extend(labels)225226# trim to batch size at the end and convert to numpy array227batch_center = np.array(batch_center[:batch_size], dtype=np.int)228batch_context = np.array(batch_context[:batch_size], dtype=np.int)229batch_label = np.array(batch_label[:batch_size], dtype=np.int)230return batch_center, batch_context, batch_label231232233