Path: blob/master/notebooks/tutorials/practical_jax_tips.ipynb
1192 views
If else condition with lax
Pair-wise distance with vmap
Compute Hessian with jax
Let us consider Linear regression loss function
Simulate dataset
Verify loss and gradient values
Verify hessian matrix
Way-1
Way-2
tree_map
in JAX
The only requirement for tree_map
to work is, output should have the same structure as the first argument (as explained here). For example:
The problem here is that dists
do not have same structure as log_probs
(log_probs
structure matches with samples
). So, we should keep samples
as the first argument:
Use of lax.scan
to accelerate a training loop
Here we create a dummy training loop and check the performance of lax.scan
. The example also shows how to convert a training loop to lax.scan
version of it.
Note that xs
array can be passed in case we want to scan over it. An example of it can be found in this blackjax documentation.
tree_flatten
v/s ravel_pytree
tree_flatten
: This function is used to get a list of leaves from a PyTreeravel_pytree
: This function is used to convert all the leaves in a one dimensional JAX array
tree_flatten
and tree_unflatten
ravel_pytree
is_leaf
while working with PyTrees
Sometimes you do not want to work with the leaves of your PyTree. You may want to consider a non-leaf node as a leaf node based on your requirement. Let us see such an example in distrax
Suppose we want to sample from the above distribution_pytree
.
The problem here is that there are no leaves returned by tree_leaves
, but we want the leaves to be distrax
distributions. Let us use is_leaf
for this purpose.
And, we get what we anticipated. Let us now try to get samples passing is_leaf
to tree_map
.
We can see that we are able to get the samples now.