Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/03/mix_bernoulli_em_mnist.ipynb
1192 views
Kernel: Unknown Kernel
""" Fits Bernoulli mixture model for mnist digits using em algorithm Author: Meduri Venkata Shivaditya, Aleyna Kara(@karalleyna) """ from jax.random import PRNGKey, randint try: import tensorflow as tf except ModuleNotFoundError: %pip install -qq tensorflow import tensorflow as tf try: from probml_utils.mix_bernoulli_lib import BMM except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils.mix_bernoulli_lib import BMM from probml_utils.mix_bernoulli_em_mnist import mnist_data def main(): n_obs = 1000 observations = mnist_data(n_obs) # subsample the MNIST dataset n_vars = len(observations[0]) K, num_of_iters = 12, 10 n_row, n_col = 3, 4 bmm = BMM(K, n_vars) _ = bmm.fit_em(observations, num_of_iters=num_of_iters) bmm.plot(n_row, n_col, "bmm_em_mnist") if __name__ == "__main__": main()