Path: blob/master/keras/text_classification/transformers/gensim_word2vec.py
2601 views
import numpy as np1from sklearn.base import BaseEstimator, TransformerMixin2from gensim.models import Word2Vec345class GensimWord2VecVectorizer(BaseEstimator, TransformerMixin):6"""7Word vectors are averaged across to create the document-level vectors/features.89gensim's own gensim.sklearn_api.W2VTransformer doesn't support out of vocabulary words,10hence we roll out our own.1112All the parameters are gensim.models.Word2Vec's parameters.1314https://radimrehurek.com/gensim/models/word2vec.html#gensim.models.word2vec.Word2Vec15"""1617def __init__(self, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None,18sample=0.001, seed=1, workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5,19ns_exponent=0.75, cbow_mean=1, hashfxn=hash, iter=5, null_word=0,20trim_rule=None, sorted_vocab=1, batch_words=10000, compute_loss=False,21callbacks=(), max_final_vocab=None):22self.size = size23self.alpha = alpha24self.window = window25self.min_count = min_count26self.max_vocab_size = max_vocab_size27self.sample = sample28self.seed = seed29self.workers = workers30self.min_alpha = min_alpha31self.sg = sg32self.hs = hs33self.negative = negative34self.ns_exponent = ns_exponent35self.cbow_mean = cbow_mean36self.hashfxn = hashfxn37self.iter = iter38self.null_word = null_word39self.trim_rule = trim_rule40self.sorted_vocab = sorted_vocab41self.batch_words = batch_words42self.compute_loss = compute_loss43self.callbacks = callbacks44self.max_final_vocab = max_final_vocab4546def fit(self, X, y=None):47self.model_ = Word2Vec(48sentences=X, corpus_file=None,49size=self.size, alpha=self.alpha, window=self.window, min_count=self.min_count,50max_vocab_size=self.max_vocab_size, sample=self.sample, seed=self.seed,51workers=self.workers, min_alpha=self.min_alpha, sg=self.sg, hs=self.hs,52negative=self.negative, ns_exponent=self.ns_exponent, cbow_mean=self.cbow_mean,53hashfxn=self.hashfxn, iter=self.iter, null_word=self.null_word,54trim_rule=self.trim_rule, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words,55compute_loss=self.compute_loss, callbacks=self.callbacks,56max_final_vocab=self.max_final_vocab)57return self5859def transform(self, X):60X_embeddings = np.array([self._get_embedding(words) for words in X])61return X_embeddings6263def _get_embedding(self, words):64valid_words = [word for word in words if word in self.model_.wv.vocab]65if valid_words:66embedding = np.zeros((len(valid_words), self.size), dtype=np.float32)67for idx, word in enumerate(valid_words):68embedding[idx] = self.model_.wv[word]6970return np.mean(embedding, axis=0)71else:72return np.zeros(self.size)737475