Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/experiments/arithmetic_dataset.py
4918 views
1
"""
2
---
3
title: Arithmetic Dataset
4
summary: >
5
This creates arithmetic problems.
6
---
7
8
*This is based on code by [Georges Harik (@gharik)](https://twitter.com/gharik).*
9
"""
10
11
import random
12
import string
13
from typing import List
14
15
import torch
16
from labml.logger import Text
17
from torch.utils.data import DataLoader, Dataset
18
19
from labml import monit, logger, tracker
20
from labml.configs import option
21
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
22
23
24
class ArithmeticDataset(Dataset):
25
"""
26
## Arithmetic Dataset
27
28
This creates arithmetic addition problems and solutions with workings.
29
We've only implemented addition so far.
30
31
It's based on a character level tokenization.
32
"""
33
34
def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
35
"""
36
:param seq_len: is the sequence length of generated math problems.
37
We fill as many problems as possible upto this length
38
:max_digits: is the maximum number of digits in the operand integers
39
:n_sequences: is the number of sequences per epoch
40
"""
41
self.n_sequences = n_sequences
42
self.max_digits = max_digits
43
self.seq_len = seq_len
44
# Token id to string
45
self.itos = list(string.digits + 'xe =\n?+;')
46
# Character to token id
47
self.stoi = {c: i for i, c in enumerate(self.itos)}
48
49
@staticmethod
50
def make_int(n_digits: int):
51
"""
52
Generates an integer with `n_digit` number of digits
53
"""
54
res = 0
55
for i in range(n_digits):
56
d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57
res = res * 10 + d
58
59
return res
60
61
@staticmethod
62
def get_add_explanation(x: int, y: int):
63
"""
64
Generates the workings for `x + y`.
65
For example for `11+29` it generates
66
`1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0`.
67
"""
68
69
carry = 0
70
e = 0
71
explanation = []
72
while x > 0 or y > 0 or carry > 0:
73
rx, ry = x % 10, y % 10
74
total = rx + ry + carry
75
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76
x, y, carry = x // 10, y // 10, total // 10
77
e += 1
78
79
return ' '.join(explanation)
80
81
# Make a problem with a pre_explanation or not
82
def make_add_problem(self):
83
"""
84
Creates an arithmetic addition problem with workings and answer.
85
"""
86
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89
explanation = self.get_add_explanation(x, y)
90
return f"x={x}+{y}; {explanation} x=={x + y}\n"
91
92
def get_qa(self):
93
"""
94
Get arithmetic problem and answer. This is used for evaluation.
95
"""
96
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99
return f'x={x}+{y};', f'{x + y}'
100
101
def get_packed_math_input(self):
102
"""
103
Generate multiple problems and pack them into a sequence.
104
"""
105
s_enc = []
106
while len(s_enc) <= self.seq_len:
107
s_part = self.make_add_problem()
108
s_part_enc = self.encode('?' + s_part)
109
s_enc = s_enc + s_part_enc
110
return s_enc
111
112
def encode(self, s: str):
113
"""
114
Encode a given string
115
"""
116
return [self.stoi[c] for c in s]
117
118
def decode(self, arr: List[int]):
119
"""
120
Decode a list of token ids
121
"""
122
return ''.join([self.itos[c] for c in arr])
123
124
def __getitem__(self, idx: int):
125
"""
126
Get a input and target pair for auto-regressive modelling
127
"""
128
s = torch.tensor(self.get_packed_math_input())
129
return s[:self.seq_len], s[1:self.seq_len + 1]
130
131
def __len__(self):
132
"""
133
Number of sequences per epoch
134
"""
135
return self.n_sequences
136
137
138
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
139
"""
140
## Arithmetic Task Experiment Configurations
141
"""
142
# Maximum number of digits per operand integer
143
max_digits: int = 4
144
# Number of training sequences per epoch
145
train_sequences_per_epoch: int = 2 ** 12
146
# Training data loader
147
train_loader: DataLoader = 'arithmetic_train_loader'
148
# Number of problems in evaluation
149
n_tests: int = 64
150
# No need of a validation dataset
151
validator = None
152
# Number of times to run evaluations per epoch
153
inner_iterations = 4
154
# Number of tokens in the vocabulary
155
n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
156
157
@torch.no_grad()
158
def sample(self):
159
"""
160
### Evaluation
161
162
We use the sampling function to evaluate the model on a set of problems
163
"""
164
165
# Skip in the first epoch
166
if self.training_loop.idx < 1:
167
return
168
169
# Create a dataset to generate problems
170
dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
171
# Get a set of problems and answers
172
qa = [dataset.get_qa() for _ in range(self.n_tests)]
173
# Collect the problems only
174
questions = [p[0] for p in qa]
175
176
# Create a tensor with only the initial token
177
data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
178
# Move to device
179
data = data.to(self.device)
180
181
# Number of sequences that have completed
182
finished = torch.zeros((len(questions),)).bool().to(self.device)
183
# Token id of the new line character - this marks end of the answer
184
new_line = dataset.stoi['\n']
185
186
# Sampled results
187
results = [p[0] for p in questions]
188
189
# Sample upto sequence length
190
for i in monit.iterate('Sample', self.seq_len - 1):
191
# If all the sequences have completed we skip this
192
if finished.sum() == len(finished):
193
continue
194
195
# Get the model output
196
output, *_ = self.model(data)
197
# Get the model prediction (greedy)
198
output = output[-1].argmax(dim=-1)
199
200
# Find which sequences have finished
201
finished = finished | (output == new_line)
202
# Skip if all have finished
203
if finished.sum() == len(finished):
204
continue
205
206
# Override with the question
207
for j, p in enumerate(questions):
208
if len(p) > i + 1:
209
output[j] = dataset.stoi[p[i + 1]]
210
211
# Add the next token to the input
212
data = torch.cat([data, output[None, :]], dim=0)
213
214
# Get the sampled results
215
for j, c in enumerate(output):
216
results[j] += dataset.itos[c]
217
218
# Discard everything after the answer in the results
219
results = [r.split('\n')[0] for r in results]
220
221
# Log a sample
222
res_sample = results[0].split(';')
223
logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
224
225
# Get the answers
226
results = [r.split('x==')[-1] for r in results]
227
228
# Count the number of correct answers
229
correct = 0
230
for r, _qa in zip(results, qa):
231
if r == _qa[1]:
232
correct += 1
233
234
# Log the score
235
tracker.save('score', correct / len(results))
236
237
238
@option(ArithmeticAutoregression.train_loader)
239
def arithmetic_train_loader(c: ArithmeticAutoregression):
240
"""
241
Training data loader
242
"""
243
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244
batch_size=c.batch_size,
245
collate_fn=transpose_batch,
246
num_workers=4)
247
248
249
def _test():
250
"""
251
Code to test generated problems
252
"""
253
dataset = ArithmeticDataset(256, 8, 10)
254
255
print(dataset.decode(dataset.get_packed_math_input()))
256
257
258
#
259
if __name__ == '__main__':
260
_test()
261
262