CoCalc Logo Icon
StoreFeaturesDocsShareSupport News Sign UpSign In
Project: Testing 18.04
Path: jax.ipynb
Views: 821
Embed | Download | Raw |
Kernel: Python 3 (Ubuntu Linux)

JAX on CoCalc

Kernel: Python 3 (Ubuntu Linux)

JAX is Autograd and XLA, brought together for high-performance machine learning research

grad for gradient, and jit for just in time compilation

jax also has a numpy compatible interface

from jax import grad, jit import jax.numpy as np
def tanh(x): # Define a function y = np.exp(-2.0 * x) return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function print(grad_tanh(1.0))
%timeit grad_tanh(0.1)
def slow_f(x): # Element-wise ops see a large benefit from fusion a = x * x b = 2.0 + x return a * b
fast_f = jit(slow_f)
x = np.ones((5000, 5000))
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)
del x

comparing with numba

import numba numba_f = numba.jit(slow_f)
import numpy y = numpy.ones((5000, 5000))
%timeit -n10 -r3 numba_f(y)