Path: blob/master/labml_nn/experiments/arithmetic_dataset.py
4918 views
"""1---2title: Arithmetic Dataset3summary: >4This creates arithmetic problems.5---67*This is based on code by [Georges Harik (@gharik)](https://twitter.com/gharik).*8"""910import random11import string12from typing import List1314import torch15from labml.logger import Text16from torch.utils.data import DataLoader, Dataset1718from labml import monit, logger, tracker19from labml.configs import option20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch212223class ArithmeticDataset(Dataset):24"""25## Arithmetic Dataset2627This creates arithmetic addition problems and solutions with workings.28We've only implemented addition so far.2930It's based on a character level tokenization.31"""3233def __init__(self, seq_len: int, max_digits: int, n_sequences: int):34"""35:param seq_len: is the sequence length of generated math problems.36We fill as many problems as possible upto this length37:max_digits: is the maximum number of digits in the operand integers38:n_sequences: is the number of sequences per epoch39"""40self.n_sequences = n_sequences41self.max_digits = max_digits42self.seq_len = seq_len43# Token id to string44self.itos = list(string.digits + 'xe =\n?+;')45# Character to token id46self.stoi = {c: i for i, c in enumerate(self.itos)}4748@staticmethod49def make_int(n_digits: int):50"""51Generates an integer with `n_digit` number of digits52"""53res = 054for i in range(n_digits):55d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)56res = res * 10 + d5758return res5960@staticmethod61def get_add_explanation(x: int, y: int):62"""63Generates the workings for `x + y`.64For example for `11+29` it generates65`1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0`.66"""6768carry = 069e = 070explanation = []71while x > 0 or y > 0 or carry > 0:72rx, ry = x % 10, y % 1073total = rx + ry + carry74explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")75x, y, carry = x // 10, y // 10, total // 1076e += 17778return ' '.join(explanation)7980# Make a problem with a pre_explanation or not81def make_add_problem(self):82"""83Creates an arithmetic addition problem with workings and answer.84"""85x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))86y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))8788explanation = self.get_add_explanation(x, y)89return f"x={x}+{y}; {explanation} x=={x + y}\n"9091def get_qa(self):92"""93Get arithmetic problem and answer. This is used for evaluation.94"""95x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))96y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))9798return f'x={x}+{y};', f'{x + y}'99100def get_packed_math_input(self):101"""102Generate multiple problems and pack them into a sequence.103"""104s_enc = []105while len(s_enc) <= self.seq_len:106s_part = self.make_add_problem()107s_part_enc = self.encode('?' + s_part)108s_enc = s_enc + s_part_enc109return s_enc110111def encode(self, s: str):112"""113Encode a given string114"""115return [self.stoi[c] for c in s]116117def decode(self, arr: List[int]):118"""119Decode a list of token ids120"""121return ''.join([self.itos[c] for c in arr])122123def __getitem__(self, idx: int):124"""125Get a input and target pair for auto-regressive modelling126"""127s = torch.tensor(self.get_packed_math_input())128return s[:self.seq_len], s[1:self.seq_len + 1]129130def __len__(self):131"""132Number of sequences per epoch133"""134return self.n_sequences135136137class ArithmeticAutoregression(NLPAutoRegressionConfigs):138"""139## Arithmetic Task Experiment Configurations140"""141# Maximum number of digits per operand integer142max_digits: int = 4143# Number of training sequences per epoch144train_sequences_per_epoch: int = 2 ** 12145# Training data loader146train_loader: DataLoader = 'arithmetic_train_loader'147# Number of problems in evaluation148n_tests: int = 64149# No need of a validation dataset150validator = None151# Number of times to run evaluations per epoch152inner_iterations = 4153# Number of tokens in the vocabulary154n_tokens = len(ArithmeticDataset(1, 1, 1).itos)155156@torch.no_grad()157def sample(self):158"""159### Evaluation160161We use the sampling function to evaluate the model on a set of problems162"""163164# Skip in the first epoch165if self.training_loop.idx < 1:166return167168# Create a dataset to generate problems169dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)170# Get a set of problems and answers171qa = [dataset.get_qa() for _ in range(self.n_tests)]172# Collect the problems only173questions = [p[0] for p in qa]174175# Create a tensor with only the initial token176data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])177# Move to device178data = data.to(self.device)179180# Number of sequences that have completed181finished = torch.zeros((len(questions),)).bool().to(self.device)182# Token id of the new line character - this marks end of the answer183new_line = dataset.stoi['\n']184185# Sampled results186results = [p[0] for p in questions]187188# Sample upto sequence length189for i in monit.iterate('Sample', self.seq_len - 1):190# If all the sequences have completed we skip this191if finished.sum() == len(finished):192continue193194# Get the model output195output, *_ = self.model(data)196# Get the model prediction (greedy)197output = output[-1].argmax(dim=-1)198199# Find which sequences have finished200finished = finished | (output == new_line)201# Skip if all have finished202if finished.sum() == len(finished):203continue204205# Override with the question206for j, p in enumerate(questions):207if len(p) > i + 1:208output[j] = dataset.stoi[p[i + 1]]209210# Add the next token to the input211data = torch.cat([data, output[None, :]], dim=0)212213# Get the sampled results214for j, c in enumerate(output):215results[j] += dataset.itos[c]216217# Discard everything after the answer in the results218results = [r.split('\n')[0] for r in results]219220# Log a sample221res_sample = results[0].split(';')222logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])223224# Get the answers225results = [r.split('x==')[-1] for r in results]226227# Count the number of correct answers228correct = 0229for r, _qa in zip(results, qa):230if r == _qa[1]:231correct += 1232233# Log the score234tracker.save('score', correct / len(results))235236237@option(ArithmeticAutoregression.train_loader)238def arithmetic_train_loader(c: ArithmeticAutoregression):239"""240Training data loader241"""242return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),243batch_size=c.batch_size,244collate_fn=transpose_batch,245num_workers=4)246247248def _test():249"""250Code to test generated problems251"""252dataset = ArithmeticDataset(256, 8, 10)253254print(dataset.decode(dataset.get_packed_math_input()))255256257#258if __name__ == '__main__':259_test()260261262