Kernel: Python 3 (Ubuntu Linux)
JAX on CoCalc
Kernel: Python 3 (Ubuntu Linux)
JAX is Autograd and XLA, brought together for high-performance machine learning research
grad for gradient, and jit for just in time compilation
jax also has a numpy compatible interface
In [1]:
In [2]:
In [3]:
In [4]:
In [5]:
In [6]:
In [7]:
In [8]:
In [9]:
In [10]:
comparing with numba
In [11]:
In [12]:
In [13]:
In [0]:
In [0]: