Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/tutorials/haiku_intro.ipynb
1192 views
Kernel: Python 3

Open In Colab

An introduction to haiku (neural network library in JAX)

https://github.com/deepmind/dm-haiku

Haiku is a JAX version of the Sonnet neural network library (which was written in Tensorflow2). The main thing it does is to provide a way to convert object-oriented (stateful) code into functionally pure code, which can then be processed by JAX transformations like jit and grad. In addition it has implementations of common neural net building blocks.

Below we give a brief introduction, based on the offical docs.

%%capture !pip install git+https://github.com/deepmind/dm-haiku import haiku as hk
%%capture !pip install git+git://github.com/deepmind/optax.git import optax
import haiku as hk import jax import jax.numpy as jnp import numpy as np

Haiku function transformations

The main thing haiku offers is a way to let the user write a function that defines and accesses mutable parameters inside the function, and then to transform this into a function that takes the parameters as explicit arguments. (The advantage of the implicit method will become clearer later, when we consider modules, which let the user define parameters using nested objects.)

# Here is a function that takes in data x, and meta-data output_size, # but creates its mutable parameters internally. # The parameters define an affine mapping, f1(x) = b + W*x def f1(x, output_size): j, k = x.shape[-1], output_size w_init = hk.initializers.TruncatedNormal(1.0 / np.sqrt(j)) w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init) b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones) return jnp.dot(x, w) + b
# transform will convert f1 to a function that explicitly uses parameters, which we call f2. # (We explain the rng part later.) f2 = hk.without_apply_rng(hk.transform(f1)) # f2 is a struct with two functions, init and apply print(f2)
Transformed(init=<function without_state.<locals>.init_fn at 0x7fbe71410320>, apply=<function without_apply_rng.<locals>.apply_fn at 0x7fbe714105f0>)
# The init function creates an initial random set of parameters # by calling f1 on some data x (the values don't matter, just the shape) # and using the RNG. # The params are stoerd in a haiku FlatMap (like a FrozenDict) output_size = 2 dummy_x = jnp.array([[1.0, 2.0, 3.0]]) rng_key = jax.random.PRNGKey(42) # params = f2.init(rng=rng_key, x=dummy_x, output_size = output_size) params = f2.init(rng_key, dummy_x, output_size) print(params)
FlatMap({ '~': FlatMap({ 'w': DeviceArray([[-0.30350363, 0.5123802 ], [ 0.08009141, -0.3163005 ], [ 0.6056666 , 0.58207023]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32), }), })
p = params["~"] print(p["b"])
[1. 1.]
# params are frozen params["~"]["b"] = jnp.array([2.0, 2.0])
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-19-c3f935c6e770> in <module>() ----> 1 params['~']['b'] = jnp.array([2.0, 2.0]) TypeError: 'FlatMap' object does not support item assignment
# The apply function takes a param FlatMap and injects it into the original f1 function sample_x = jnp.array([[1.0, 2.0, 3.0]]) output_1 = f2.apply(params=params, x=sample_x, output_size=output_size) print(output_1)
[[2.6736789 2.62599 ]]

Transforming stateful functions

We can create a function with internal state that is mutated on each call, but is treated separately from the fixed parameters (which are usually mutated by an external optimizer). Below we illustrate this for a simple counter example, that gets incremented on each call.

def stateful_f(x): counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones) multiplier = hk.get_parameter( "multiplier", shape=[ 1, ], dtype=x.dtype, init=jnp.ones, ) hk.set_state("counter", counter + 1) output = x + multiplier * counter return output stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f)) sample_x = jnp.array( [ [ 5.0, ] ] ) params, state = stateful_forward.init(x=sample_x, rng=rng_key) print(f"Initial params:\n{params}\nInitial state:\n{state}") print("##########") for i in range(3): output, state = stateful_forward.apply(params, state, x=sample_x) print(f"After {i+1} iterations:\nOutput: {output}\nState: {state}") print("##########")
Initial params: FlatMap({'~': FlatMap({'multiplier': DeviceArray([1.], dtype=float32)})}) Initial state: FlatMap({'~': FlatMap({'counter': DeviceArray(1, dtype=int32)})}) ########## After 1 iterations: Output: [[6.]] State: FlatMap({'~': FlatMap({'counter': DeviceArray(2, dtype=int32)})}) ########## After 2 iterations: Output: [[7.]] State: FlatMap({'~': FlatMap({'counter': DeviceArray(3, dtype=int32)})}) ########## After 3 iterations: Output: [[8.]] State: FlatMap({'~': FlatMap({'counter': DeviceArray(4, dtype=int32)})}) ##########

Modules

Creating a single dict of parameters and passing it as an argument is easy, and haiku is overkill for such cases. However we often have nested parameterized functions, each of which has metadata (like output_sizes above) that needs to specified. In such cases it is easier to work with haiku modules. These are just like regular Python classes (no required methods), but typically have a __init__ constructor and a __call__ method that can be invoked when calling the module. Below we reimplement the affine function f1 as a module.

class MyLinear1(hk.Module): def __init__(self, output_size, name=None): super().__init__(name=name) self.output_size = output_size def __call__(self, x): j, k = x.shape[-1], self.output_size w_init = hk.initializers.TruncatedNormal(1.0 / np.sqrt(j)) w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init) b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones) return jnp.dot(x, w) + b
def _forward_fn_linear1(x): module = MyLinear1(output_size=2) return module(x) forward_linear1 = hk.without_apply_rng(hk.transform(_forward_fn_linear1))
dummy_x = jnp.array([[1.0, 2.0, 3.0]]) rng_key = jax.random.PRNGKey(42) params = forward_linear1.init(rng=rng_key, x=dummy_x) print(params) sample_x = jnp.array([[1.0, 2.0, 3.0]]) output_1 = forward_linear1.apply(params=params, x=sample_x) print(output_1)
FlatMap({ 'my_linear1': FlatMap({ 'w': DeviceArray([[-0.30350363, 0.5123802 ], [ 0.08009141, -0.3163005 ], [ 0.6056666 , 0.58207023]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32), }), }) [[2.6736789 2.62599 ]]

Nested and built-in modules

We can nest modules inside of each other. This allows us to create complex functions. Haiku ships with many common layers, as well as a small number of common models, like MLPs and Resnets. (A model is just multiple layers.)

class MyModuleCustom(hk.Module): def __init__(self, output_size=2, name="custom_linear"): super().__init__(name=name) self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name="hk_internal_linear") self._internal_linear_2 = MyLinear1(output_size=output_size, name="old_linear") def __call__(self, x): return self._internal_linear_2(self._internal_linear_1(x)) def _custom_forward_fn(x): module = MyModuleCustom() return module(x) custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn)) params = custom_forward_without_rng.init(rng=rng_key, x=sample_x) params
FlatMap({ 'custom_linear/~/hk_internal_linear/~/linear_0': FlatMap({ 'w': DeviceArray([[ 1.51595 , -0.23353335]], dtype=float32), 'b': DeviceArray([0., 0.], dtype=float32), }), 'custom_linear/~/hk_internal_linear/~/linear_1': FlatMap({ 'w': DeviceArray([[-0.22075887, -0.27375957, 0.5931483 ], [ 0.78180677, 0.72626334, -0.6860752 ]], dtype=float32), 'b': DeviceArray([0., 0., 0.], dtype=float32), }), 'custom_linear/~/old_linear': FlatMap({ 'w': DeviceArray([[ 0.28584382, 0.31626168], [ 0.23357749, -0.4827032 ], [-0.14647584, -0.7185701 ]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32), }), })

Stochastic modules

If the module is stochastic, we have to pass the RNG to the apply function (as well as the init function), as we show below. We can use hk.next_rng_key() to derive a new key from the one that the user passes to apply. This is useful for when we have nested modules.

class HkRandom2(hk.Module): def __init__(self, rate=0.5): super().__init__() self.rate = rate def __call__(self, x): key1 = hk.next_rng_key() return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape) class HkRandomNest(hk.Module): def __init__(self, rate=0.5): super().__init__() self.rate = rate self._another_random_module = HkRandom2() def __call__(self, x): key2 = hk.next_rng_key() p1 = self._another_random_module(x) p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape) print(f"Bernoullis are : {p1, p2}") # Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng() forward = hk.transform(lambda x: HkRandomNest()(x)) x = jnp.array(1.0) params = forward.init(rng_key, x=x) # The 2 Bernoullis can be difference, since they use key1 and key2 # But across the 5 iterations the answers should be the same, # since they are all produced by passing in the same rng_key to apply. for i in range(5): print(f"\n Iteration {i+1}") prediction = forward.apply(params, x=x, rng=rng_key)
Bernoullis are : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool)) Iteration 1 Bernoullis are : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool)) Iteration 2 Bernoullis are : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool)) Iteration 3 Bernoullis are : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool)) Iteration 4 Bernoullis are : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool)) Iteration 5 Bernoullis are : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

Combining JAX Function transformations and Haiku

We cannot apply JAX function transformations, like jit and grad, inside of a haiku module, since modules are impure. So we have to use hk.jit, hk.grad, etc. See this page for details. However, after transforming the haiku code to be pure, we can apply JAX transformations as usual.

(See also the equinox libary for an alternative approach to this problem.)

Example: MLP on MNIST

This example is modified from https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py

from typing import Generator, Mapping, Tuple from absl import app import haiku as hk import jax import jax.numpy as jnp import numpy as np import optax import tensorflow_datasets as tfds Batch = Mapping[str, np.ndarray]
# Data def load_dataset( split: str, *, is_training: bool, batch_size: int, ) -> Generator[Batch, None, None]: """Loads the dataset as a generator of batches.""" ds = tfds.load("mnist:3.*.*", split=split).cache().repeat() if is_training: ds = ds.shuffle(10 * batch_size, seed=0) ds = ds.batch(batch_size) return iter(tfds.as_numpy(ds)) # Make datasets. train = load_dataset("train", is_training=True, batch_size=1000) train_eval = load_dataset("train", is_training=False, batch_size=10000) test_eval = load_dataset("test", is_training=False, batch_size=10000)
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
Dl Completed...: 0%| | 0/4 [00:00<?, ? file/s]
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
# Model NCLASSES = 10 def net_fn(batch: Batch) -> jnp.ndarray: """Standard LeNet-300-100 MLP network.""" x = batch["image"].astype(jnp.float32) / 255.0 mlp = hk.Sequential( [ hk.Flatten(), hk.Linear(300), jax.nn.relu, hk.Linear(100), jax.nn.relu, hk.Linear(NCLASSES), ] ) return mlp(x) net = hk.without_apply_rng(hk.transform(net_fn)) L2_REGULARIZER = 1e-4
# Metrics # Training loss (cross-entropy). def loss(params: hk.Params, batch: Batch) -> jnp.ndarray: """Compute the loss of the network, including L2.""" logits = net.apply(params, batch) labels = jax.nn.one_hot(batch["label"], NCLASSES) l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)) softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits)) softmax_xent /= labels.shape[0] return softmax_xent + L2_REGULARIZER * l2_loss # Evaluation metric (classification accuracy). @jax.jit def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray: predictions = net.apply(params, batch) return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"]) @jax.jit def update( params: hk.Params, opt_state: optax.OptState, batch: Batch, ) -> Tuple[hk.Params, optax.OptState]: """Learning rule (stochastic gradient descent).""" grads = jax.grad(loss)(params, batch) updates, opt_state = opt.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state # We maintain avg_params, the exponential moving average of the "live" params. # avg_params is used only for evaluation (cf. https://doi.org/10.1137/0330046) @jax.jit def ema_update(params, avg_params): return optax.incremental_update(params, avg_params, step_size=0.001)
# Optimzier LR = 1e-3 opt = optax.adam(LR) # Initialize network and optimiser; note we draw an input to get shapes. params = avg_params = net.init(jax.random.PRNGKey(42), next(train)) opt_state = opt.init(params) # Train/eval loop. nsteps = 500 print_every = 100 def callback(step, avg_params, train_eval, test_eval): if step % print_every == 0: # Periodically evaluate classification accuracy on train & test sets. train_accuracy = accuracy(avg_params, next(train_eval)) test_accuracy = accuracy(avg_params, next(test_eval)) train_accuracy, test_accuracy = jax.device_get((train_accuracy, test_accuracy)) print(f"[Step {step}] Train / Test accuracy: " f"{train_accuracy:.3f} / {test_accuracy:.3f}.") for step in range(nsteps + 1): params, opt_state = update(params, opt_state, next(train)) avg_params = ema_update(params, avg_params) callback(step, avg_params, train_eval, test_eval)
[Step 0] Train / Test accuracy: 0.129 / 0.132. [Step 100] Train / Test accuracy: 0.544 / 0.544. [Step 200] Train / Test accuracy: 0.802 / 0.809. [Step 300] Train / Test accuracy: 0.887 / 0.884. [Step 400] Train / Test accuracy: 0.919 / 0.919. [Step 500] Train / Test accuracy: 0.941 / 0.937.