Path: blob/master/labml_nn/adaptive_computation/parity.py
4918 views
"""1---2title: "Parity Task"3summary: >4This creates data for Parity Task from the paper Adaptive Computation Time5for Recurrent Neural Networks6---78# Parity Task910This creates data for Parity Task from the paper11[Adaptive Computation Time for Recurrent Neural Networks](https://arxiv.org/abs/1603.08983).1213The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.14The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.15The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.16"""1718from typing import Tuple1920import torch21from torch.utils.data import Dataset222324class ParityDataset(Dataset):25"""26### Parity dataset27"""2829def __init__(self, n_samples: int, n_elems: int = 64):30"""31* `n_samples` is the number of samples32* `n_elems` is the number of elements in the input vector33"""34self.n_samples = n_samples35self.n_elems = n_elems3637def __len__(self):38"""39Size of the dataset40"""41return self.n_samples4243def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:44"""45Generate a sample46"""4748# Empty vector49x = torch.zeros((self.n_elems,))50# Number of non-zero elements - a random number between $1$ and total number of elements51n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()52# Fill non-zero elements with $1$'s and $-1$'s53x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 154# Randomly permute the elements55x = x[torch.randperm(self.n_elems)]5657# The parity58y = (x == 1.).sum() % 25960#61return x, y626364