Path: blob/master/deprecated/notebooks/gibbs_demo_potts_jax.ipynb
1192 views
Kernel: Python 3
Gibbs sampling for a Potts model on a 2d lattice
Ming Liang Ang.
The math behind the model
The potts model
In order to efficiently compute for all the different states in our potts model we use a convolution. The idea is to first reperesent each potts model state as a one-hot state and then apply a convolution to compute the logits.
An example
Where the matrix correspond to the number of neighbours with the same value around in the matrix
For more than 2 states, we represent the above matrix as a 3d tensor which you can imagine as the state matrix but with each element as a one hot vector.
Import libaries
In [ ]:
RNG key
In [ ]:
The number of states and size of the 2d grid
In [ ]:
The convolutional kernel for computing energy of markov blanket of each node
In [ ]:
Creating the checkerboard
In [ ]:
In [ ]:
In [ ]:
In [ ]:
Running the test
In [ ]:
Running the model
In [ ]:
In [ ]:
HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))
DeviceArray([[9, 3, 4, ..., 5, 9, 9],
[3, 2, 3, ..., 3, 8, 5],
[4, 0, 3, ..., 3, 7, 6],
...,
[5, 5, 5, ..., 8, 2, 5],
[5, 5, 7, ..., 3, 2, 3],
[4, 9, 5, ..., 6, 7, 3]], dtype=int32)
In [ ]:
HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))
HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))
HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))
HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))
In [ ]: