Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/keras/text_classification/transformers/gensim_word2vec.py
2601 views
1
import numpy as np
2
from sklearn.base import BaseEstimator, TransformerMixin
3
from gensim.models import Word2Vec
4
5
6
class GensimWord2VecVectorizer(BaseEstimator, TransformerMixin):
7
"""
8
Word vectors are averaged across to create the document-level vectors/features.
9
10
gensim's own gensim.sklearn_api.W2VTransformer doesn't support out of vocabulary words,
11
hence we roll out our own.
12
13
All the parameters are gensim.models.Word2Vec's parameters.
14
15
https://radimrehurek.com/gensim/models/word2vec.html#gensim.models.word2vec.Word2Vec
16
"""
17
18
def __init__(self, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None,
19
sample=0.001, seed=1, workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5,
20
ns_exponent=0.75, cbow_mean=1, hashfxn=hash, iter=5, null_word=0,
21
trim_rule=None, sorted_vocab=1, batch_words=10000, compute_loss=False,
22
callbacks=(), max_final_vocab=None):
23
self.size = size
24
self.alpha = alpha
25
self.window = window
26
self.min_count = min_count
27
self.max_vocab_size = max_vocab_size
28
self.sample = sample
29
self.seed = seed
30
self.workers = workers
31
self.min_alpha = min_alpha
32
self.sg = sg
33
self.hs = hs
34
self.negative = negative
35
self.ns_exponent = ns_exponent
36
self.cbow_mean = cbow_mean
37
self.hashfxn = hashfxn
38
self.iter = iter
39
self.null_word = null_word
40
self.trim_rule = trim_rule
41
self.sorted_vocab = sorted_vocab
42
self.batch_words = batch_words
43
self.compute_loss = compute_loss
44
self.callbacks = callbacks
45
self.max_final_vocab = max_final_vocab
46
47
def fit(self, X, y=None):
48
self.model_ = Word2Vec(
49
sentences=X, corpus_file=None,
50
size=self.size, alpha=self.alpha, window=self.window, min_count=self.min_count,
51
max_vocab_size=self.max_vocab_size, sample=self.sample, seed=self.seed,
52
workers=self.workers, min_alpha=self.min_alpha, sg=self.sg, hs=self.hs,
53
negative=self.negative, ns_exponent=self.ns_exponent, cbow_mean=self.cbow_mean,
54
hashfxn=self.hashfxn, iter=self.iter, null_word=self.null_word,
55
trim_rule=self.trim_rule, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words,
56
compute_loss=self.compute_loss, callbacks=self.callbacks,
57
max_final_vocab=self.max_final_vocab)
58
return self
59
60
def transform(self, X):
61
X_embeddings = np.array([self._get_embedding(words) for words in X])
62
return X_embeddings
63
64
def _get_embedding(self, words):
65
valid_words = [word for word in words if word in self.model_.wv.vocab]
66
if valid_words:
67
embedding = np.zeros((len(valid_words), self.size), dtype=np.float32)
68
for idx, word in enumerate(valid_words):
69
embedding[idx] = self.model_.wv[word]
70
71
return np.mean(embedding, axis=0)
72
else:
73
return np.zeros(self.size)
74
75