Introduction to neural networks using Flax
Flax / Linen is a neural net library, built on top of JAX, "designed to offer an implicit variable management API to save the user from having to manually thread thousands of variables through a complex tree of functions." To handle both current and future JAX transforms (configured and composed in any way), Linen Modules are defined as explicit functions of the form Where is the collection of variables (eg. parameters) and PRNG state used by the model, the mutated output variable collections, the input data and the output data. We illustrate this below. Our tutorial is based on the official flax intro and linen colab. Details are in the flax source code. Note: please be sure to read our JAX tutorial first.
Building wheel for flax (setup.py) ... done
MLP in vanilla JAX
We construct a simple MLP with L hidden layers (relu activation), and scalar output (linear activation).
Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors.
Our first flax model
Here we recreate the vanilla model in flax. Since we don't specify how the parameters are initialized, the behavior will not be identical to the vanilla model --- we will fix this below, but for now, we focus on model construction.
We see that the model is a subclass of nn.Module
, which is a subclass of Python's dataclass. The child class (written by the user) must define a model.call(inputs)
method, that applies the function to the input, and a model.setup()
method, that creates the modules inside this model.
The module (parent) class defines two main methods: model.apply(variables, input
, that applies the function to the input (and variables) to generate an output; and model.init(key, input)
, that initializes the variables and returns them as a "frozen dictionary". This dictionary can contain multiple kinds of variables. In the example below, the only kind are parameters, which are immutable variables (that will usually get updated in an external optimization loop, as we show later). The parameters are automatically named after the corresponding module (here, dense0, dense1, etc). In this example, both modules are dense layers, so their parameters are a weight matrix (called 'kernel') and a bias vector.
The hyper-parameters (in this case, the size of each layer) are stored as attributes of the class, and are specified when the module is constructed.
Compact modules
To reduce the amount of boiler plate code, flax makes it possible to define a module just by writing the call
method, avoiding the need to write a setup
function. The corresponding layers will be created when the init
funciton is called, so the input shape can be inferred lazily (when passed an input).
Explicit parameter initialization
We can control the initialization of the random parameters in each submodule by specifying an init function. Below we show how to initialize our MLP to match the vanilla JAX model. We then check both methods give the same outputs.
Creating your own modules
Now we illustrate how to create a module with its own parameters, instead of relying on composing built-in primitives. As an example, we write our own dense layer class.
Stochastic layers
Some layers may need a source of randomness. If so, we must pass them a PRNG in the init
and apply
functions, in addition to the PRNG used for parameter initialization. We illustrate this below using dropout. We construct two versions, one which is stochastic (for training), and one which is deterministic (for evaluation).
Stateful layers
In addition to parameters, linen modules can contain other kinds of variables, which may be mutable as we illustrate below. Indeed, parameters are just a special case of variable. In particular, this line
is a convenient shorthand for this:
Example: counter
Combining mutable variables and immutable parameters
We can combine mutable variables with immutable parameters. As an example, consider a simplified version of batch normalization, which computes the running mean of its inputs, and adds an optimzable offset (bias) term.
The intial variables are: params = (bias=1), batch_stats=(mean=0)
If we pass in x=ones(N,D), the running average becomes and the output becomes
To call the function with the updated batch stats, we have to stitch together the new mutated state with the old state, as shown below.
If we pass in x=2*ones(N,D), the running average gets updated to and the output becomes
Optimization
Flax has several built-in (first-order) optimizers, as we illustrate below on a random linear function. (Note that we can also fit a model defined in flax using some other kind of optimizer, such as that provided by the optax library.)
Worked example: MLP for MNIST
We demonstrate how to fit a shallow MLP to MNIST using Flax. We use this function: https://github.com/probml/pyprobml/blob/master/scripts/fit_flax.py
Import code
Collecting superimport
Downloading superimport-0.3.3.tar.gz (5.8 kB)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from superimport) (2.23.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2021.5.30)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (3.0.4)
Building wheels for collected packages: superimport
Building wheel for superimport (setup.py) ... done
Created wheel for superimport: filename=superimport-0.3.3-py3-none-any.whl size=5766 sha256=4a2891b002c0f5f3e2330adca027096d023e29accf21b2ae4ceb6b89445ef44f
Stored in directory: /root/.cache/pip/wheels/0f/0a/7e/ba2303ac54e68950f97db02ebf09ee4ef5c794e1adb656cb68
Successfully built superimport
Installing collected packages: superimport
Successfully installed superimport-0.3.3
--2021-09-11 03:30:48-- https://raw.githubusercontent.com/probml/pyprobml/master/scripts/fit_flax.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5638 (5.5K) [text/plain]
Saving to: ‘fit_flax.py.1’
fit_flax.py.1 100%[===================>] 5.51K --.-KB/s in 0s
2021-09-11 03:30:48 (65.1 MB/s) - ‘fit_flax.py.1’ saved [5638/5638]