Path: blob/master/notebooks/tutorials/haiku_intro.ipynb
1192 views
An introduction to haiku (neural network library in JAX)
https://github.com/deepmind/dm-haiku
Haiku is a JAX version of the Sonnet neural network library (which was written in Tensorflow2). The main thing it does is to provide a way to convert object-oriented (stateful) code into functionally pure code, which can then be processed by JAX transformations like jit and grad. In addition it has implementations of common neural net building blocks.
Below we give a brief introduction, based on the offical docs.
Haiku function transformations
The main thing haiku offers is a way to let the user write a function that defines and accesses mutable parameters inside the function, and then to transform this into a function that takes the parameters as explicit arguments. (The advantage of the implicit method will become clearer later, when we consider modules, which let the user define parameters using nested objects.)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-19-c3f935c6e770> in <module>()
----> 1 params['~']['b'] = jnp.array([2.0, 2.0])
TypeError: 'FlatMap' object does not support item assignment
Transforming stateful functions
We can create a function with internal state that is mutated on each call, but is treated separately from the fixed parameters (which are usually mutated by an external optimizer). Below we illustrate this for a simple counter example, that gets incremented on each call.
Modules
Creating a single dict of parameters and passing it as an argument is easy, and haiku is overkill for such cases. However we often have nested parameterized functions, each of which has metadata (like output_sizes
above) that needs to specified. In such cases it is easier to work with haiku modules. These are just like regular Python classes (no required methods), but typically have a __init__
constructor and a __call__
method that can be invoked when calling the module. Below we reimplement the affine function f1 as a module.
Nested and built-in modules
We can nest modules inside of each other. This allows us to create complex functions. Haiku ships with many common layers, as well as a small number of common models, like MLPs and Resnets. (A model is just multiple layers.)
Stochastic modules
If the module is stochastic, we have to pass the RNG to the apply function (as well as the init function), as we show below. We can use hk.next_rng_key()
to derive a new key from the one that the user passes to apply
. This is useful for when we have nested modules.
Combining JAX Function transformations and Haiku
We cannot apply JAX function transformations, like jit and grad, inside of a haiku module, since modules are impure. So we have to use hk.jit
, hk.grad
, etc. See this page for details. However, after transforming the haiku code to be pure, we can apply JAX transformations as usual.
(See also the equinox libary for an alternative approach to this problem.)
Example: MLP on MNIST
This example is modified from https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.