Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/datasets.py
984 views
1
#!/usr/bin/env python3
2
3
'''
4
This module contains our Dataset classes and functions to load the 3 datasets we're using.
5
6
You should only need to call load_multitask_data to get the training and dev examples
7
to train your model.
8
'''
9
10
11
import csv
12
13
import torch
14
from torch.utils.data import Dataset
15
from tokenizer import BertTokenizer
16
17
18
def preprocess_string(s):
19
return ' '.join(s.lower()
20
.replace('.', ' .')
21
.replace('?', ' ?')
22
.replace(',', ' ,')
23
.replace('\'', ' \'')
24
.split())
25
26
27
class SentenceClassificationDataset(Dataset):
28
def __init__(self, dataset, args):
29
self.dataset = dataset
30
self.p = args
31
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
32
33
def __len__(self):
34
return len(self.dataset)
35
36
def __getitem__(self, idx):
37
return self.dataset[idx]
38
39
def pad_data(self, data):
40
41
sents = [x[0] for x in data]
42
labels = [x[1] for x in data]
43
sent_ids = [x[2] for x in data]
44
45
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
46
token_ids = torch.LongTensor(encoding['input_ids'])
47
attention_mask = torch.LongTensor(encoding['attention_mask'])
48
labels = torch.LongTensor(labels)
49
50
return token_ids, attention_mask, labels, sents, sent_ids
51
52
def collate_fn(self, all_data):
53
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
54
55
batched_data = {
56
'token_ids': token_ids,
57
'attention_mask': attention_mask,
58
'labels': labels,
59
'sents': sents,
60
'sent_ids': sent_ids
61
}
62
63
return batched_data
64
65
66
class SentenceClassificationTestDataset(Dataset):
67
def __init__(self, dataset, args):
68
self.dataset = dataset
69
self.p = args
70
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
71
72
def __len__(self):
73
return len(self.dataset)
74
75
def __getitem__(self, idx):
76
return self.dataset[idx]
77
78
def pad_data(self, data):
79
sents = [x[0] for x in data]
80
sent_ids = [x[1] for x in data]
81
82
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
83
token_ids = torch.LongTensor(encoding['input_ids'])
84
attention_mask = torch.LongTensor(encoding['attention_mask'])
85
86
return token_ids, attention_mask, sents, sent_ids
87
88
def collate_fn(self, all_data):
89
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
90
91
batched_data = {
92
'token_ids': token_ids,
93
'attention_mask': attention_mask,
94
'sents': sents,
95
'sent_ids': sent_ids
96
}
97
98
return batched_data
99
100
101
class SentencePairDataset(Dataset):
102
def __init__(self, dataset, args, isRegression =False):
103
self.dataset = dataset
104
self.p = args
105
self.isRegression = isRegression
106
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
107
108
def __len__(self):
109
return len(self.dataset)
110
111
def __getitem__(self, idx):
112
return self.dataset[idx]
113
114
def pad_data(self, data):
115
sent1 = [x[0] for x in data]
116
sent2 = [x[1] for x in data]
117
labels = [x[2] for x in data]
118
sent_ids = [x[3] for x in data]
119
120
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
121
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
122
123
token_ids = torch.LongTensor(encoding1['input_ids'])
124
attention_mask = torch.LongTensor(encoding1['attention_mask'])
125
token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
126
127
token_ids2 = torch.LongTensor(encoding2['input_ids'])
128
attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
129
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
130
if self.isRegression:
131
labels = torch.DoubleTensor(labels)
132
else:
133
labels = torch.LongTensor(labels)
134
135
136
return (token_ids, token_type_ids, attention_mask,
137
token_ids2, token_type_ids2, attention_mask2,
138
labels,sent_ids)
139
140
def collate_fn(self, all_data):
141
(token_ids, token_type_ids, attention_mask,
142
token_ids2, token_type_ids2, attention_mask2,
143
labels, sent_ids) = self.pad_data(all_data)
144
145
batched_data = {
146
'token_ids_1': token_ids,
147
'token_type_ids_1': token_type_ids,
148
'attention_mask_1': attention_mask,
149
'token_ids_2': token_ids2,
150
'token_type_ids_2': token_type_ids2,
151
'attention_mask_2': attention_mask2,
152
'labels': labels,
153
'sent_ids': sent_ids
154
}
155
156
return batched_data
157
158
159
class SentencePairTestDataset(Dataset):
160
def __init__(self, dataset, args):
161
self.dataset = dataset
162
self.p = args
163
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
164
165
def __len__(self):
166
return len(self.dataset)
167
168
def __getitem__(self, idx):
169
return self.dataset[idx]
170
171
def pad_data(self, data):
172
sent1 = [x[0] for x in data]
173
sent2 = [x[1] for x in data]
174
sent_ids = [x[2] for x in data]
175
176
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
177
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
178
179
token_ids = torch.LongTensor(encoding1['input_ids'])
180
attention_mask = torch.LongTensor(encoding1['attention_mask'])
181
token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
182
183
token_ids2 = torch.LongTensor(encoding2['input_ids'])
184
attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
185
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
186
187
188
return (token_ids, token_type_ids, attention_mask,
189
token_ids2, token_type_ids2, attention_mask2,
190
sent_ids)
191
192
def collate_fn(self, all_data):
193
(token_ids, token_type_ids, attention_mask,
194
token_ids2, token_type_ids2, attention_mask2,
195
sent_ids) = self.pad_data(all_data)
196
197
batched_data = {
198
'token_ids_1': token_ids,
199
'token_type_ids_1': token_type_ids,
200
'attention_mask_1': attention_mask,
201
'token_ids_2': token_ids2,
202
'token_type_ids_2': token_type_ids2,
203
'attention_mask_2': attention_mask2,
204
'sent_ids': sent_ids
205
}
206
207
return batched_data
208
209
210
def load_multitask_test_data():
211
paraphrase_filename = f'data/quora-test.csv'
212
sentiment_filename = f'data/ids-sst-test.txt'
213
similarity_filename = f'data/sts-test.csv'
214
215
sentiment_data = []
216
217
with open(sentiment_filename, 'r') as fp:
218
for record in csv.DictReader(fp,delimiter = '\t'):
219
sent = record['sentence'].lower().strip()
220
sentiment_data.append(sent)
221
222
print(f"Loaded {len(sentiment_data)} test examples from {sentiment_filename}")
223
224
paraphrase_data = []
225
with open(paraphrase_filename, 'r') as fp:
226
for record in csv.DictReader(fp,delimiter = '\t'):
227
#if record['split'] != split:
228
# continue
229
paraphrase_data.append((preprocess_string(record['sentence1']),
230
preprocess_string(record['sentence2']),
231
))
232
233
print(f"Loaded {len(paraphrase_data)} test examples from {paraphrase_filename}")
234
235
similarity_data = []
236
with open(similarity_filename, 'r') as fp:
237
for record in csv.DictReader(fp,delimiter = '\t'):
238
similarity_data.append((preprocess_string(record['sentence1']),
239
preprocess_string(record['sentence2']),
240
))
241
242
print(f"Loaded {len(similarity_data)} test examples from {similarity_filename}")
243
244
return sentiment_data, paraphrase_data, similarity_data
245
246
247
248
def load_multitask_data(sentiment_filename,paraphrase_filename,similarity_filename,split='train'):
249
sentiment_data = []
250
num_labels = {}
251
if split == 'test':
252
with open(sentiment_filename, 'r') as fp:
253
for record in csv.DictReader(fp,delimiter = '\t'):
254
sent = record['sentence'].lower().strip()
255
sent_id = record['id'].lower().strip()
256
sentiment_data.append((sent,sent_id))
257
else:
258
with open(sentiment_filename, 'r') as fp:
259
for record in csv.DictReader(fp,delimiter = '\t'):
260
sent = record['sentence'].lower().strip()
261
sent_id = record['id'].lower().strip()
262
label = int(record['sentiment'].strip())
263
if label not in num_labels:
264
num_labels[label] = len(num_labels)
265
sentiment_data.append((sent, label,sent_id))
266
267
print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")
268
269
paraphrase_data = []
270
if split == 'test':
271
with open(paraphrase_filename, 'r') as fp:
272
for record in csv.DictReader(fp,delimiter = '\t'):
273
sent_id = record['id'].lower().strip()
274
paraphrase_data.append((preprocess_string(record['sentence1']),
275
preprocess_string(record['sentence2']),
276
sent_id))
277
278
else:
279
with open(paraphrase_filename, 'r') as fp:
280
for record in csv.DictReader(fp,delimiter = '\t'):
281
try:
282
sent_id = record['id'].lower().strip()
283
paraphrase_data.append((preprocess_string(record['sentence1']),
284
preprocess_string(record['sentence2']),
285
int(float(record['is_duplicate'])),sent_id))
286
except:
287
pass
288
289
print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")
290
291
similarity_data = []
292
if split == 'test':
293
with open(similarity_filename, 'r') as fp:
294
for record in csv.DictReader(fp,delimiter = '\t'):
295
sent_id = record['id'].lower().strip()
296
similarity_data.append((preprocess_string(record['sentence1']),
297
preprocess_string(record['sentence2'])
298
,sent_id))
299
else:
300
with open(similarity_filename, 'r') as fp:
301
for record in csv.DictReader(fp,delimiter = '\t'):
302
sent_id = record['id'].lower().strip()
303
similarity_data.append((preprocess_string(record['sentence1']),
304
preprocess_string(record['sentence2']),
305
float(record['similarity']),sent_id))
306
307
print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")
308
309
return sentiment_data, num_labels, paraphrase_data, similarity_data
310
311