Path: blob/master/notebooks/book2/29/supplementary/hmm_with_ngram.ipynb
1193 views
Kernel: Python 3
Setup
In [1]:
Out[1]:
Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (3.2.5)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk) (1.15.0)
Requirement already satisfied: distrax in /usr/local/lib/python3.7/dist-packages (0.0.2)
Requirement already satisfied: tensorflow-probability>=0.13.0rc0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.13.0)
Requirement already satisfied: chex>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.0.8)
Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.12.0)
Requirement already satisfied: jaxlib>=0.1.67 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.1.70+cuda110)
Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.2.19)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (1.19.5)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.9.0->distrax) (1.15.0)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.11.1)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.1.6)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->distrax) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.67->distrax) (1.12)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.67->distrax) (1.4.1)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.13.0rc0->distrax) (4.4.2)
Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.13.0rc0->distrax) (0.4.0)
Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.13.0rc0->distrax) (1.3.0)
In [2]:
Out[2]:
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 4230k 100 4230k 0 0 9704k 0 --:--:-- --:--:-- --:--:-- 9704k
In [3]:
Out[3]:
Requirement already satisfied: superimport in /usr/local/lib/python3.7/dist-packages (0.3.3)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from superimport) (2.23.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2021.5.30)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2.10)
In [4]:
ClassConditionalBMM
In [5]:
Out[5]:
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [6]:
In [7]:
HMM
In [8]:
In [9]:
Loading Dataset
In [10]:
Sampling Images
In [11]:
Out[11]:
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5847: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "astype")
In [12]:
Out[12]:
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5847: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "astype")
In [13]:
In [14]:
Out[14]:
NGram
In [15]:
In [16]:
In [17]:
In [18]:
Out[18]:
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5847: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "astype")
In [19]:
Out[19]:
DeviceArray(-374.44446, dtype=float32)
In [20]:
Out[20]:
'Christians first int'
In [21]:
Out[21]: