Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/adaptive_computation/parity.py
4918 views
1
"""
2
---
3
title: "Parity Task"
4
summary: >
5
This creates data for Parity Task from the paper Adaptive Computation Time
6
for Recurrent Neural Networks
7
---
8
9
# Parity Task
10
11
This creates data for Parity Task from the paper
12
[Adaptive Computation Time for Recurrent Neural Networks](https://arxiv.org/abs/1603.08983).
13
14
The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.
15
The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.
16
The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.
17
"""
18
19
from typing import Tuple
20
21
import torch
22
from torch.utils.data import Dataset
23
24
25
class ParityDataset(Dataset):
26
"""
27
### Parity dataset
28
"""
29
30
def __init__(self, n_samples: int, n_elems: int = 64):
31
"""
32
* `n_samples` is the number of samples
33
* `n_elems` is the number of elements in the input vector
34
"""
35
self.n_samples = n_samples
36
self.n_elems = n_elems
37
38
def __len__(self):
39
"""
40
Size of the dataset
41
"""
42
return self.n_samples
43
44
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
45
"""
46
Generate a sample
47
"""
48
49
# Empty vector
50
x = torch.zeros((self.n_elems,))
51
# Number of non-zero elements - a random number between $1$ and total number of elements
52
n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
53
# Fill non-zero elements with $1$'s and $-1$'s
54
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
55
# Randomly permute the elements
56
x = x[torch.randperm(self.n_elems)]
57
58
# The parity
59
y = (x == 1.).sum() % 2
60
61
#
62
return x, y
63
64