Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupport News AboutSign UpSign In
| Download
Project: Testing 18.04
Path: jax.ipynb
Views: 859
Kernel: Python 3 (Ubuntu Linux)

JAX on CoCalc

Kernel: Python 3 (Ubuntu Linux)

JAX is Autograd and XLA, brought together for high-performance machine learning research

https://github.com/google/jax

grad for gradient, and jit for just in time compilation

jax also has a numpy compatible interface

from jax import grad, jit import jax.numpy as np
def tanh(x): # Define a function y = np.exp(-2.0 * x) return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function print(grad_tanh(1.0))
0.4199743
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:164: UserWarning: No GPU found, falling back to CPU. warnings.warn('No GPU found, falling back to CPU.')
%timeit grad_tanh(0.1)
4.58 ms ± 551 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
def slow_f(x): # Element-wise ops see a large benefit from fusion a = x * x b = 2.0 + x return a * b
fast_f = jit(slow_f)
x = np.ones((5000, 5000))
%timeit -n10 -r3 fast_f(x)
146 ms ± 2.95 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
%timeit -n10 -r3 slow_f(x)
466 ms ± 5.45 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
del x

comparing with numba

import numba numba_f = numba.jit(slow_f)
import numpy y = numpy.ones((5000, 5000))
%timeit -n10 -r3 numba_f(y)
612 ms ± 22.5 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)