Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a2/utils/treebank.py
1003 views
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import pickle
5
import numpy as np
6
import os
7
import random
8
9
class StanfordSentiment:
10
def __init__(self, path=None, tablesize = 1000000):
11
if not path:
12
path = "utils/datasets/stanfordSentimentTreebank"
13
14
self.path = path
15
self.tablesize = tablesize
16
17
def tokens(self):
18
if hasattr(self, "_tokens") and self._tokens:
19
return self._tokens
20
21
tokens = dict()
22
tokenfreq = dict()
23
wordcount = 0
24
revtokens = []
25
idx = 0
26
27
for sentence in self.sentences():
28
for w in sentence:
29
wordcount += 1
30
if not w in tokens:
31
tokens[w] = idx
32
revtokens += [w]
33
tokenfreq[w] = 1
34
idx += 1
35
else:
36
tokenfreq[w] += 1
37
38
tokens["UNK"] = idx
39
revtokens += ["UNK"]
40
tokenfreq["UNK"] = 1
41
wordcount += 1
42
43
self._tokens = tokens
44
self._tokenfreq = tokenfreq
45
self._wordcount = wordcount
46
self._revtokens = revtokens
47
return self._tokens
48
49
def sentences(self):
50
if hasattr(self, "_sentences") and self._sentences:
51
return self._sentences
52
53
sentences = []
54
with open(self.path + "/datasetSentences.txt", "r") as f:
55
first = True
56
for line in f:
57
if first:
58
first = False
59
continue
60
61
splitted = line.strip().split()[1:]
62
# Deal with some peculiar encoding issues with this file
63
sentences += [[w.lower() for w in splitted]]
64
65
self._sentences = sentences
66
self._sentlengths = np.array([len(s) for s in sentences])
67
self._cumsentlen = np.cumsum(self._sentlengths)
68
69
return self._sentences
70
71
def numSentences(self):
72
if hasattr(self, "_numSentences") and self._numSentences:
73
return self._numSentences
74
else:
75
self._numSentences = len(self.sentences())
76
return self._numSentences
77
78
def allSentences(self):
79
if hasattr(self, "_allsentences") and self._allsentences:
80
return self._allsentences
81
82
sentences = self.sentences()
83
rejectProb = self.rejectProb()
84
tokens = self.tokens()
85
allsentences = [[w for w in s
86
if 0 >= rejectProb[tokens[w]] or random.random() >= rejectProb[tokens[w]]]
87
for s in sentences * 30]
88
89
allsentences = [s for s in allsentences if len(s) > 1]
90
91
self._allsentences = allsentences
92
93
return self._allsentences
94
95
def getRandomContext(self, C=5):
96
allsent = self.allSentences()
97
sentID = random.randint(0, len(allsent) - 1)
98
sent = allsent[sentID]
99
wordID = random.randint(0, len(sent) - 1)
100
101
context = sent[max(0, wordID - C):wordID]
102
if wordID+1 < len(sent):
103
context += sent[wordID+1:min(len(sent), wordID + C + 1)]
104
105
centerword = sent[wordID]
106
context = [w for w in context if w != centerword]
107
108
if len(context) > 0:
109
return centerword, context
110
else:
111
return self.getRandomContext(C)
112
113
def sent_labels(self):
114
if hasattr(self, "_sent_labels") and self._sent_labels:
115
return self._sent_labels
116
117
dictionary = dict()
118
phrases = 0
119
with open(self.path + "/dictionary.txt", "r") as f:
120
for line in f:
121
line = line.strip()
122
if not line: continue
123
splitted = line.split("|")
124
dictionary[splitted[0].lower()] = int(splitted[1])
125
phrases += 1
126
127
labels = [0.0] * phrases
128
with open(self.path + "/sentiment_labels.txt", "r") as f:
129
first = True
130
for line in f:
131
if first:
132
first = False
133
continue
134
135
line = line.strip()
136
if not line: continue
137
splitted = line.split("|")
138
labels[int(splitted[0])] = float(splitted[1])
139
140
sent_labels = [0.0] * self.numSentences()
141
sentences = self.sentences()
142
for i in range(self.numSentences()):
143
sentence = sentences[i]
144
full_sent = " ".join(sentence).replace('-lrb-', '(').replace('-rrb-', ')')
145
sent_labels[i] = labels[dictionary[full_sent]]
146
147
self._sent_labels = sent_labels
148
return self._sent_labels
149
150
def dataset_split(self):
151
if hasattr(self, "_split") and self._split:
152
return self._split
153
154
split = [[] for i in range(3)]
155
with open(self.path + "/datasetSplit.txt", "r") as f:
156
first = True
157
for line in f:
158
if first:
159
first = False
160
continue
161
162
splitted = line.strip().split(",")
163
split[int(splitted[1]) - 1] += [int(splitted[0]) - 1]
164
165
self._split = split
166
return self._split
167
168
def getRandomTrainSentence(self):
169
split = self.dataset_split()
170
sentId = split[0][random.randint(0, len(split[0]) - 1)]
171
return self.sentences()[sentId], self.categorify(self.sent_labels()[sentId])
172
173
def categorify(self, label):
174
if label <= 0.2:
175
return 0
176
elif label <= 0.4:
177
return 1
178
elif label <= 0.6:
179
return 2
180
elif label <= 0.8:
181
return 3
182
else:
183
return 4
184
185
def getDevSentences(self):
186
return self.getSplitSentences(2)
187
188
def getTestSentences(self):
189
return self.getSplitSentences(1)
190
191
def getTrainSentences(self):
192
return self.getSplitSentences(0)
193
194
def getSplitSentences(self, split=0):
195
ds_split = self.dataset_split()
196
return [(self.sentences()[i], self.categorify(self.sent_labels()[i])) for i in ds_split[split]]
197
198
def sampleTable(self):
199
if hasattr(self, '_sampleTable') and self._sampleTable is not None:
200
return self._sampleTable
201
202
nTokens = len(self.tokens())
203
samplingFreq = np.zeros((nTokens,))
204
self.allSentences()
205
i = 0
206
for w in range(nTokens):
207
w = self._revtokens[i]
208
if w in self._tokenfreq:
209
freq = 1.0 * self._tokenfreq[w]
210
# Reweigh
211
freq = freq ** 0.75
212
else:
213
freq = 0.0
214
samplingFreq[i] = freq
215
i += 1
216
217
samplingFreq /= np.sum(samplingFreq)
218
samplingFreq = np.cumsum(samplingFreq) * self.tablesize
219
220
self._sampleTable = [0] * self.tablesize
221
222
j = 0
223
for i in range(self.tablesize):
224
while i > samplingFreq[j]:
225
j += 1
226
self._sampleTable[i] = j
227
228
return self._sampleTable
229
230
def rejectProb(self):
231
if hasattr(self, '_rejectProb') and self._rejectProb is not None:
232
return self._rejectProb
233
234
threshold = 1e-5 * self._wordcount
235
236
nTokens = len(self.tokens())
237
rejectProb = np.zeros((nTokens,))
238
for i in range(nTokens):
239
w = self._revtokens[i]
240
freq = 1.0 * self._tokenfreq[w]
241
# Reweigh
242
rejectProb[i] = max(0, 1 - np.sqrt(threshold / freq))
243
244
self._rejectProb = rejectProb
245
return self._rejectProb
246
247
def sampleTokenIdx(self):
248
return self.sampleTable()[random.randint(0, self.tablesize - 1)]
249