Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/29/supplementary/hmm_bach_chorales_distrax.ipynb
1193 views
Kernel: Python 3

Bach Chorales

Imports

from IPython.utils import io with io.capture_output() as captured: !pip install distrax !pip install numpyro import warnings warnings.filterwarnings("ignore")
import jax from jax import jit, lax, vmap from jax.random import split, PRNGKey, permutation, normal import jax.numpy as jnp from jax.nn import log_softmax, sigmoid try: import tensorflow as tf except ModuleNotFoundError: %pip install -qq tensorflow import tensorflow as tf try: import tensorflow_probability as tfp except ModuleNotFoundError: %pip install -qq tensorflow-probability import tensorflow_probability as tfp import itertools try: import distrax except ModuleNotFoundError: %pip install -qq distrax import distrax from distrax import HMM import numpy as np import matplotlib.pyplot as plt

Load Training Dataset

try: from numpyro.examples.datasets import JSB_CHORALES, load_dataset except ModuleNotFoundError: %pip install -qq numpyro from numpyro.examples.datasets import JSB_CHORALES, load_dataset _, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False) lengths, sequences_all = fetch() # find all the notes that are present at least once in the training set present_notes = (sequences_all == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes with default args) sequences = sequences_all[..., present_notes] print(sequences_all.shape) print(sequences.shape) print(lengths)
Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle. Download complete. (229, 129, 88) (229, 129, 51) [129 65 49 65 114 33 57 49 64 33 108 48 49 48 61 48 65 53 41 52 33 61 41 45 69 39 57 80 86 57 61 105 68 65 48 57 57 52 48 33 93 41 65 49 73 48 33 45 65 52 49 109 49 52 65 41 49 65 77 73 57 41 65 57 44 33 85 72 60 41 49 33 56 52 56 83 57 57 52 41 53 64 61 65 65 48 25 96 41 37 76 65 65 72 41 65 41 37 77 48 85 57 72 48 45 53 49 84 68 49 73 50 96 55 66 98 37 76 65 33 49 77 76 65 128 41 45 61 33 88 61 63 65 68 60 49 48 45 65 45 33 64 64 65 49 63 61 48 45 65 84 120 41 102 49 57 69 57 82 48 76 48 52 113 97 83 33 68 41 49 65 109 108 60 65 57 49 33 57 33 61 113 58 64 65 33 57 49 64 68 60 65 41 72 53 49 57 33 49 89 65 44 49 61 65 69 52 76 33 57 76 57 33 76 101 40 41 44 76 65 49 54 49 49 64 51 109 65 65]

Present Notes

print(jnp.where(present_notes)[0])
[22 25 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75]

Visualization of Training Sequences

N = 4 fig, axs = plt.subplots(1, N, figsize=(20, 10)) axs = axs.reshape(N) for i in range(N): X = sequences[i, :, :].squeeze() ax = axs[i] ax.imshow(X) ax.set_axis_off() plt.show()
Image in a Jupyter notebook

Gradient Descent Implementation

Creates mini-batches.

def hmm_sample_minibatches(iterables, batch_size): sequences, lens = iterables n_seq = len(sequences) for idx in range(0, n_seq, batch_size): yield sequences[idx : min(idx + batch_size, n_seq)], lens[idx : min(idx + batch_size, n_seq)]
def make_hmm(params): obs_logits, trans_logits = params trans_logits = log_softmax(trans_logits) obs_logits = jnp.clip(obs_logits, a_min=-12) trans_logits = jnp.clip(trans_logits, a_min=-12) b = distrax.Bernoulli(logits=obs_logits) obs_dist = distrax.Independent(b, 1) hmm = HMM( trans_dist=distrax.Categorical(logits=trans_logits), obs_dist=obs_dist, init_dist=distrax.Categorical(probs=initial_hmm.init_dist.probs), ) return hmm @jit def hmm_loglikelihood(hmm, observations, lens): def forward_(x, length): return hmm.forward(x, length)[0] / length return vmap(forward_, in_axes=(0, 0))(observations, lens) @jit def loss_fn(params, batch, lens): hmm = make_hmm(params) return -hmm_loglikelihood(hmm, batch, lens).mean() @jit def update(i, opt_state, batch, lens): params = get_params(opt_state) loss, grads = jax.value_and_grad(loss_fn)(params, batch, lens) return opt_update(i, grads, opt_state), loss def fit(observations, lens, beta, trans_dist, batch_size, rng_key=None, num_epochs=300): if rng_key is None: rng_key = PRNGKey(0) opt_state = opt_init((beta.logits, trans_dist.logits)) # opt_state = opt_init((logit(beta.probs), logit(trans_dist.probs))) itercount = itertools.count() num_complete_batches, leftover = jnp.divmod(num_epochs, batch_size) num_batches = num_complete_batches + jnp.where(leftover == 0, 0, 1) def epoch_step(opt_state, key): perm = permutation(key, len(observations)) observatios_, lens_ = observations[perm], lens[perm] sample_generator = hmm_sample_minibatches((observatios_, lens_), batch_size) def train_step(opt_state, i): batch, length = next(sample_generator) opt_state, loss = update(next(itercount), opt_state, batch, length) return opt_state, loss opt_state, losses = jax.lax.scan(train_step, opt_state, jnp.arange(num_batches)) return opt_state, losses.mean() epochs = split(rng_key, num_epochs) opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs) losses = losses.flatten() params = get_params(opt_state) hmm = make_hmm(params) obs_logits, _ = params return hmm, losses, obs_logits

Initialize HMM

def init_hmm(K, D, initial_prob=0.1): trans_probs = distrax.as_distribution(tfp.substrates.jax.distributions.Dirichlet(0.9 * jnp.eye(K) + 0.1)).sample( seed=1 ) trans_dist = distrax.Categorical(probs=trans_probs) init_dist = distrax.Categorical(logits=jnp.zeros((K,))) b = distrax.Bernoulli(probs=jnp.full((K, D), initial_prob)) obs_dist = distrax.Independent(b, 1) initial_hmm = HMM(trans_dist=trans_dist, obs_dist=obs_dist, init_dist=init_dist) return initial_hmm, b

Model Selection

Determine Hyperparameters

num_epochs = 200 # @param {type:"slider", min:1, max:10000, step:0} batch_size = 20 # @param {type:"slider", min:1, max:229, step:0}
hidden_states = [4, 8, 16, 24] batch_sizes = [1, 20, 229] learning_rates = [1e-3, 5e-2, 5e-2] opt_init, opt_update, get_params = None, None, None D = sequences.shape[2]
all_models, all_losses, all_obs_logits = {}, [], [] for learning_rate, batch_size in zip(learning_rates, batch_sizes): opt_init, opt_update, get_params = jax.experimental.optimizers.adam(learning_rate) models = {} losses, obs_logits = [], [] for k in hidden_states: initial_hmm, b = init_hmm(k, D) hmm, loss_, obs_logits_ = fit(sequences, lengths, b, initial_hmm.trans_dist, batch_size, num_epochs=num_epochs) # Store results models[k] = hmm losses.append(loss_) obs_logits.append(obs_logits_) all_models[batch_size] = models all_losses.append(losses) all_obs_logits.append(obs_logits)
fig, axes = plt.subplots(nrows=1, ncols=len(batch_sizes)) axes = np.array(axes) fig.set_figheight(6) fig.set_figwidth(18) for b, (batch_size, ax) in enumerate(zip(batch_sizes, axes.flatten())): for h, k in enumerate(hidden_states): ax.plot(all_losses[b][h], label=f"K={k}") ax.set_title(f"batch_size={batch_size}") ax.legend(loc="upper right") fig.tight_layout() plt.show()
Image in a Jupyter notebook

Visualization of Parameters of HMM

from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_observation_dist(idx, batch_sizes, all_obs_logits): fig, axes = plt.subplots(nrows=1, ncols=len(batch_sizes)) fig.set_figwidth(20) fig.set_figwidth(24) for b, (batch_size, ax) in enumerate(zip(batch_sizes, axes.flatten())): im = ax.imshow(sigmoid(all_obs_logits[b][idx])) ax.set_title(f"batch_size={batch_size}") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="2%", pad=0.1) plt.colorbar(im, cax=cax) fig.tight_layout() plt.show()
# @title Select the Number of Hidden States K = 16 # @param [4, 8, 16, 24] {type:"raw", allow-input: false} i = np.where(np.array(hidden_states) == K)[0][0]

Observation Distribution

plot_observation_dist(i, batch_sizes, all_obs_logits)
Image in a Jupyter notebook

Transition Probabilities

def plot_transition_probs(n_hidden, batch_sizes, all_obs_logits): fig, axes = plt.subplots(nrows=1, ncols=len(batch_sizes)) fig.set_figwidth(20) for batch_size, ax in zip(batch_sizes, axes.flatten()): im = ax.imshow(sigmoid(all_models[batch_size][n_hidden].trans_dist.probs)) ax.set_title(f"batch_size={batch_size}") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.1) plt.colorbar(im, cax=cax) fig.tight_layout() plt.show()
n_hidden = hidden_states[i] plot_transition_probs(n_hidden, batch_sizes, all_models)
Image in a Jupyter notebook

Model Selection

Load Test Dataset

_, fetch = load_dataset(JSB_CHORALES, split="test", shuffle=False)
test_lengths, test_sequences = fetch() # find all the notes that are present at least once in the training set test_sequences = test_sequences[..., present_notes]
def neg_log_likelihood(hmm, test_sequences, test_lengths): return -hmm_loglikelihood(hmm, test_sequences, test_lengths).mean() colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"] plt.figure(figsize=(12, 6)) ax = plt.gca() negative_log_likelihoods = [] for b, (batch_size) in enumerate(batch_sizes): nlls = jnp.array([]) for h, (k, color) in enumerate(zip(hidden_states, colors)): nll = neg_log_likelihood(all_models[batch_size][k], test_sequences, test_lengths) nlls = jnp.append(nlls, nll) ax.bar(b * 1.8 + h * 0.35, nll, color=color, label=f"K={k}" if b == 0 else None, width=0.25, alpha=0.8) negative_log_likelihoods.append(nlls) plt.xticks( [b * 1.8 + 0.52 for b in range(len(batch_sizes))], [f"batch_size={batch_size}" for batch_size in batch_sizes], fontsize=14, ) plt.ylabel("Negative Log Likelihood", fontsize=18) ax.legend(loc="upper right") plt.show()
Image in a Jupyter notebook

Best Model

negative_log_likelihoods = np.vstack(negative_log_likelihoods) best_batch_idx, best_hidden_idx = np.unravel_index(negative_log_likelihoods.argmin(), negative_log_likelihoods.shape) best_batch_size, best_n_hidden = batch_sizes[best_batch_idx], hidden_states[best_hidden_idx]
print(f"Best Model : batch_size={best_batch_size}, K={best_n_hidden}") print(f"Test Negative Loglikelihood of Best Model : {negative_log_likelihoods[best_batch_idx][best_hidden_idx]}")
Best Model : batch_size=20, K=24 Test Negative Loglikelihood of Best Model : 8.054402351379395

Visualization of Parameters of the Best HMM

Observation Distribution

plt.figure(figsize=(16, 16)) ax = plt.gca() im = plt.imshow(sigmoid(all_obs_logits[best_batch_idx][best_hidden_idx])) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="2%", pad=0.1) plt.colorbar(im, cax=cax);
Image in a Jupyter notebook

Transition Probabilities

plt.figure(figsize=(6, 6)) ax = plt.gca() im = plt.imshow(all_models[best_batch_size][best_n_hidden].trans_dist.probs) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.1) plt.colorbar(im, cax=cax);
Image in a Jupyter notebook