Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
zackhy
GitHub Repository: zackhy/TextClassification
Path: blob/master/data_helper.py
1 views
1
# -*- coding: utf-8 -*-
2
import re
3
import os
4
import sys
5
import csv
6
import time
7
import json
8
import collections
9
10
import numpy as np
11
from tensorflow.contrib import learn
12
13
14
def load_data(file_path, sw_path=None, min_frequency=0, max_length=0, language='ch', vocab_processor=None, shuffle=True):
15
"""
16
Build dataset for mini-batch iterator
17
:param file_path: Data file path
18
:param sw_path: Stop word file path
19
:param language: 'ch' for Chinese and 'en' for English
20
:param min_frequency: the minimal frequency of words to keep
21
:param max_length: the max document length
22
:param vocab_processor: the predefined vocabulary processor
23
:param shuffle: whether to shuffle the data
24
:return data, labels, lengths, vocabulary processor
25
"""
26
with open(file_path, 'r', encoding='utf-8') as f:
27
print('Building dataset ...')
28
start = time.time()
29
incsv = csv.reader(f)
30
header = next(incsv) # Header
31
label_idx = header.index('label')
32
content_idx = header.index('content')
33
34
labels = []
35
sentences = []
36
37
if sw_path is not None:
38
sw = _stop_words(sw_path)
39
else:
40
sw = None
41
42
for line in incsv:
43
sent = line[content_idx].strip()
44
45
if language == 'ch':
46
sent = _tradition_2_simple(sent) # Convert traditional Chinese to simplified Chinese
47
elif language == 'en':
48
sent = sent.lower()
49
else:
50
raise ValueError('language should be one of [ch, en].')
51
52
sent = _clean_data(sent, sw, language=language) # Remove stop words and special characters
53
54
if len(sent) < 1:
55
continue
56
57
if language == 'ch':
58
sent = _word_segmentation(sent)
59
sentences.append(sent)
60
61
if int(line[label_idx]) < 0:
62
labels.append(2)
63
else:
64
labels.append(int(line[label_idx]))
65
66
labels = np.array(labels)
67
# Real lengths
68
lengths = np.array(list(map(len, [sent.strip().split(' ') for sent in sentences])))
69
70
if max_length == 0:
71
max_length = max(lengths)
72
73
# Extract vocabulary from sentences and map words to indices
74
if vocab_processor is None:
75
vocab_processor = learn.preprocessing.VocabularyProcessor(max_length, min_frequency=min_frequency)
76
data = np.array(list(vocab_processor.fit_transform(sentences)))
77
else:
78
data = np.array(list(vocab_processor.transform(sentences)))
79
80
data_size = len(data)
81
82
if shuffle:
83
shuffle_indices = np.random.permutation(np.arange(data_size))
84
data = data[shuffle_indices]
85
labels = labels[shuffle_indices]
86
lengths = lengths[shuffle_indices]
87
88
end = time.time()
89
90
print('Dataset has been built successfully.')
91
print('Run time: {}'.format(end - start))
92
print('Number of sentences: {}'.format(len(data)))
93
print('Vocabulary size: {}'.format(len(vocab_processor.vocabulary_._mapping)))
94
print('Max document length: {}\n'.format(vocab_processor.max_document_length))
95
96
return data, labels, lengths, vocab_processor
97
98
99
def batch_iter(data, labels, lengths, batch_size, num_epochs):
100
"""
101
A mini-batch iterator to generate mini-batches for training neural network
102
:param data: a list of sentences. each sentence is a vector of integers
103
:param labels: a list of labels
104
:param batch_size: the size of mini-batch
105
:param num_epochs: number of epochs
106
:return: a mini-batch iterator
107
"""
108
assert len(data) == len(labels) == len(lengths)
109
110
data_size = len(data)
111
epoch_length = data_size // batch_size
112
113
for _ in range(num_epochs):
114
for i in range(epoch_length):
115
start_index = i * batch_size
116
end_index = start_index + batch_size
117
118
xdata = data[start_index: end_index]
119
ydata = labels[start_index: end_index]
120
sequence_length = lengths[start_index: end_index]
121
122
yield xdata, ydata, sequence_length
123
124
# --------------- Private Methods ---------------
125
126
def _tradition_2_simple(sent):
127
""" Convert Traditional Chinese to Simplified Chinese """
128
# Please download langconv.py and zh_wiki.py first
129
# langconv.py and zh_wiki.py are used for converting between languages
130
try:
131
import langconv
132
except ImportError as e:
133
error = "Please download langconv.py and zh_wiki.py at "
134
error += "https://github.com/skydark/nstools/tree/master/zhtools."
135
print(str(e) + ': ' + error)
136
sys.exit()
137
138
return langconv.Converter('zh-hans').convert(sent)
139
140
141
def _word_segmentation(sent):
142
""" Tokenizer for Chinese """
143
import jieba
144
sent = ' '.join(list(jieba.cut(sent, cut_all=False, HMM=True)))
145
return re.sub(r'\s+', ' ', sent)
146
147
148
def _stop_words(path):
149
with open(path, 'r', encoding='utf-8') as f:
150
sw = list()
151
for line in f:
152
sw.append(line.strip())
153
154
return set(sw)
155
156
157
def _clean_data(sent, sw, language='ch'):
158
""" Remove special characters and stop words """
159
if language == 'ch':
160
sent = re.sub(r"[^\u4e00-\u9fa5A-z0-9!?,。]", " ", sent)
161
sent = re.sub('!{2,}', '!', sent)
162
sent = re.sub('?{2,}', '!', sent)
163
sent = re.sub('。{2,}', '。', sent)
164
sent = re.sub(',{2,}', ',', sent)
165
sent = re.sub('\s{2,}', ' ', sent)
166
if language == 'en':
167
sent = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", sent)
168
sent = re.sub(r"\'s", " \'s", sent)
169
sent = re.sub(r"\'ve", " \'ve", sent)
170
sent = re.sub(r"n\'t", " n\'t", sent)
171
sent = re.sub(r"\'re", " \'re", sent)
172
sent = re.sub(r"\'d", " \'d", sent)
173
sent = re.sub(r"\'ll", " \'ll", sent)
174
sent = re.sub(r",", " , ", sent)
175
sent = re.sub(r"!", " ! ", sent)
176
sent = re.sub(r"\(", " \( ", sent)
177
sent = re.sub(r"\)", " \) ", sent)
178
sent = re.sub(r"\?", " \? ", sent)
179
sent = re.sub(r"\s{2,}", " ", sent)
180
if sw is not None:
181
sent = "".join([word for word in sent if word not in sw])
182
183
return sent
184
185