Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/keras/text_classification/transformers/keras_word2vec.py
2594 views
1
import numpy as np
2
from tqdm import trange
3
from keras import layers, optimizers, Model
4
from sklearn.preprocessing import normalize
5
from sklearn.base import BaseEstimator, TransformerMixin
6
from keras.preprocessing.sequence import skipgrams, make_sampling_table
7
8
9
class KerasWord2VecVectorizer(BaseEstimator, TransformerMixin):
10
"""
11
Word vectors are averaged across to create the document-level vectors/features.
12
13
Attributes
14
----------
15
word2index_ : dict[str, int]
16
Each distinct word in the corpus gets map to a numeric index.
17
e.g. {'unk': 0, 'film': 1}
18
19
index2word_ : list[str]
20
Reverse napping of ``word2index_`` e.g. ['unk', 'film']
21
22
vocab_size_ : int
23
24
model_ : keras.models.Model
25
26
"""
27
28
def __init__(self, embed_size=100, window_size=5, batch_size=64, epochs=5000,
29
learning_rate=0.05, negative_samples=0.5, min_count=2,
30
use_sampling_table=True, sort_vocab=True):
31
self.min_count = min_count
32
self.embed_size = embed_size
33
self.sort_vocab = sort_vocab
34
self.window_size = window_size
35
self.batch_size = batch_size
36
self.epochs = epochs
37
self.learning_rate = learning_rate
38
self.negative_samples = negative_samples
39
self.use_sampling_table = use_sampling_table
40
41
def fit(self, X, y=None):
42
self.build_vocab(X)
43
self.build_graph()
44
indexed_texts = self.texts_to_index(X)
45
46
sampling_table = None
47
if self.sort_vocab and self.use_sampling_table:
48
sampling_table = make_sampling_table(self.vocab_size_)
49
50
for epoch in trange(self.epochs):
51
(batch_center,
52
batch_context,
53
batch_label) = generate_batch_data(
54
indexed_texts, self.batch_size, self.vocab_size_, self.window_size,
55
self.negative_samples, sampling_table)
56
self.model_.train_on_batch([batch_center, batch_context], batch_label)
57
58
return self
59
60
def transform(self, X):
61
embed_in = self._get_word_vectors()
62
X_embeddings = np.array([self._get_embedding(words, embed_in) for words in X])
63
return X_embeddings
64
65
def _get_word_vectors(self):
66
return self.model_.get_layer('embed_in').get_weights()[0]
67
68
def _get_embedding(self, words, embed_in):
69
70
valid_words = [word for word in words if word in self.word2index_]
71
if valid_words:
72
embedding = np.zeros((len(valid_words), self.embed_size), dtype=np.float32)
73
for idx, word in enumerate(valid_words):
74
word_idx = self.word2index_[word]
75
embedding[idx] = embed_in[word_idx]
76
77
return np.mean(embedding, axis=0)
78
else:
79
return np.zeros(self.embed_size)
80
81
def build_vocab(self, texts):
82
83
# list[str] flatten to list of words
84
words = [token for text in texts for token in text]
85
86
word_count = {}
87
for word in words:
88
word_count[word] = word_count.get(word, 0) + 1
89
90
valid_word_count = [(word, count) for word, count in word_count.items()
91
if count >= self.min_count]
92
if self.sort_vocab:
93
from operator import itemgetter
94
valid_word_count = sorted(valid_word_count, key=itemgetter(1), reverse=True)
95
96
index2word = ['unk']
97
word2index = {'unk': 0}
98
for word, _ in valid_word_count:
99
word2index[word] = len(word2index)
100
index2word.append(word)
101
102
self.word2index_ = word2index
103
self.index2word_ = index2word
104
self.vocab_size_ = len(word2index)
105
return self
106
107
def texts_to_index(self, texts):
108
"""
109
Returns
110
-------
111
texts_index : list[list[int]]
112
e.g. [[0, 2], [3, 1]]
113
each element in the outer list is the sentence, e.g. [0, 2]
114
and each element in the inner list is each word represented in numeric index.
115
"""
116
word2index = self.word2index_
117
texts_index = []
118
for text in texts:
119
text_index = [word2index.get(token, 0) for token in text]
120
texts_index.append(text_index)
121
122
return texts_index
123
124
def build_graph(self):
125
input_center = layers.Input((1,))
126
input_context = layers.Input((1,))
127
128
embedding = layers.Embedding(self.vocab_size_, self.embed_size,
129
input_length=1, name='embed_in')
130
center = embedding(input_center) # shape [seq_len, # features (1), embed_size]
131
context = embedding(input_context)
132
133
center = layers.Reshape((self.embed_size,))(center)
134
context = layers.Reshape((self.embed_size,))(context)
135
136
dot_product = layers.dot([center, context], axes=1)
137
output = layers.Dense(1, activation='sigmoid')(dot_product)
138
self.model_ = Model(inputs=[input_center, input_context], outputs=output)
139
self.model_.compile(loss='binary_crossentropy',
140
optimizer=optimizers.RMSprop(lr=self.learning_rate))
141
return self
142
143
# def build_graph(self):
144
# """
145
# A different way of building the graph where the center word and
146
# context word each have its own embedding layer.
147
# """
148
# input_center = layers.Input((1,))
149
# input_context = layers.Input((1,))
150
151
# embedding_center = layers.Embedding(self.vocab_size_, self.embed_size,
152
# input_length=1, name='embed_in')
153
# embedding_context = layers.Embedding(self.vocab_size_, self.embed_size,
154
# input_length=1, name='embed_out')
155
# center = embedding_center(input_center) # shape [seq_len, # features (1), embed_size]
156
# context = embedding_context(input_context)
157
158
# center = layers.Reshape((self.embed_size,))(center)
159
# context = layers.Reshape((self.embed_size,))(context)
160
161
# dot_product = layers.dot([center, context], axes=1)
162
# output = layers.Dense(1, activation='sigmoid')(dot_product)
163
# self.model_ = Model(inputs=[input_center, input_context], outputs=output)
164
# self.model_.compile(loss='binary_crossentropy',
165
# optimizer=optimizers.RMSprop(lr=self.learning_rate))
166
# return self
167
168
def most_similar(self, positive, negative=None, topn=10):
169
170
# normalize word vectors to make the cosine distance calculation easier
171
# normed_vectors = vectors / np.sqrt((word_vectors ** 2).sum(axis=-1))[..., np.newaxis]
172
# ?? whether to cache the normed vector or replace the original one to speed up computation
173
word_vectors = self._get_word_vectors()
174
normed_vectors = normalize(word_vectors)
175
176
# assign weight to positive and negative query words
177
positive = [] if positive is None else [(word, 1.0) for word in positive]
178
negative = [] if negative is None else [(word, -1.0) for word in negative]
179
180
# compute the weighted average of all the query words
181
queries = []
182
all_word_index = set()
183
for word, weight in positive + negative:
184
word_index = self.word2index_[word]
185
word_vector = normed_vectors[word_index]
186
queries.append(weight * word_vector)
187
all_word_index.add(word_index)
188
189
if not queries:
190
raise ValueError('cannot compute similarity with no input')
191
192
query_vector = np.mean(queries, axis=0).reshape(1, -1)
193
normed_query_vector = normalize(query_vector).ravel()
194
195
# cosine similarity between the query vector and all the existing word vectors
196
scores = np.dot(normed_vectors, normed_query_vector)
197
198
actual_len = topn + len(all_word_index)
199
sorted_index = np.argpartition(scores, -actual_len)[-actual_len:]
200
best = sorted_index[np.argsort(scores[sorted_index])[::-1]]
201
202
result = [(self.index2word_[index], scores[index])
203
for index in best if index not in all_word_index]
204
return result[:topn]
205
206
207
def generate_batch_data(indexed_texts, batch_size, vocab_size,
208
window_size, negative_samples, sampling_table):
209
batch_label = []
210
batch_center = []
211
batch_context = []
212
while len(batch_center) < batch_size:
213
# list[int]
214
rand_indexed_texts = np.random.choice(indexed_texts)
215
216
# couples: list[(str, str)], list of word pairs
217
couples, labels = skipgrams(rand_indexed_texts, vocab_size,
218
window_size=window_size,
219
sampling_table=sampling_table,
220
negative_samples=negative_samples)
221
if couples:
222
centers, contexts = zip(*couples)
223
batch_center.extend(centers)
224
batch_context.extend(contexts)
225
batch_label.extend(labels)
226
227
# trim to batch size at the end and convert to numpy array
228
batch_center = np.array(batch_center[:batch_size], dtype=np.int)
229
batch_context = np.array(batch_context[:batch_size], dtype=np.int)
230
batch_label = np.array(batch_label[:batch_size], dtype=np.int)
231
return batch_center, batch_context, batch_label
232
233