Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/src/dataset.py
1003 views
1
import random
2
import torch
3
from torch.utils.data import Dataset
4
import argparse
5
6
"""
7
The input-output pairs (x, y) of the NameDataset are of the following form:
8
9
x: Where was Khatchig Mouradian born?⁇Lebanon⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
10
y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□⁇Lebanon⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
11
x: Where was Jacob Henry Studer born?⁇Columbus⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
12
y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□⁇Columbus⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
13
14
Using the PAD_CHAR characters in y before the ⁇[place] keeps the trainer from
15
optimizing the model to predict the question, "Where was...".
16
17
Note that the NameDataset should take the pretraining_dataset defined in run.py
18
as an input. This is to allow the vocab specification of the NameDataset to be
19
the same as that of the pretraining dataset.
20
21
You don't need to implement anything in NameDataset.
22
"""
23
24
class NameDataset(Dataset):
25
def __init__(self, pretraining_dataset, data):
26
self.MASK_CHAR = u"\u2047" # the doublequestionmark character, for mask
27
self.PAD_CHAR = u"\u25A1" # the empty square character, for pad
28
self.itos = pretraining_dataset.itos
29
self.stoi = pretraining_dataset.stoi
30
self.block_size = pretraining_dataset.block_size
31
self.data = list(data.encode('utf-8').decode('ascii', errors='ignore').split('\n'))
32
33
def __len__(self):
34
# returns the length of the dataset
35
return len(self.data) - 1
36
37
def __getitem__(self, idx):
38
inp, oup = self.data[idx].split('\t')
39
x = inp + self.MASK_CHAR + oup + self.MASK_CHAR
40
x = x + self.PAD_CHAR*(self.block_size - len(x))
41
y = self.PAD_CHAR*(len(inp)-1) + x[len(inp):]
42
43
x = x[:-1]
44
x = torch.tensor([self.stoi[c] for c in x], dtype=torch.long)
45
y = torch.tensor([self.stoi[c] for c in y], dtype=torch.long)
46
return x, y
47
48
49
"""
50
[part e]
51
52
Write a class that yields examples of a simplified span corruption objective.
53
Do not change the signature of the __init__ or __getitem__ functions.
54
55
Make sure to implement the full spec for full credit -- we list below the
56
criteria that must be satisfied for a full implementation.
57
58
--------------
59
Vocabulary Specification
60
61
Your vocabulary is to be accessible via two dictionaries:
62
self.stoi: a dictionary from characters in the vocabulary to indices of type
63
int
64
self.itos: a dictionary from indices of type int to characters in the
65
vocabulary
66
67
Your vocabulary must have the following form:
68
69
Identifier 0 must be assigned to the unicode element u"\u25A1".
70
This is the empty_square_character.
71
Further, let self.PAD_CHAR = u"\u25A1"
72
Identifier 1 must be assigned to the unicode element u"\u2047".
73
This is the doublequestionmark character, which we'll use
74
as a sentinel to represent that text is missing from the input
75
Further, let self.MASK_CHAR = u"\u2047"
76
Identifiers 2, ..., len(self.itos)-1 should be the sorted list of characters
77
that appear in the data argument.
78
79
--------------
80
Masking Specification
81
82
The __getitem__ function takes an index and returns a data point (x, y) where
83
x and y are Long tensors of length self.block_size. x encodes the input
84
sequence, and y encodes the output sequence.
85
86
0. Use the idx argument of __getitem__ to retrieve the element of self.data
87
at the given index. We'll call the resulting data entry a document.
88
89
1. Randomly truncate the document to a length no less than 4 characters,
90
and no more than int(self.block_size*7/8) characters.
91
92
- IMPORTANT: You are free to decide how to perform this random truncation, but
93
make sure that the length is picked _randomly_ (every possible length from 4
94
to int(self.block_size*7/8) has a chance of being picked) for full credit.
95
96
2. Now, break the (truncated) document into three substrings:
97
98
[prefix] [masked_content] [suffix]
99
100
In other words, choose three strings prefix, masked_content and suffix
101
such that prefix + masked_content + suffix = [the original document].
102
The length of [masked_content] should be random, and 1/4 the length of the
103
truncated document on average.
104
105
- IMPORTANT: You are free to decide how to perform this operation, but
106
make sure that the length is picked _randomly_ (has a chance of being more or
107
less than 1/4 the length of the truncated document) for full credit.
108
109
3. Rearrange these substrings into the following form:
110
111
[prefix] MASK_CHAR [suffix] MASK_CHAR [masked_content] [pads]
112
113
This resulting string, denoted masked_string, serves as the output example.
114
Here MASK_CHAR is the masking character defined in Vocabulary Specification,
115
and [pads] is a string of repeated PAD_CHAR characters chosen so that the
116
entire string is of length self.block_size.
117
Intuitively, the [masked_content], a string, is removed from the document and
118
replaced with MASK_CHAR (the masking character defined in Vocabulary
119
Specification). After the suffix of the string, the MASK_CHAR is seen again,
120
followed by the content that was removed, and the padding characters.
121
122
4. We now use masked_string to construct the input and output example pair. To
123
do so, simply take the input string to be masked_string[:-1], and the output
124
string to be masked_string[1:]. In other words, for each character, the goal is
125
to predict the next character in the masked string.
126
127
5. Making use of the vocabulary that you defined, encode the resulting input
128
and output strings as Long tensors and return the resulting data point.
129
130
----------------
131
Here are some examples of input-output pairs (x, y):
132
133
x: Khatchig Mouradian. Khatchig Mouradian is a jour⁇and tran⁇nalist, writer ⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
134
y: hatchig Mouradian. Khatchig Mouradian is a jour⁇and tran⁇nalist, writer ⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
135
136
x: Jaco⁇enry ⁇b H⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
137
y: aco⁇enry ⁇b H⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
138
139
x: John Stephen. Born in Glasgow, Steph⁇lder's apprentice on⁇en became a we⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
140
y: ohn Stephen. Born in Glasgow, Steph⁇lder's apprentice on⁇en became a we⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
141
142
143
"""
144
class CharCorruptionDataset(Dataset):
145
def __init__(self, data, block_size):
146
self.MASK_CHAR = u"\u2047" # the doublequestionmark character, for mask
147
self.PAD_CHAR = u"\u25A1" # the empty square character, for pad
148
149
chars = list(sorted(list(set(data))))
150
assert self.MASK_CHAR not in chars
151
assert self.PAD_CHAR not in chars
152
chars.insert(0, self.MASK_CHAR)
153
chars.insert(0, self.PAD_CHAR)
154
155
self.stoi = { ch:i for i,ch in enumerate(chars) }
156
self.itos = { i:ch for i,ch in enumerate(chars) }
157
158
data_size, vocab_size = len(data), len(chars)
159
print('data has %d characters, %d unique.' % (data_size, vocab_size))
160
161
self.block_size = block_size
162
self.vocab_size = vocab_size
163
self.data = data.split('\n')
164
165
def __len__(self):
166
# returns the length of the dataset
167
return len(self.data)
168
169
def __getitem__(self, idx):
170
# TODO [part e]: see spec above
171
data = self.data[idx]
172
truncated_length = random.randint(4, int(self.block_size * 7 / 8))
173
data = data[:truncated_length]
174
start_idx = random.randint(0, len(data) - truncated_length//4)
175
prefix, masked_content, suffix = data[:start_idx], data[start_idx:start_idx+(truncated_length//4)], data[start_idx+(truncated_length//4):]
176
masked_string = prefix + self.MASK_CHAR + suffix + self.MASK_CHAR + masked_content
177
masked_string = masked_string + self.PAD_CHAR * (self.block_size - len(masked_string))
178
x, y = masked_string[:-1], masked_string[1:]
179
x = torch.tensor([self.stoi[c] for c in x], dtype=torch.long)
180
y = torch.tensor([self.stoi[c] for c in y], dtype=torch.long)
181
return (x, y)
182
183
"""
184
Code under here is strictly for your debugging purposes; feel free to modify
185
as desired.
186
"""
187
if __name__ == '__main__':
188
argp = argparse.ArgumentParser()
189
argp.add_argument('dataset_type', help="Type of dataset to sample from."
190
"Options: namedata, charcorruption.",
191
choices=["namedata", "charcorruption"])
192
args = argp.parse_args()
193
194
if args.dataset_type == 'namedata':
195
# Even if it hasn't been implemented, we use it to define the vocab
196
corruption_dataset = CharCorruptionDataset(open('wiki.txt', encoding='utf-8').read(), 128)
197
# Make the name dataset
198
name_dataset = NameDataset(corruption_dataset,
199
open('birth_places_train.tsv', encoding='utf-8').read())
200
for _, example in zip(range(4), name_dataset):
201
x, y = example
202
print('x:', ''.join([name_dataset.itos[int(c)] for c in x]))
203
print('y:', ''.join([name_dataset.itos[int(c)] for c in y]))
204
pass
205
elif args.dataset_type == 'charcorruption':
206
corruption_dataset = CharCorruptionDataset(open('wiki.txt', encoding='utf-8').read(), 128)
207
for _, example in zip(range(4), corruption_dataset):
208
x, y = example
209
print('x:', ''.join([corruption_dataset.itos[int(c)] for c in x]))
210
print('y:', ''.join([corruption_dataset.itos[int(c)] for c in y]))
211
else:
212
raise ValueError("Unknown dataset type in command line args: {}"
213
.format(args.dataset_type))
214
215
216