Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/word2vec/word2vec_workflow.py
2585 views
1
"""
2
1) Basic preprocessing of the raw text.
3
2) trains a Phrase model to glue words that commonly appear next to
4
each other into bigrams.
5
3) trains the Word2vec model (skipgram + negative sampling); currently,
6
there are zero hyperparameter tuning.
7
"""
8
import os
9
import re
10
import logging
11
from joblib import cpu_count
12
from string import punctuation
13
from logzero import setup_logger
14
from nltk.corpus import stopwords
15
from gensim.models import Phrases
16
from gensim.models import Word2Vec
17
from gensim.models.phrases import Phraser
18
from gensim.models.word2vec import LineSentence
19
from sklearn.datasets import fetch_20newsgroups
20
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
21
logger = setup_logger(name = __name__, logfile = 'word2vec.log', level = logging.INFO)
22
23
24
def main():
25
# -------------------------------------------------------------------------------
26
# Parameters
27
28
# the script will most likely work if we swap the TEXTS variable
29
# with any iterable of text (where one element represents a document,
30
# and the whole iterable is the corpus)
31
newsgroups_train = fetch_20newsgroups(subset = 'train')
32
TEXTS = newsgroups_train.data
33
34
# a set of stopwords built-in to various packages
35
# we can always expand this set for the
36
# problem that we are working on, here we also included
37
# python built-in string punctuation mark
38
STOPWORDS = set(stopwords.words('english')) | set(punctuation) | set(ENGLISH_STOP_WORDS)
39
40
# create a directory called 'model' to store all outputs in later section
41
MODEL_DIR = 'model'
42
UNIGRAM_PATH = os.path.join(MODEL_DIR, 'unigram.txt')
43
PHRASE_MODEL_CHECKPOINT = os.path.join(MODEL_DIR, 'phrase_model')
44
BIGRAM_PATH = os.path.join(MODEL_DIR, 'bigram.txt')
45
WORD2VEC_CHECKPOINT = os.path.join(MODEL_DIR, 'word2vec')
46
47
# -------------------------------------------------------------------------------
48
logger.info('job started')
49
if not os.path.isdir(MODEL_DIR):
50
os.mkdir(MODEL_DIR)
51
52
if not os.path.exists(UNIGRAM_PATH):
53
logger.info('preprocessing text')
54
export_unigrams(UNIGRAM_PATH, texts=TEXTS, stop_words=STOPWORDS)
55
56
if os.path.exists(PHRASE_MODEL_CHECKPOINT):
57
phrase_model = Phrases.load(PHRASE_MODEL_CHECKPOINT)
58
else:
59
logger.info('training phrase model')
60
# use LineSetence to stream text as oppose to loading it all into memory
61
unigram_sentences = LineSentence(UNIGRAM_PATH)
62
phrase_model = Phrases(unigram_sentences)
63
phrase_model.save(PHRASE_MODEL_CHECKPOINT)
64
65
if not os.path.exists(BIGRAM_PATH):
66
logger.info('converting words to phrases')
67
export_bigrams(UNIGRAM_PATH, BIGRAM_PATH, phrase_model)
68
69
if os.path.exists(WORD2VEC_CHECKPOINT):
70
word2vec = Word2Vec.load(WORD2VEC_CHECKPOINT)
71
else:
72
logger.info('training word2vec')
73
word2vec = Word2Vec(corpus_file=BIGRAM_PATH, workers=cpu_count())
74
word2vec.save(WORD2VEC_CHECKPOINT)
75
76
logger.info('job completed')
77
78
79
def export_unigrams(unigram_path, texts, stop_words):
80
"""
81
Preprocessed the raw text and export it to a .txt file,
82
where each line is one document, for what sort of preprocessing
83
is done, please refer to the `normalize_text` function
84
85
Parameters
86
----------
87
unigram_path : str
88
output file path of the preprocessed unigram text.
89
90
texts : iterable
91
iterable can be simply a list, but for larger corpora,
92
consider an iterable that streams the sentences directly from
93
disk/network using Gensim's Linsentence or something along
94
those line.
95
96
stop_words : set
97
stopword set that will be excluded from the corpus.
98
"""
99
with open(unigram_path, 'w', encoding='utf_8') as f:
100
for text in texts:
101
cleaned_text = normalize_text(text, stop_words)
102
f.write(cleaned_text + '\n')
103
104
105
def normalize_text(text, stop_words):
106
# remove special characters\whitespaces
107
text = re.sub(r'[^a-zA-Z\s]', '', text, re.I | re.A)
108
109
# lower case & tokenize text
110
tokens = re.split(r'\s+', text.lower().strip())
111
112
# filter stopwords out of text &
113
# re-create text from filtered tokens
114
cleaned_text = ' '.join(token for token in tokens if token not in stop_words)
115
return cleaned_text
116
117
118
def export_bigrams(unigram_path, bigram_path, phrase_model):
119
"""
120
Use the learned phrase model to create (potential) bigrams,
121
and output the text that contains bigrams to disk
122
123
Parameters
124
----------
125
unigram_path : str
126
input file path of the preprocessed unigram text
127
128
bigram_path : str
129
output file path of the transformed bigram text
130
131
phrase_model : gensim's Phrase model object
132
133
References
134
----------
135
Gensim Phrase Detection
136
- https://radimrehurek.com/gensim/models/phrases.html
137
"""
138
139
# after training the Phrase model, create a performant
140
# Phraser object to transform any sentence (list of
141
# token strings) and glue unigrams together into bigrams
142
phraser = Phraser(phrase_model)
143
with open(bigram_path, 'w') as fout, open(unigram_path) as fin:
144
for text in fin:
145
unigram = text.split()
146
bigram = phraser[unigram]
147
bigram_sentence = ' '.join(bigram)
148
fout.write(bigram_sentence + '\n')
149
150
151
if __name__ == '__main__':
152
main()
153
154