Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/13/flax_intro.ipynb
1192 views
Kernel: Python 3

Open In Colab

Introduction to neural networks using Flax

Flax / Linen is a neural net library, built on top of JAX, "designed to offer an implicit variable management API to save the user from having to manually thread thousands of variables through a complex tree of functions." To handle both current and future JAX transforms (configured and composed in any way), Linen Modules are defined as explicit functions of the form f(vin,x)vout,y f(v_{in}, x) \rightarrow v_{out}, y Where vinv_{in} is the collection of variables (eg. parameters) and PRNG state used by the model, voutv_{out} the mutated output variable collections, xx the input data and yy the output data. We illustrate this below. Our tutorial is based on the official flax intro and linen colab. Details are in the flax source code. Note: please be sure to read our JAX tutorial first.

import numpy as np # np.set_printoptions(precision=3) np.set_printoptions(formatter={"float": lambda x: "{0:0.5f}".format(x)}) import matplotlib.pyplot as plt
import jax print(jax.__version__) print(jax.devices()) from jax import lax, random, numpy as jnp key = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.2.19 [<jaxlib.xla_extension.Device object at 0x7f8642e14af0>]
from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple # Useful type aliases Array = jnp.ndarray PRNGKey = Array Batch = Mapping[str, np.ndarray] OptState = Any
# Install Flax at head: !pip install --upgrade -q git+https://github.com/google/flax.git
Building wheel for flax (setup.py) ... done
import flax from flax.core import freeze, unfreeze from flax import linen as nn from flax import optim from jax.config import config # config.enable_omnistaging() # Linen requires enabling omnistaging

MLP in vanilla JAX

We construct a simple MLP with L hidden layers (relu activation), and scalar output (linear activation).

Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors.

# We define the parameter initializers using a signature that is flax-compatible # https://flax.readthedocs.io/en/latest/_modules/jax/_src/nn/initializers.html def weights_init(key, shape, dtype=jnp.float32): return random.normal(key, shape, dtype) # return jnp.ones(shape, dtype) def bias_init(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype) def relu(a): return jnp.maximum(a, 0)
# A minimal MLP class class MLP0: features: Sequence[int] # number of features in each layer def __init__(self, features): # class constructor self.features = features def init(self, key, x): # initialize parameters in_size = np.shape(x)[1] sizes = np.concatenate(([in_size], self.features)) nlayers = len(sizes) params = {} for i in range(nlayers - 1): in_size = sizes[i] out_size = sizes[i + 1] subkey1, subkey2, key = random.split(key, num=3) W = weights_init(subkey1, (in_size, out_size)) b = bias_init(subkey2, out_size) params[f"W{i}"] = W params[f"b{i}"] = b return params def apply(self, params, x): # forwards pass activations = x nhidden_layers = len(self.features) - 1 for i in range(nhidden_layers): W = params[f"W{i}"] b = params[f"b{i}"] outputs = jnp.dot(activations, W) + b activations = relu(outputs) # for final layer, no activation function i = nhidden_layers outputs = jnp.dot(activations, params[f"W{i}"]) + params[f"b{i}"] return outputs
key = random.PRNGKey(0) D = 3 N = 2 x = random.normal( key, ( N, D, ), ) layer_sizes = [3, 1] # 1 hidden layer of size 3, 1 scalar output model0 = MLP0(layer_sizes) params0 = model0.init(key, x) print("params") for k, v in params0.items(): print(k, v.shape) print(v) y0 = model0.apply(params0, x) print("\noutput") print(y0)
params W0 (3, 3) [[-1.83021 1.18417 0.06777] [0.34588 0.37858 -0.65318] [0.18976 0.45157 -0.33964]] b0 (3,) [0.00000 0.00000 0.00000] W1 (3, 1) [[-1.74905] [1.83313] [-0.23808]] b1 (1,) [0.00000] output [[-0.09538] [2.78382]]

Our first flax model

Here we recreate the vanilla model in flax. Since we don't specify how the parameters are initialized, the behavior will not be identical to the vanilla model --- we will fix this below, but for now, we focus on model construction.

We see that the model is a subclass of nn.Module, which is a subclass of Python's dataclass. The child class (written by the user) must define a model.call(inputs) method, that applies the function to the input, and a model.setup() method, that creates the modules inside this model.

The module (parent) class defines two main methods: model.apply(variables, input, that applies the function to the input (and variables) to generate an output; and model.init(key, input), that initializes the variables and returns them as a "frozen dictionary". This dictionary can contain multiple kinds of variables. In the example below, the only kind are parameters, which are immutable variables (that will usually get updated in an external optimization loop, as we show later). The parameters are automatically named after the corresponding module (here, dense0, dense1, etc). In this example, both modules are dense layers, so their parameters are a weight matrix (called 'kernel') and a bias vector.

The hyper-parameters (in this case, the size of each layer) are stored as attributes of the class, and are specified when the module is constructed.

class MLP(nn.Module): features: Sequence[int] default_attr: int = 42 def setup(self): print("setup") self.layers = [nn.Dense(feat) for feat in self.features] def __call__(self, inputs): print("call") x = inputs for i, lyr in enumerate(self.layers): x = lyr(x) if i != len(self.layers) - 1: x = nn.relu(x) return x
key = random.PRNGKey(0) D = 3 N = 2 x = random.normal( key, ( N, D, ), ) layer_sizes = [3, 1] # 1 hidden layer of size 3, 1 scalar output print("calling constructor") model = MLP(layer_sizes) # just initialize attributes of the object print("OUTPUT") print(model) print("\ncalling init") variables = model.init(key, x) # calls setup then __call___ print("OUTPUT") print(variables) print("Calling apply") y = model.apply(variables, x) # calls setup then __call___ print(y)
calling constructor OUTPUT MLP( # attributes features = [3, 1] default_attr = 42 ) calling init setup call OUTPUT FrozenDict({ params: { layers_0: { kernel: DeviceArray([[0.57725, 0.43926, 0.69045], [0.02542, 0.50461, 0.56675], [0.07185, 0.17350, -0.04227]], dtype=float32), bias: DeviceArray([0.00000, 0.00000, 0.00000], dtype=float32), }, layers_1: { kernel: DeviceArray([[0.24313], [0.94535], [-0.12602]], dtype=float32), bias: DeviceArray([0.00000], dtype=float32), }, }, }) W0 [[0.57725 0.43926 0.69045] [0.02542 0.50461 0.56675] [0.07185 0.17350 -0.04227]] Calling apply setup call [[0.02978] [0.66403]]

Compact modules

To reduce the amount of boiler plate code, flax makes it possible to define a module just by writing the call method, avoiding the need to write a setup function. The corresponding layers will be created when the init funciton is called, so the input shape can be inferred lazily (when passed an input).

class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): x = nn.Dense(feat)(x) if i != len(self.features) - 1: x = nn.relu(x) return x model = MLP(layer_sizes) print(model) params = model.init(key, x) print(params) y = model.apply(params, x) print(y)
MLP( # attributes features = [3, 1] ) FrozenDict({ params: { Dense_0: { kernel: DeviceArray([[0.28216, 1.03322, 0.07901], [0.15159, -0.50100, -0.22373], [-0.40327, -0.39875, -0.09402]], dtype=float32), bias: DeviceArray([0.00000, 0.00000, 0.00000], dtype=float32), }, Dense_1: { kernel: DeviceArray([[0.25432], [0.76792], [0.48329]], dtype=float32), bias: DeviceArray([0.00000], dtype=float32), }, }, }) [[0.56035] [1.07065]]

Explicit parameter initialization

We can control the initialization of the random parameters in each submodule by specifying an init function. Below we show how to initialize our MLP to match the vanilla JAX model. We then check both methods give the same outputs.

def make_const_init(x): def init_params(key, shape, dtype=jnp.float32): return x return init_params class MLP_init(nn.Module): features: Sequence[int] params_init: Dict def setup(self): nlayers = len(self.features) layers = [] for i in range(nlayers): W = self.params_init[f"W{i}"] b = self.params_init[f"b{i}"] weights_init = make_const_init(W) bias_init = make_const_init(b) layer = nn.Dense(self.features[i], kernel_init=weights_init, bias_init=bias_init) layers.append(layer) self.layers = layers def __call__(self, inputs): x = inputs for i, lyr in enumerate(self.layers): x = lyr(x) if i != len(self.layers) - 1: x = nn.relu(x) return x
params_init = params0 model = MLP_init(layer_sizes, params_init) print(model) variables = model.init(key, x) params = variables["params"] print(params) W0 = params0["W0"] W = params["layers_0"]["kernel"] assert np.allclose(W, W0) y = model.apply(variables, x) print(y) assert np.allclose(y, y0)
MLP_init( # attributes features = [3, 1] params_init = {'W0': DeviceArray([[-1.83021, 1.18417, 0.06777], [0.34588, 0.37858, -0.65318], [0.18976, 0.45157, -0.33964]], dtype=float32), 'b0': DeviceArray([0.00000, 0.00000, 0.00000], dtype=float32), 'W1': DeviceArray([[-1.74905], [1.83313], [-0.23808]], dtype=float32), 'b1': DeviceArray([0.00000], dtype=float32)} ) FrozenDict({ layers_0: { kernel: DeviceArray([[-1.83021, 1.18417, 0.06777], [0.34588, 0.37858, -0.65318], [0.18976, 0.45157, -0.33964]], dtype=float32), bias: DeviceArray([0.00000, 0.00000, 0.00000], dtype=float32), }, layers_1: { kernel: DeviceArray([[-1.74905], [1.83313], [-0.23808]], dtype=float32), bias: DeviceArray([0.00000], dtype=float32), }, }) [[-0.09538] [2.78382]]

Creating your own modules

Now we illustrate how to create a module with its own parameters, instead of relying on composing built-in primitives. As an example, we write our own dense layer class.

class SimpleDense(nn.Module): features: int # num output features for this layer kernel_init: Callable = nn.initializers.lecun_normal() bias_init: Callable = nn.initializers.zeros @nn.compact def __call__(self, inputs): features_in = inputs.shape[-1] # infer shape from input features_out = self.features kernel = self.param("kernel", self.kernel_init, (features_in, features_out)) bias = self.param("bias", self.bias_init, (features_out,)) outputs = jnp.dot(inputs, kernel) + bias return outputs model = SimpleDense(features=3) print(model) vars = model.init(key, x) print(vars) y = model.apply(vars, x) print(y)
SimpleDense( # attributes features = 3 kernel_init = init bias_init = zeros ) FrozenDict({ params: { kernel: DeviceArray([[0.32718, 0.05599, 0.17998], [-0.12295, 0.70712, 0.28972], [0.13731, -0.02853, -0.62830]], dtype=float32), bias: DeviceArray([0.00000, 0.00000, 0.00000], dtype=float32), }, }) [[0.30842 -0.91549 -0.74603] [0.36248 0.24616 0.36943]]

Stochastic layers

Some layers may need a source of randomness. If so, we must pass them a PRNG in the init and apply functions, in addition to the PRNG used for parameter initialization. We illustrate this below using dropout. We construct two versions, one which is stochastic (for training), and one which is deterministic (for evaluation).

class Block(nn.Module): features: int training: bool @nn.compact def __call__(self, inputs): x = nn.Dense(self.features)(inputs) x = nn.Dropout(rate=0.5)(x, deterministic=not self.training) return x N = 1 D = 2 x = random.uniform(key, (N, D)) model = Block(features=3, training=True) key = random.PRNGKey(0) variables = model.init({"params": key, "dropout": key}, x) # variables = model.init(key, x) # cannot share the rng print("variables", variables) # Apply stochastic model for i in range(2): key, subkey = random.split(key) y = model.apply(variables, x, rngs={"dropout": subkey}) print(f"train output {i}, ", y) # Now make a deterministic version eval_model = Block(features=3, training=False) key = random.PRNGKey(0) # variables = eval_model.init({'params': key, 'dropout': key}, x) for i in range(2): key, subkey = random.split(key) y = eval_model.apply(variables, x, rngs={"dropout": subkey}) print(f"eval output {i}, ", y)
variables FrozenDict({ params: { Dense_0: { kernel: DeviceArray([[0.99988, -0.14086, -0.99796], [1.46673, 0.59637, 0.38263]], dtype=float32), bias: DeviceArray([0.00000, 0.00000, 0.00000], dtype=float32), }, }, }) train output 0, [[0.00000 1.05814 0.00000]] train output 1, [[3.12202 1.05814 0.32862]] eval output 0, [[1.56101 0.52907 0.16431]] eval output 1, [[1.56101 0.52907 0.16431]]

Stateful layers

In addition to parameters, linen modules can contain other kinds of variables, which may be mutable as we illustrate below. Indeed, parameters are just a special case of variable. In particular, this line

p = self.param('param_name', init_fn, shape, dtype)

is a convenient shorthand for this:

p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value

Example: counter

class Counter(nn.Module): @nn.compact def __call__(self): # variable(collection, name, init_fn, *init_args) counter1 = self.variable("counter", "count1", lambda: jnp.zeros((), jnp.int32)) counter2 = self.variable("counter", "count2", lambda: jnp.zeros((), jnp.int32)) is_initialized = self.has_variable("counter", "count1") if is_initialized: counter1.value += 1 counter2.value += 2 return counter1.value, counter2.value model = Counter() print(model) init_variables = model.init(key) # calls the `call` method print("initialized variables:\n", init_variables) counter = init_variables["counter"]["count1"] print("counter 1 value", counter) y, mutated_variables = model.apply(init_variables, mutable=["counter"]) print("mutated variables:\n", mutated_variables) print("output:\n", y)
Counter() initialized variables: FrozenDict({ counter: { count1: DeviceArray(1, dtype=int32), count2: DeviceArray(2, dtype=int32), }, }) counter 1 value 1 mutated variables: FrozenDict({ counter: { count1: DeviceArray(2, dtype=int32), count2: DeviceArray(4, dtype=int32), }, }) output: (DeviceArray(2, dtype=int32), DeviceArray(4, dtype=int32))

Combining mutable variables and immutable parameters

We can combine mutable variables with immutable parameters. As an example, consider a simplified version of batch normalization, which computes the running mean of its inputs, and adds an optimzable offset (bias) term.

class BiasAdderWithRunningMean(nn.Module): decay: float = 0.99 @nn.compact def __call__(self, x): is_initialized = self.has_variable("params", "bias") # variable(collection, name, init_fn, *init_args) ra_mean = self.variable("batch_stats", "mean", lambda s: jnp.zeros(s), x.shape[1:]) dummy_mutable = self.variable("mutables", "dummy", lambda s: 42, 0) # param(name, init_fn, *init_args) bias = self.param("bias", lambda rng, shape: jnp.ones(shape), x.shape[1:]) if is_initialized: ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True) return x - ra_mean.value + bias

The intial variables are: params = (bias=1), batch_stats=(mean=0)

If we pass in x=ones(N,D), the running average becomes 0.990+(10.99)1=0.01 0.99*0 + (1-0.99)*1 = 0.01 and the output becomes 10.01+1=1.99 1 - 0.01 + 1 = 1.99

key = random.PRNGKey(0) N = 2 D = 5 x = jnp.ones((N, D)) model = BiasAdderWithRunningMean() variables = model.init(key, x) print("initial variables:\n", variables) nonstats, stats = variables.pop("batch_stats") print("nonstats", nonstats) print("stats", stats)
initial variables: FrozenDict({ batch_stats: { mean: DeviceArray([0.00000, 0.00000, 0.00000, 0.00000, 0.00000], dtype=float32), }, mutables: { dummy: 42, }, params: { bias: DeviceArray([1.00000, 1.00000, 1.00000, 1.00000, 1.00000], dtype=float32), }, }) nonstats FrozenDict({ mutables: { dummy: 42, }, params: { bias: DeviceArray([1.00000, 1.00000, 1.00000, 1.00000, 1.00000], dtype=float32), }, }) stats FrozenDict({ mean: DeviceArray([0.00000, 0.00000, 0.00000, 0.00000, 0.00000], dtype=float32), })
y, mutables = model.apply(variables, x, mutable=["batch_stats"]) print("output", y) print("mutables", mutables)
output [[1.99000 1.99000 1.99000 1.99000 1.99000] [1.99000 1.99000 1.99000 1.99000 1.99000]] mutables FrozenDict({ batch_stats: { mean: DeviceArray([[0.01000, 0.01000, 0.01000, 0.01000, 0.01000]], dtype=float32), }, })

To call the function with the updated batch stats, we have to stitch together the new mutated state with the old state, as shown below.

variables = unfreeze(nonstats) print(variables) variables["batch_stats"] = mutables["batch_stats"] variables = freeze(variables) print(variables)
{'mutables': {'dummy': 42}, 'params': {'bias': DeviceArray([1.00000, 1.00000, 1.00000, 1.00000, 1.00000], dtype=float32)}} FrozenDict({ mutables: { dummy: 42, }, params: { bias: DeviceArray([1.00000, 1.00000, 1.00000, 1.00000, 1.00000], dtype=float32), }, batch_stats: { mean: DeviceArray([[0.01000, 0.01000, 0.01000, 0.01000, 0.01000]], dtype=float32), }, })

If we pass in x=2*ones(N,D), the running average gets updated to 0.990.01+(10.99)2.0=0.0299 0.99 * 0.01 + (1-0.99) * 2.0 = 0.0299 and the output becomes 20.0299+1=2.9701 2- 0.0299 + 1 = 2.9701

x = 2 * jnp.ones((N, D)) y, mutables = model.apply(variables, x, mutable=["batch_stats"]) print("output", y) print("batch_stats", mutables) assert np.allclose(y, 2.9701) assert np.allclose(mutables["batch_stats"]["mean"], 0.0299)
output [[2.97010 2.97010 2.97010 2.97010 2.97010] [2.97010 2.97010 2.97010 2.97010 2.97010]] batch_stats FrozenDict({ batch_stats: { mean: DeviceArray([[0.02990, 0.02990, 0.02990, 0.02990, 0.02990]], dtype=float32), }, })

Optimization

Flax has several built-in (first-order) optimizers, as we illustrate below on a random linear function. (Note that we can also fit a model defined in flax using some other kind of optimizer, such as that provided by the optax library.)

D = 5 key = jax.random.PRNGKey(0) params = {"w": jax.random.normal(key, (D,))} print(params) x = jax.random.normal(key, (D,)) def loss(params): w = params["w"] return jnp.dot(x, w) loss_grad_fn = jax.value_and_grad(loss) v, g = loss_grad_fn(params) print(v) print(g)
{'w': DeviceArray([0.18784, -1.28334, -0.27109, 1.24906, 0.24447], dtype=float32)} 3.375659 {'w': DeviceArray([0.18784, -1.28334, -0.27109, 1.24906, 0.24447], dtype=float32)}
from flax import optim optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.9) print(optimizer_def) optimizer = optimizer_def.create(params) print(optimizer)
<flax.optim.momentum.Momentum object at 0x7fbd09abceb8> Optimizer(optimizer_def=<flax.optim.momentum.Momentum object at 0x7fbd09abceb8>, state=OptimizerState(step=DeviceArray(0, dtype=int32), param_states={'w': _MomentumParamState(momentum=DeviceArray([0.00000, 0.00000, 0.00000, 0.00000, 0.00000], dtype=float32))}), target={'w': DeviceArray([0.18784, -1.28334, -0.27109, 1.24906, 0.24447], dtype=float32)})
for i in range(10): params = optimizer.target loss_val, grad = loss_grad_fn(params) optimizer = optimizer.apply_gradient(grad) params = optimizer.target print("step {}, loss {:0.3f}, params {}".format(i, loss_val, params))
step 0, loss -10.593, params {'w': DeviceArray([-0.71837, 4.90788, 1.03673, -4.77677, -0.93493], dtype=float32)} step 1, loss -12.910, params {'w': DeviceArray([-0.85316, 5.82877, 1.23126, -5.67306, -1.11035], dtype=float32)} step 2, loss -15.332, params {'w': DeviceArray([-0.99326, 6.78590, 1.43345, -6.60462, -1.29268], dtype=float32)} step 3, loss -17.849, params {'w': DeviceArray([-1.13813, 7.77566, 1.64252, -7.56794, -1.48122], dtype=float32)} step 4, loss -20.453, params {'w': DeviceArray([-1.28730, 8.79477, 1.85780, -8.55983, -1.67536], dtype=float32)} step 5, loss -23.133, params {'w': DeviceArray([-1.44033, 9.84031, 2.07866, -9.57743, -1.87453], dtype=float32)} step 6, loss -25.884, params {'w': DeviceArray([-1.59685, 10.90963, 2.30454, -10.61818, -2.07823], dtype=float32)} step 7, loss -28.696, params {'w': DeviceArray([-1.75650, 12.00035, 2.53494, -11.67977, -2.28600], dtype=float32)} step 8, loss -31.565, params {'w': DeviceArray([-1.91897, 13.11033, 2.76941, -12.76010, -2.49745], dtype=float32)} step 9, loss -34.485, params {'w': DeviceArray([-2.08397, 14.23764, 3.00754, -13.85730, -2.71220], dtype=float32)}

Worked example: MLP for MNIST

We demonstrate how to fit a shallow MLP to MNIST using Flax. We use this function: https://github.com/probml/pyprobml/blob/master/scripts/fit_flax.py

Import code

!pip install superimport
Collecting superimport Downloading superimport-0.3.3.tar.gz (5.8 kB) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from superimport) (2.23.0) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2.10) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2021.5.30) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (1.24.3) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (3.0.4) Building wheels for collected packages: superimport Building wheel for superimport (setup.py) ... done Created wheel for superimport: filename=superimport-0.3.3-py3-none-any.whl size=5766 sha256=4a2891b002c0f5f3e2330adca027096d023e29accf21b2ae4ceb6b89445ef44f Stored in directory: /root/.cache/pip/wheels/0f/0a/7e/ba2303ac54e68950f97db02ebf09ee4ef5c794e1adb656cb68 Successfully built superimport Installing collected packages: superimport Successfully installed superimport-0.3.3
!wget https://raw.githubusercontent.com/probml/pyprobml/master/scripts/fit_flax.py import fit_flax as ff ff.test()
--2021-09-11 03:30:48-- https://raw.githubusercontent.com/probml/pyprobml/master/scripts/fit_flax.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 5638 (5.5K) [text/plain] Saving to: ‘fit_flax.py.1’ fit_flax.py.1 100%[===================>] 5.51K --.-KB/s in 0s 2021-09-11 03:30:48 (65.1 MB/s) - ‘fit_flax.py.1’ saved [5638/5638]
ERROR: superimport : missing python module: flax Trying try to install automatcially WARNING:root:Package was not found in the reverse index, trying pypi. /usr/local/lib/python3.7/dist-packages/jax/_src/config.py:171: UserWarning: enable_omnistaging() is a no-op in JAX versions 0.2.12 and higher; see https://github.com/google/jax/blob/main/design_notes/omnistaging.md "enable_omnistaging() is a no-op in JAX versions 0.2.12 and higher;\n"
testing fit-flax train step: 0, loss: 1.9212, accuracy: 0.33 train step: 1, loss: 1.8051, accuracy: 0.33 FrozenDict({ Dense_0: { bias: DeviceArray([-0.05753, -0.06485, 0.01820, -0.06623, 0.03506, 0.01663, 0.03751, 0.03186, 0.03567, 0.01368], dtype=float32), kernel: DeviceArray([[0.04317, -0.03052, 0.00153, -0.08923, 0.00806, 0.00845, 0.01886, 0.01948, 0.01805, 0.00214], [0.05920, 0.03486, -0.01555, 0.05232, -0.03050, -0.01319, -0.02862, -0.02186, -0.02522, -0.01143], [0.01316, 0.02697, -0.01047, 0.09079, -0.02449, -0.01311, -0.02791, -0.02281, -0.02392, -0.00822], [0.04932, -0.15770, -0.00438, 0.08807, -0.02104, -0.00335, 0.00400, 0.02333, 0.02383, -0.00208], [0.05093, -0.04816, -0.00149, -0.04978, -0.00104, 0.00457, 0.01233, 0.01710, 0.01562, -0.00007]], dtype=float32), }, }) {'train_loss': [DeviceArray(1.92125, dtype=float32), DeviceArray(1.80505, dtype=float32)], 'train_accuracy': [DeviceArray(0.33333, dtype=float32), DeviceArray(0.33333, dtype=float32)], 'test_loss': [DeviceArray(1.80505, dtype=float32), DeviceArray(1.59708, dtype=float32)], 'test_accuracy': [DeviceArray(0.33333, dtype=float32), DeviceArray(0.66667, dtype=float32)]} test passed

Data

import tensorflow_datasets as tfds import tensorflow as tf
def process_record(batch): image = batch["image"] label = batch["label"] # flatten image to vector shape = image.get_shape().as_list() D = np.prod(shape) # no batch dimension image = tf.reshape(image, (D,)) # rescale to -1..+1 image = tf.cast(image, dtype=tf.float32) image = ((image / 255.0) - 0.5) * 2.0 # convert to standard names return {"X": image, "y": label} def load_mnist(split, batch_size): dataset, info = tfds.load("mnist", split=split, with_info=True) dataset = dataset.map(process_record) if split == "train": dataset = dataset.shuffle(10 * batch_size, seed=0) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.cache() dataset = dataset.repeat() dataset = tfds.as_numpy(dataset) # leave TF behind num_examples = info.splits[split].num_examples return iter(dataset), num_examples batch_size = 100 train_iter, num_train = load_mnist("train", batch_size) test_iter, num_test = load_mnist("test", batch_size) num_epochs = 3 num_steps = num_train // batch_size print(f"{num_epochs} epochs with batch size {batch_size} will take {num_steps} steps") batch = next(train_iter) print(batch["X"].shape) print(batch["y"].shape)
3 epochs with batch size 100 will take 600 steps (100, 784) (100,)

Model

class Model(nn.Module): nhidden: int nclasses: int @nn.compact def __call__(self, x): if self.nhidden > 0: x = nn.Dense(self.nhidden)(x) x = nn.relu(x) x = nn.Dense(self.nclasses)(x) # logits x = nn.log_softmax(x) # log probabilities return x

Training loop

model = Model(nhidden=128, nclasses=10) rng = jax.random.PRNGKey(0) num_steps = 200 params, history = ff.fit_model(model, rng, num_steps, train_iter, test_iter, print_every=20) display(history)
train step: 0, loss: 2.4736, accuracy: 0.13 train step: 20, loss: 1.3480, accuracy: 0.60 train step: 40, loss: 0.6385, accuracy: 0.80 train step: 60, loss: 0.9009, accuracy: 0.71 train step: 80, loss: 0.6118, accuracy: 0.83 train step: 100, loss: 0.3172, accuracy: 0.91 train step: 120, loss: 0.5050, accuracy: 0.82 train step: 140, loss: 0.5362, accuracy: 0.83 train step: 160, loss: 0.4464, accuracy: 0.86 train step: 180, loss: 0.7583, accuracy: 0.81
{'test_accuracy': [DeviceArray(0.14000, dtype=float32), DeviceArray(0.61000, dtype=float32), DeviceArray(0.82000, dtype=float32), DeviceArray(0.80000, dtype=float32), DeviceArray(0.82000, dtype=float32), DeviceArray(0.83000, dtype=float32), DeviceArray(0.86000, dtype=float32), DeviceArray(0.84000, dtype=float32), DeviceArray(0.85000, dtype=float32), DeviceArray(0.86000, dtype=float32)], 'test_loss': [DeviceArray(2.60662, dtype=float32), DeviceArray(1.52616, dtype=float32), DeviceArray(0.58906, dtype=float32), DeviceArray(0.77289, dtype=float32), DeviceArray(0.62323, dtype=float32), DeviceArray(0.55543, dtype=float32), DeviceArray(0.45732, dtype=float32), DeviceArray(0.54828, dtype=float32), DeviceArray(0.61440, dtype=float32), DeviceArray(0.54720, dtype=float32)], 'train_accuracy': [DeviceArray(0.13000, dtype=float32), DeviceArray(0.60000, dtype=float32), DeviceArray(0.80000, dtype=float32), DeviceArray(0.71000, dtype=float32), DeviceArray(0.83000, dtype=float32), DeviceArray(0.91000, dtype=float32), DeviceArray(0.82000, dtype=float32), DeviceArray(0.83000, dtype=float32), DeviceArray(0.86000, dtype=float32), DeviceArray(0.81000, dtype=float32)], 'train_loss': [DeviceArray(2.47364, dtype=float32), DeviceArray(1.34795, dtype=float32), DeviceArray(0.63851, dtype=float32), DeviceArray(0.90089, dtype=float32), DeviceArray(0.61182, dtype=float32), DeviceArray(0.31721, dtype=float32), DeviceArray(0.50501, dtype=float32), DeviceArray(0.53616, dtype=float32), DeviceArray(0.44640, dtype=float32), DeviceArray(0.75831, dtype=float32)]}
plt.figure() plt.plot(history["test_accuracy"], "o-", label="test accuracy") plt.xlabel("num. minibatches") plt.legend() plt.show()
Image in a Jupyter notebook