Path: blob/master/deprecated/scripts/autodiff_demo_jax.py
1192 views
# -*- coding: utf-8 -*-1"""autodiff_demo_jax.ipynb23Automatically generated by Colaboratory.45Original file is located at6https://colab.research.google.com/drive/1fOGAmlA4brfHL_5qcUfsQKsy7BCuuTCS7"""89# illustrate automatic differentiation using jax10# https://github.com/google/jax11import superimport1213import numpy as np # original numpy14import jax.numpy as jnp15from jax import grad, hessian1617np.random.seed(42)18D = 519w = np.random.randn(D) # jax handles RNG differently2021x = np.random.randn(D)22y = 0 # should be 0 or 12324def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.) + 1)2526#d/da sigmoid(a) = s(a) * (1-s(a))27deriv_sigmoid = lambda a: sigmoid(a) * (1-sigmoid(a))28deriv_sigmoid_jax = grad(sigmoid)29a0 = 1.530assert jnp.isclose(deriv_sigmoid(a0), deriv_sigmoid_jax(a0))3132# mu(w)=s(w'x), d/dw mu(w) = mu * (1-mu) .* x33def mu(w): return sigmoid(jnp.dot(w,x))34def deriv_mu(w): return mu(w) * (1-mu(w)) * x35deriv_mu_jax = grad(mu)36assert jnp.allclose(deriv_mu(w), deriv_mu_jax(w))3738# NLL(w) = -[y*log(mu) + (1-y)*log(1-mu)]39# d/dw NLL(w) = (mu-y)*x40def nll(w): return -(y*jnp.log(mu(w)) + (1-y)*jnp.log(1-mu(w)))41#def deriv_nll(w): return -(y*(1-mu(w))*x - (1-y)*mu(w)*x)42def deriv_nll(w): return (mu(w)-y)*x43deriv_nll_jax = grad(nll)44assert jnp.allclose(deriv_nll(w), deriv_nll_jax(w))454647# Now do it for a batch of data484950def predict(weights, inputs):51return sigmoid(jnp.dot(inputs, weights))5253def loss(weights, inputs, targets):54preds = predict(weights, inputs)55logprobs = jnp.log(preds) * targets + jnp.log(1 - preds) * (1 - targets)56return -np.sum(logprobs)575859N = 360X = np.random.randn(N, D)61y = np.random.randint(0, 2, N)6263from jax import vmap64from functools import partial6566preds = vmap(partial(predict, w))(X)67preds2 = vmap(predict, in_axes=(None, 0))(w, X)68preds3 = [predict(w, x) for x in X]69preds4 = predict(w, X)70assert jnp.allclose(preds, preds2)71assert jnp.allclose(preds, preds3)72assert jnp.allclose(preds, preds4)7374grad_fun = grad(loss)75grads = vmap(partial(grad_fun, w))(X,y)76assert grads.shape == (N,D)77grads2 = jnp.dot(np.diag(preds-y), X)78assert jnp.allclose(grads, grads2)7980grad_sum = jnp.sum(grads, axis=0)81grad_sum2 = jnp.dot(np.ones((1,N)), grads)82assert jnp.allclose(grad_sum, grad_sum2)8384# Now make things go fast85from jax import jit8687grad_fun = jit(grad(loss))88grads = vmap(partial(grad_fun, w))(X,y)89assert jnp.allclose(grads, grads2)909192# Logistic regression93H1 = hessian(loss)(w, X, y)94mu = predict(w, X)95S = jnp.diag(mu * (1-mu))96H2 = jnp.dot(np.dot(X.T, S), X)97assert jnp.allclose(H1, H2)9899