Yet another JAX tutorial
Kevin Murphy ([email protected]). Last update: September 2021.
JAX is a version of NumPy that runs fast on CPU, GPU and TPU, by compiling the computational graph to XLA (Accelerated Linear Algebra). It also has an excellent automatic differentiation library, extending the earlier autograd package. This library makes it easy to compute higher order derivatives, gradients of complex functions (e.g., optimize an iterative solver), etc. The JAX interface is almost identical to NumPy (by design), but with some small differences, and lots of additional features. We give a brief introduction below. For more details, see this list of JAX tutorials
Setup
Hardware accelerators
Colab makes it easy to use GPUs and TPUs for speeding up some workflows, especially related to deep learning.
GPUs
Colab offers graphics processing units (GPUs) which can be much faster than CPUs (central processing units), as we illustrate below.
Let's see how JAX can speed up things like matrix-matrix multiplication.
First the numpy/CPU version.
Now we look at the JAX version. JAX supports execution on XLA devices, which can be CPU, GPU or even TPU. We added that block_until_ready because JAX uses asynchronous execution by default.
In the above example we see that JAX GPU is much faster than Numpy CPU. However we also see that JAX CPU is slower than Numpy CPU - this can happen with simple functions, but usually JAX provides a speedup, even on CPU, if you JIT compile a complex function (see below).
We can move numpy arrays to the GPU for speed. The result will be transferred back to CPU for printing, saving, etc.
TPUs
We can turn on the tensor processing unit by selecting from the Colab runtime. Everything else "just works" as before.
If everything is set up correctly, the following command should return a list of 8 TPU devices.
Vmap
We often write a function to process a single vector or matrix, and then want to apply it to a batch of data. Using for loops is slow, and manually batchifying code is complex. Fortunately we can use the vmap
function, which will map our function across a set of inputs, automatically batchifying it.
Example: 1d convolution
(This example is from the Deepmind tutorial.)
Consider standard 1d convolution of two vectors.
Now suppose we want to convolve multiple vectors with multiple kernels. The simplest way is to use a for loop, but this is slow.
We can manually vectorize the code, but it is complex.
Fortunately vmap can do this for us!
Axes
By default, vmap vectorizes over the first axis of each of its inputs. If the first argument has a batch and the second does not, ,specify in_axes=[0,None]
, so the second argument is not vectorized over.
We can also vectorize over other dimensions.
Example: logistic regression
We now give another example, using binary logistic regression. Let us start with a predictor for a single example .
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-28-853aeac052fc> in <module>()
1
----> 2 print(predict_single(X)) # fails
<ipython-input-26-e29aadf4f504> in predict_single(x)
9
10 def predict_single(x):
---> 11 return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product
12
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in dot(a, b, precision)
4194 return lax.mul(a, b)
4195 if _max(a_ndim, b_ndim) <= 2:
-> 4196 return lax.dot(a, b, precision=precision)
4197
4198 if b_ndim == 1:
/usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in dot(lhs, rhs, precision, preferred_element_type)
666 else:
667 raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
--> 668 lhs.shape, rhs.shape))
669
670
TypeError: Incompatible shapes for dot: got (2,) and (3, 2).
We can manually vectorize the code by remembering the shapes, so multiplies each row of with .
But it easier to use vmap.
Failure cases
Vmap requires that the shapes of all the variables that are created by the function that is being mapped are the same for all values of the input arguments, as explained here. So vmap cannot be used to do any kind of embarassingly parallel task. Below we give a simple example of where this fails, since internally we create a vector whose length depends on the input 'length'.
The following fails.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-32-7fe93a3635da> in <module>()
----> 1 v = vmap(example_fun)(xs)
2 print(v)
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in batched_fun(*args, **kwargs)
1286 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
-> 1287 ).call_wrapped(*args_flat)
1288 return tree_unflatten(out_tree(), out_flat)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
<ipython-input-31-4138480dd486> in example_fun(length, val)
1 def example_fun(length, val=4):
----> 2 return jnp.sum(jnp.ones((length,)) * val)
3
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in ones(shape, dtype)
3187 shape = (shape,) if ndim(shape) == 0 else shape
-> 3188 return lax.full(shape, 1, dtype)
3189
/usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in full(shape, fill_value, dtype)
1594 """
-> 1595 shape = canonicalize_shape(shape)
1596 if np.shape(fill_value):
/usr/local/lib/python3.7/dist-packages/jax/core.py in canonicalize_shape(shape, context)
1440 pass
-> 1441 raise _invalid_shape_error(shape, context)
1442
UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)>
with val = DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
batch_dim = 0,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-32-7fe93a3635da> in <module>()
----> 1 v = vmap(example_fun)(xs)
2 print(v)
<ipython-input-31-4138480dd486> in example_fun(length, val)
1 def example_fun(length, val=4):
----> 2 return jnp.sum(jnp.ones((length,)) * val)
3
4 xs = jnp.arange(1,10)
5
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in ones(shape, dtype)
3186 dtype = float_ if dtype is None else dtype
3187 shape = (shape,) if ndim(shape) == 0 else shape
-> 3188 return lax.full(shape, 1, dtype)
3189
3190
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)>
with val = DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
batch_dim = 0,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
Stochastics
JAX is designed to be deterministic, but in some cases, we want to introduce randomness in a controlled way, and to reason about it. We discuss this below
Random number generation
One of the biggest differences from NumPy is the way Jax treates pseudo random number generation (PRNG). This is because Jax does not maintain any global state, i.e., it is purely functional. This design "provides reproducible results invariant to compilation boundaries and backends, while also maximizing performance by enabling vectorized generation and parallelization across random calls" (to quote the official page).
For example, consider this Numpy snippet. Each call to np.random.uniform updates the global state. The value of foo() is therefore only guaranteed to give the same result every time if we evaluate bar() and baz() in the same order (eg left to right). This is why foo1 and foo2 give different answers even though mathematically they shouldn't (we cannot just substitute in the value of a variable and derive the result, so we are violating "referential transparency").
Jax may evaluate parts of expressions such as bar() + baz()
in parallel, which would violate reproducibility. To prevent this, the user must pass in an explicit PRNG key to every function that requires a source of randomness. Using the same key will give the same results. See the example below.
When generating independent samples, it is important to use different keys, to ensure results are not correlated. We can do this by splitting the key into the the 'master' key (which will be used in later parts of the code via splitting), and the 'subkey', which is used temporarily to generate randomness and then thrown away, as we illustrate below.
We now reimplement the numpy example in Jax and show that we get the result no matter the order of evaluation of bar and baz.
In Jax (but not in python), a random draw of N samples in parallel will not give the same results as N draws of individual samples, as we show below.
Probability distributions
The distrax library is a JAX-native implementation of some parts of the distrbitions library from Tensorflow Probabilty (TFP). The main advantage is that the distrax source code is much easier to read and understand. For distributions not in distrax, it is possible to use TFP instead.
Here is a brief example.
Autograd
In this section, we illustrate automatic differentation using JAX. For details, see see this video or The Autodiff Cookbook.
Derivatives
We can compute using grad(f)(x)
. For example, consider
Partial derivatives
Gradients
Linear function: multi-input, scalar output.
Linear function: multi-input, multi-output.
Quadratic form.
Chain rule applied to sigmoid function.
Auxiliary return values
A function can return its value and other auxiliary results; the latter are not differentiated.
Jacobians
Example: Linear function: multi-input, multi-output.
Hessians
Quadratic form.
Example: Binary logistic regression
Vector Jacobian Products (VJP) and Jacobian Vector Products (JVP)
Suppose for fixed W, so , and . Instead of computing explicitly and then multiplying by , wecan do this in one operation.
Suppose for fixed . Now , but .
Stop-gradient
Sometimes we want to take the gradient of a complex expression wrt some parameters , but treating as a constant for some parts of the expression. For example, consider the TD(0) update in reinforcement learning, which has the following form:
where is the state, is the reward, and is the value function. This update is not the gradient of any loss function. However, if the dependency of the target on the parameter is ignored, it can be written as the gradient of the pseudo loss function
since
. We can implement this in JAX using stop_gradient
, as we show below.
Straight through estimator
The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that is the identity function, so gradients pass through ignoring the term. This can be implemented neatly using jax.lax.stop_gradient
.
Here is an example of a non-differentiable function that converts a soft probability distribution to a one-hot vector (discretization).
Now suppose we define some linear function of the quantized variable of the form . If and , we get . But the gradient is 0 because is not differentiable.
To use the straight-through estimator, we replace with , where SG is stop gradient. In the forwards pass, we have . In the backwards pass, the gradient of SG is 0, so we effectively replace with . So in the backwarsd pass we have
Per-example gradients
In some applications, we want to compute the gradient for every example in a batch, not just the sum of gradients over the batch. This is hard in other frameworks like TF and PyTorch but easy in JAX, as we show below.
To explain the above code in more depth, note that the vmap converts the function loss to take a batch of inputs for each of its arguments, and returns a batch of outputs. To make it work with a single weight vector, we specify in_axes=(None,0), meaning the first argument (w) is not replicated, and the second argument (x) is replicated along dimension 0.
Optimization
The Optax library implements many common optimizers. Below is a simple example.
JIT (just in time compilation)
In this section, we illustrate how to use the Jax JIT compiler to make code go much faster (even on a CPU). It does this by compiling the computational graph into low-level XLA primitives, potentially fusing multiple sequential operations into a single op. However, it does not work on arbitrary Python code, as we explain below.
We can also add the @jit
decorator in front of a function.
How it works: Jaxprs and tracing
In this section, we briefly explain the mechanics behind JIT, which will help you understand when it does not work.
First, consider this function.
When a function is first executed (applied to an argument), it is converted to an intermediate representatio called a JAX expression or jaxpr, by a process called tracing, as we show below.
The XLA JIT compiler can then convert the jaxpr to code that runs fast on a CPU, GPU or TPU; the original python code is no longer needed.
However, the jaxpr is created by tracing the function for a specific value. If different code is executed depending on the value of the input arguments, the resulting jaxpr will be different, so the function cannot be JITed, as we illustrate below.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-91-545a2f514f83> in <module>()
10 f_jit = jit(f)
---> 11 print(f_jit(3.0))
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
407 device=device, backend=backend, name=flat_fun.__name__,
--> 408 donated_invars=donated_invars, inline=inline)
409 out_pytree_def = out_tree()
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
1613 def bind(self, fun, *args, **params):
-> 1614 return call_bind(self, fun, *args, **params)
1615
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1604 tracers = map(top_trace.full_raise, args)
-> 1605 outs = primitive.process(top_trace, fun, tracers, params)
1606 return map(full_lower, apply_todos(env_trace_todo(), outs))
/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
1616 def process(self, trace, fun, tracers, params):
-> 1617 return trace.process_call(self, fun, tracers, params)
1618
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
612 def process_call(self, primitive, f, tracers, params):
--> 613 return primitive.impl(f, *tracers, **params)
614 process_map = process_call
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
619 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 620 *unsafe_map(arg_spec, args))
621 try:
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
696 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 697 fun, abstract_args, pe.debug_info_final(fun, "jit"))
698 if any(isinstance(c, core.Tracer) for c in consts):
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1284 with core.new_sublevel():
-> 1285 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1286 del fun, main
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1262 in_tracers = map(trace.new_arg, in_avals)
-> 1263 ans = fun.call_wrapped(*in_tracers)
1264 out_tracers = map(trace.full_raise, ans)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
<ipython-input-91-545a2f514f83> in f(x)
2 def f(x):
----> 3 if x > 0:
4 return x
/usr/local/lib/python3.7/dist-packages/jax/core.py in __bool__(self)
541 def __nonzero__(self): return self.aval._nonzero(self)
--> 542 def __bool__(self): return self.aval._bool(self)
543 def __int__(self): return self.aval._int(self)
/usr/local/lib/python3.7/dist-packages/jax/core.py in error(self, arg)
986 def error(self, arg):
--> 987 raise ConcretizationTypeError(arg, fname_context)
988 return error
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function f at <ipython-input-91-545a2f514f83>:2 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-91-545a2f514f83> in <module>()
9
10 f_jit = jit(f)
---> 11 print(f_jit(3.0))
<ipython-input-91-545a2f514f83> in f(x)
1
2 def f(x):
----> 3 if x > 0:
4 return x
5 else:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function f at <ipython-input-91-545a2f514f83>:2 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Jit will create a new compiled version for each different ShapedArray, but will reuse the code for different values of the same shape. If the code path depends on the concrete value, we can either just jit a subfunction (whose code path is constant), or we can create a different jaxpr for each concrete value of the input arguments as we explain below.
Static argnum
Note that JIT compilation requires that the control flow through the function can be determined by the shape (but not concrete value) of its inputs. The function below violates this, since when x<0, it takes one branch, whereas when x>0, it takes the other.
We can fix this by telling JAX to trace the control flow through the function using concrete values of some of its arguments. JAX will then compile different versions, depending on the input values. See below for an example.
Jit and vmap
Unfortunately, the static argnum method fails when the function is passed to vmap, because the latter can take arguments of different shape.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-95-00960fe99363> in <module>()
9
---> 10 ys = vmap(f)(xs)
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in batched_fun(*args, **kwargs)
1286 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
-> 1287 ).call_wrapped(*args_flat)
1288 return tree_unflatten(out_tree(), out_flat)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'jax.interpreters.batching.BatchTracer'>, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)>
with val = DeviceArray([0, 1, 2, 3, 4], dtype=int32)
batch_dim = 0. The error was:
TypeError: unhashable type: 'BatchTracer'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-95-00960fe99363> in <module>()
8 return 2 * x
9
---> 10 ys = vmap(f)(xs)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
164
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'jax.interpreters.batching.BatchTracer'>, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)>
with val = DeviceArray([0, 1, 2, 3, 4], dtype=int32)
batch_dim = 0. The error was:
TypeError: unhashable type: 'BatchTracer'
Side effects
Since the jaxpr is created only once, if your function has global side-effects, such as using print, they will only happen once, even if the function is called multiple times. See example below.
Caching
If you write g=jax.jit(f)
, then f will get compiled and the XLA code will be cahced. Subsequent calls to g reuse the cached code for speed. But if the jit is called inside a loop, it is effectively making a new f each time, which is slow. So typically jit occurs in the outermost scope (modulo being constant shape).
Also, if you specify static_argnums
, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs.
Strings
Jit does not work with functions that consume or return strings.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-98-613837d92373> in <module>()
10 fj = jax.jit(f)
---> 11 print(fj(42, 'add'))
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
402 for arg in args_flat:
--> 403 _check_arg(arg)
404 flat_fun, out_tree = flatten_fun(f, in_tree)
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in _check_arg(arg)
2484 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
-> 2485 raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
2486
UnfilteredStackTrace: TypeError: Argument 'add' of type <class 'str'> is not a valid JAX type.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-98-613837d92373> in <module>()
9
10 fj = jax.jit(f)
---> 11 print(fj(42, 'add'))
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in _check_arg(arg)
2483 def _check_arg(arg):
2484 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
-> 2485 raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
2486
2487 # TODO(necula): this duplicates code in core.valid_jaxtype
TypeError: Argument 'add' of type <class 'str'> is not a valid JAX type.
Pytrees
A Pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. Pytrees are useful for representing hierarchical sets of parameters for DNNs (and other structured dsta).
Simple example
Treemap
We can map functions down a pytree in the same way that we can map a function down a list. We can also combine elements in two pytrees that have the same shape to make a third pytree.
If we have a list of dicts, we can convert to a dict of lists, as shown below.
Flattening / Unflattening
Example: Linear regression
In this section we show how to use pytrees as a container for parameters of a linear reregression model. The code is based on the flax JAX tutorial. When we compute the gradient, it will also be a pytree, and will have the same shape as the parameters, so we can add the params to the gradient without having to flatten and unflatten the parameters.
Compare the above to what the training code would look like if W and b were passed in as separate arguments:
Example: MLPs
We now show a more interesting example, from the Deepmind tutorial, where we fit an MLP using SGD. The basic structure is similar to the linear regression case.
Looping constructs
For loops in Python are slow, even when JIT-compiled. However, there are built-in primitives for loops that are fast, as we illustrate below.
For loops.
The semantics of the for loop function in JAX is as follows:
We see that val
is used to accumulate the results across iterations.
Below is an example.
While loops
Here is the semantics of the JAX while loop
Below is an example.
Scan
Here is the semantics of scan:
Here is an example where we use scan to sample from a discrete-time, discrete-state Markov chain.
Common gotchas
Handling state
In this section, we discuss how to transform code that uses object-oriented programming (which can be stateful) to pure functional programming, which is stateless, as required by JAX. Our presentation is based on the Deepmind tutorial.
To start, consider a simple class that maintains an internal counter, and when called, increments the counter and returns the next number from some sequence.
The trouble with the above code is that the call to count
depends on the internal state of the object (the value n
), even though this is not an argument to the function. (The code is therefoe said to violate 'referential transparency'.) When we Jit compile it, Jax will only call the code once (to convert to a jaxpr), so the side effect of updating n
will not happen, resulting in incorrect behavior, as we show below,
We can solve this problem by passing the state as an argument into the function.
This version is functionally pure, so jit-compiles nicely.
We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form
and turned it into a class of the form
This is a common functional programming pattern, and, essentially, is the way that state is handled in all JAX programs (as we saw with the way Jax handles random number state, or parameters of a model that get updated). Note that the stateless version of the code no longer needs to use a class, but can instead group the functions into a common namespace using modules.
In some cases (eg when working with DNNs), it is more convenient to write code in an OO way. There are several libraries (notably Flax and Haiku) that let you define a model in an OO way, and then generate functionally pure code.
Mutation of arrays
Since JAX is functional, you cannot mutate arrays in place, since this makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. Instead, JAX offers the functional update functions: index_update
, index_add
, index_min
, index_max
, and the index
helper. These are illustrated below. However it is best to avoid these if possible, since they are slow.
Note: If the input values of index_update
aren't reused, jit-compiled code will perform these operations in-place, rather than making a copy.
Implicitly casting lists to vectors
You cannot treat a list of numbers as a vector. Instead you must explicitly create the vector using the np.array() constructor.