Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/autodiff_demo_jax.py
1192 views
1
# -*- coding: utf-8 -*-
2
"""autodiff_demo_jax.ipynb
3
4
Automatically generated by Colaboratory.
5
6
Original file is located at
7
https://colab.research.google.com/drive/1fOGAmlA4brfHL_5qcUfsQKsy7BCuuTCS
8
"""
9
10
# illustrate automatic differentiation using jax
11
# https://github.com/google/jax
12
import superimport
13
14
import numpy as np # original numpy
15
import jax.numpy as jnp
16
from jax import grad, hessian
17
18
np.random.seed(42)
19
D = 5
20
w = np.random.randn(D) # jax handles RNG differently
21
22
x = np.random.randn(D)
23
y = 0 # should be 0 or 1
24
25
def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.) + 1)
26
27
#d/da sigmoid(a) = s(a) * (1-s(a))
28
deriv_sigmoid = lambda a: sigmoid(a) * (1-sigmoid(a))
29
deriv_sigmoid_jax = grad(sigmoid)
30
a0 = 1.5
31
assert jnp.isclose(deriv_sigmoid(a0), deriv_sigmoid_jax(a0))
32
33
# mu(w)=s(w'x), d/dw mu(w) = mu * (1-mu) .* x
34
def mu(w): return sigmoid(jnp.dot(w,x))
35
def deriv_mu(w): return mu(w) * (1-mu(w)) * x
36
deriv_mu_jax = grad(mu)
37
assert jnp.allclose(deriv_mu(w), deriv_mu_jax(w))
38
39
# NLL(w) = -[y*log(mu) + (1-y)*log(1-mu)]
40
# d/dw NLL(w) = (mu-y)*x
41
def nll(w): return -(y*jnp.log(mu(w)) + (1-y)*jnp.log(1-mu(w)))
42
#def deriv_nll(w): return -(y*(1-mu(w))*x - (1-y)*mu(w)*x)
43
def deriv_nll(w): return (mu(w)-y)*x
44
deriv_nll_jax = grad(nll)
45
assert jnp.allclose(deriv_nll(w), deriv_nll_jax(w))
46
47
48
# Now do it for a batch of data
49
50
51
def predict(weights, inputs):
52
return sigmoid(jnp.dot(inputs, weights))
53
54
def loss(weights, inputs, targets):
55
preds = predict(weights, inputs)
56
logprobs = jnp.log(preds) * targets + jnp.log(1 - preds) * (1 - targets)
57
return -np.sum(logprobs)
58
59
60
N = 3
61
X = np.random.randn(N, D)
62
y = np.random.randint(0, 2, N)
63
64
from jax import vmap
65
from functools import partial
66
67
preds = vmap(partial(predict, w))(X)
68
preds2 = vmap(predict, in_axes=(None, 0))(w, X)
69
preds3 = [predict(w, x) for x in X]
70
preds4 = predict(w, X)
71
assert jnp.allclose(preds, preds2)
72
assert jnp.allclose(preds, preds3)
73
assert jnp.allclose(preds, preds4)
74
75
grad_fun = grad(loss)
76
grads = vmap(partial(grad_fun, w))(X,y)
77
assert grads.shape == (N,D)
78
grads2 = jnp.dot(np.diag(preds-y), X)
79
assert jnp.allclose(grads, grads2)
80
81
grad_sum = jnp.sum(grads, axis=0)
82
grad_sum2 = jnp.dot(np.ones((1,N)), grads)
83
assert jnp.allclose(grad_sum, grad_sum2)
84
85
# Now make things go fast
86
from jax import jit
87
88
grad_fun = jit(grad(loss))
89
grads = vmap(partial(grad_fun, w))(X,y)
90
assert jnp.allclose(grads, grads2)
91
92
93
# Logistic regression
94
H1 = hessian(loss)(w, X, y)
95
mu = predict(w, X)
96
S = jnp.diag(mu * (1-mu))
97
H2 = jnp.dot(np.dot(X.T, S), X)
98
assert jnp.allclose(H1, H2)
99