Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/29/supplementary/hmm_with_ngram.ipynb
1193 views
Kernel: Python 3

Setup

!pip install nltk !pip install distrax
Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (3.2.5) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk) (1.15.0) Requirement already satisfied: distrax in /usr/local/lib/python3.7/dist-packages (0.0.2) Requirement already satisfied: tensorflow-probability>=0.13.0rc0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.13.0) Requirement already satisfied: chex>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.0.8) Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.12.0) Requirement already satisfied: jaxlib>=0.1.67 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.1.70+cuda110) Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.2.19) Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (1.19.5) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.9.0->distrax) (1.15.0) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.11.1) Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.1.6) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->distrax) (3.3.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.67->distrax) (1.12) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.67->distrax) (1.4.1) Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.13.0rc0->distrax) (4.4.2) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.13.0rc0->distrax) (0.4.0) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.13.0rc0->distrax) (1.3.0)
!git clone --depth 1 https://github.com/probml/pyprobml /pyprobml &> /dev/null !curl -o bible.txt https://raw.githubusercontent.com/probml/probml-data/main/data/bible.txt %cd -q /pyprobml/scripts
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 4230k 100 4230k 0 0 9704k 0 --:--:-- --:--:-- --:--:-- 9704k
!pip install superimport
Requirement already satisfied: superimport in /usr/local/lib/python3.7/dist-packages (0.3.3) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from superimport) (2.23.0) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (1.24.3) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2021.5.30) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2.10)
from conditional_bernoulli_mix_lib import ClassConditionalBMM from conditional_bernoulli_mix_utils import ( fake_test_data, encode, decode, get_decoded_samples, get_emnist_images_per_class, ) from noisy_spelling_hmm import Word from ngram_character_demo import ngram_model_fit, read_file, preprocessing, ngram_model_sample, ngram_loglikelihood from distrax import HMM from nltk.util import ngrams from nltk import FreqDist, LidstoneProbDist import numpy as np import re import string from collections import defaultdict from dataclasses import dataclass from jax import vmap import jax.numpy as jnp import jax from jax.random import PRNGKey, split import distrax import numpy as np from matplotlib import pyplot as plt

ClassConditionalBMM

select_n = 25 dataset, targets = get_emnist_images_per_class(select_n) dataset, targets = jnp.array(dataset), jnp.array(targets)
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
""" During preprocessing of the text data, we removed punctuation whereas case folding is not applied. The text data only contains upper and lower case letters and hence there are 52 different characters in total. """ n_char = 2 * 26
def get_bmm(n_mix, dataset, targets): mixing_coeffs = jnp.array(np.full((n_char, n_mix), 1.0 / n_mix)) p_min, p_max = 0.4, 0.6 n_pixels = 28 * 28 probs = jnp.array(np.random.uniform(p_min, p_max, (n_char, n_mix, n_pixels))) class_priors = jnp.eye(n_char) class_cond_bmm = ClassConditionalBMM( mixing_coeffs=mixing_coeffs, probs=probs, class_priors=class_priors, n_char=n_char ) _ = class_cond_bmm.fit_em(dataset, targets, 8) return class_cond_bmm.model

HMM

def get_transition_probs(bigram): probs = np.zeros((52, 52)) for prev, pd in bigram.prob_dists.items(): if prev == " ": continue lowercase = prev.islower() i = lowercase * 26 + (ord(prev.lower()) - 97) for cur in pd.samples(): if cur == " ": continue lowercase = cur.islower() j = lowercase * 26 + (ord(cur.lower()) - 97) probs[i, j] += pd.prob(cur) return probs
def init_hmm_from_bigram(bigram, bmm): init_dist = distrax.Categorical(logits=jnp.zeros((n_char,))) probs = get_transition_probs(bigram) trans_dist = distrax.Categorical(probs=probs) obs_dist = bmm hmm = HMM(init_dist, trans_dist, obs_dist) return hmm

Loading Dataset

select_n = 25 dataset, targets = get_emnist_images_per_class(select_n) dataset, targets = jnp.array(dataset), jnp.array(targets) filepath = "/content/bible.txt" text = read_file(filepath) data = preprocessing(text, False)

Sampling Images

n, n_mix = 2, 30 bigram = ngram_model_fit(n, data, smoothing=1) bmm = get_bmm(n_mix, dataset, targets) hmm = init_hmm_from_bigram(bigram, bmm)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5847: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
rng_key = PRNGKey(0) seq_len = 6 Z, X = hmm.sample(seed=rng_key, seq_len=seq_len)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5847: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
def plot_seq(X, seq_len, figsize): fig, axes = plt.subplots(nrows=1, ncols=seq_len, figsize=figsize) for x, ax in zip(X, axes.flatten()): ax.imshow(x.reshape((28, 28)), cmap="gray") plt.tight_layout() plt.show()
plot_seq(X, seq_len, (10, 8))
Image in a Jupyter notebook

NGram

n = 10 n_gram = ngram_model_fit(n, data, smoothing=1)
text_length = 11 prefix = "Christian" Z = ngram_model_sample(n_gram, text_length, prefix) log_p_Z = ngram_loglikelihood(n_gram, Z)
def sample_img_seq_given_char_seq(bmm, z, rng_key): LL = 0 T = len(z) keys = split(rng_key, T) Xs = [] for t, key in enumerate(keys): cur_char = z[t] X = jnp.zeros((784,)) if cur_char != " ": lowercase = cur_char.islower() c = lowercase * 26 + (ord(cur_char.lower()) - 97) X = bmm.sample(seed=key)[c] log_p_X = bmm.log_prob(X)[c] Xs.append(X) LL += log_p_X return jnp.vstack(Xs), LL
rng_key = PRNGKey(0) images, LL = sample_img_seq_given_char_seq(bmm, Z, rng_key)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5847: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
LL
DeviceArray(-374.44446, dtype=float32)
Z
'Christians first int'
plot_seq(images, text_length + len(prefix), figsize=(40, 20))
Image in a Jupyter notebook