Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/08/opt_jax.ipynb
1193 views
Kernel: Python 3

Open In Colab

Optimization (using JAX)

In this notebook, we explore various algorithms for solving optimization problems of the form x=argminxXf(x) x* = \arg \min_{x \in X} f(x) We focus on the case where f:RDRf: R^D \rightarrow R is a differentiable function. We make use of the JAX library for automatic differentiation.

Some other possibly useful resources:

  1. Animations of various SGD algorithms in 2d (using PyTorch)

  2. Tutorial on constrained optimization using JAX

import sklearn import scipy import scipy.optimize import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") import itertools import time from functools import partial import os import numpy as np # np.set_printoptions(precision=3) np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
import jax import jax.numpy as jnp from jax.scipy.special import logsumexp from jax import grad, hessian, jacfwd, jacrev, jit, vmap print("jax version {}".format(jax.__version__))
jax version 0.2.10

Fitting a model using sklearn

Models in the sklearn library support the fit method for parameter estimation. Under the hood, this involves an optimization problem. In this colab, we lift up this hood and replicate the functionality from first principles.

As a running example, we will use binary logistic regression on the iris dataset.

# Fit the model to a dataset, so we have an "interesting" parameter vector to use. import sklearn.datasets from sklearn.model_selection import train_test_split iris = sklearn.datasets.load_iris() X = iris["data"] y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0' N, D = X.shape # 150, 4 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) from sklearn.linear_model import LogisticRegression # We set C to a large number to turn off regularization. # We don't fit the bias term to simplify the comparison below. log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False) log_reg.fit(X_train, y_train) w_mle_sklearn = jnp.ravel(log_reg.coef_) print(w_mle_sklearn)
[-4.414 -9.111 6.539 12.686]

Objectives and their gradients

The key input to an optimization algorithm (aka solver) is the objective function and its gradient. As an example, we use negative log likelihood for a binary logistic regression model as the objective. We compute the gradient by hand, and also use JAX's autodiff feature.

Manual differentiation

We compute the gradient of the negative log likelihood for binary logistic regression applied to the Iris dataset.

# Binary cross entropy def BCE_with_logits(logits, targets): # BCE = -sum_n log(p1)*yn + log(p0)*y0 # p1 = 1/(1+exp(-a) # log(p1) = log(1) - log(1+exp(-a)) = 0 - logsumexp(0, -a) N = logits.shape[0] logits = logits.reshape(N, 1) logits_plus = jnp.hstack([jnp.zeros((N, 1)), logits]) # e^0=1 logits_minus = jnp.hstack([jnp.zeros((N, 1)), -logits]) logp1 = -logsumexp(logits_minus, axis=1) logp0 = -logsumexp(logits_plus, axis=1) logprobs = logp1 * targets + logp0 * (1 - targets) return -jnp.sum(logprobs) / N def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def predict_logit(weights, inputs): return jnp.dot(inputs, weights) def predict_prob(weights, inputs): return sigmoid(predict_logit(weights, inputs)) def NLL(weights, batch): X, y = batch logits = predict_logit(weights, X) return BCE_with_logits(logits, y) def NLL_grad(weights, batch): X, y = batch N = X.shape[0] mu = predict_prob(weights, X) g = jnp.sum(jnp.dot(jnp.diag(mu - y), X), axis=0) / N return g
w = w_mle_sklearn y_pred = predict_prob(w, X_test) loss = NLL(w, (X_test, y_test)) grad_np = NLL_grad(w, (X_test, y_test)) print("params {}".format(w)) # print("pred {}".format(y_pred)) print("loss {}".format(loss)) print("grad {}".format(grad_np))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
params [-4.414 -9.111 6.539 12.686] loss 0.11824002861976624 grad [-0.235 -0.122 -0.198 -0.064]

Automatic differentiation in JAX

Below we use JAX to compute the gradient of the NLL for binary logistic regression.

grad_jax = grad(NLL)(w, (X_test, y_test)) print("grad {}".format(grad_jax)) assert np.allclose(grad_np, grad_jax)
grad [-0.235 -0.122 -0.198 -0.064]

Second-order optimization

The "gold standard" of optimization is second-order methods, that leverage Hessian information. Since the Hessian has O(D^2) parameters, such methods do not scale to high-dimensional problems. However, we can sometimes approximate the Hessian using low-rank or diagonal approximations. Below we illustrate the low-rank BFGS method, and the limited-memory version of BFGS, that uses O(D H) space and O(D^2) time per step, where H is the history length.

In general, second-order methods also require exact (rather than noisy) gradients. In the context of ML, this means they are "full batch" methods, since computing the exact gradient requires evaluating the loss on all the datapoints. However, for small data problems, this is feasible (and advisable).

Below we illustrate how to use LBFGS as in scipy.optimize

import scipy.optimize def training_loss(w): return NLL(w, (X_train, y_train)) def training_grad(w): return NLL_grad(w, (X_train, y_train)) np.random.seed(42) w_init = np.random.randn(D) options = {"disp": None, "maxfun": 1000, "maxiter": 1000} method = "BFGS" # The gradient function is specified via the Jacobian keyword w_mle_scipy = scipy.optimize.minimize(training_loss, w_init, jac=training_grad, method=method, options=options).x
print("parameters from sklearn {}".format(w_mle_sklearn)) print("parameters from scipy-bfgs {}".format(w_mle_scipy)) assert np.allclose(w_mle_sklearn, w_mle_scipy, atol=1e-1)
parameters from sklearn [-4.414 -9.111 6.539 12.686] parameters from scipy-bfgs [-4.415 -9.115 6.541 12.692]
p_pred_sklearn = predict_prob(w_mle_sklearn, X_test) p_pred_scipy = predict_prob(w_mle_scipy, X_test) print("predictions from sklearn") print(p_pred_sklearn) print("predictions from scipy") print(p_pred_scipy) assert np.allclose(p_pred_sklearn, p_pred_scipy, atol=1e-1)
predictions from sklearn [0.002 0.000 1.000 0.012 0.002 0.000 0.000 0.979 0.740 0.000 0.706 0.000 0.000 0.000 0.000 0.001 1.000 0.000 0.009 1.000 0.000 0.650 0.000 1.000 0.094 0.998 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.998 0.000 0.000 0.000 0.000 0.999 0.000 0.000 0.000 0.000 0.000 0.281 0.909 0.000 0.999] predictions from scipy [0.002 0.000 1.000 0.012 0.002 0.000 0.000 0.979 0.740 0.000 0.706 0.000 0.000 0.000 0.000 0.001 1.000 0.000 0.009 1.000 0.000 0.650 0.000 1.000 0.094 0.998 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.998 0.000 0.000 0.000 0.000 0.999 0.000 0.000 0.000 0.000 0.000 0.281 0.909 0.000 0.999]
# Limited memory version requires that we work with 64bit, since implemented in Fortran. def training_loss_64bit(w): l = NLL(w, (X_train, y_train)) return np.float64(l) def training_grad_64bit(w): g = NLL_grad(w, (X_train, y_train)) return np.asarray(g, dtype=np.float64) np.random.seed(42) w_init = np.random.randn(D) memory = 10 options = {"disp": None, "maxcor": memory, "maxfun": 1000, "maxiter": 1000} # The code also handles bound constraints, hence the name method = "L-BFGS-B" # w_mle_scipy = scipy.optimize.minimize(training_loss, w_init, jac=training_grad, method=method).x w_mle_scipy = scipy.optimize.minimize(training_loss_64bit, w_init, jac=training_grad_64bit, method=method).x print("parameters from sklearn {}".format(w_mle_sklearn)) print("parameters from scipy-lbfgs {}".format(w_mle_scipy)) assert np.allclose(w_mle_sklearn, w_mle_scipy, atol=1e-1)
parameters from sklearn [-4.414 -9.111 6.539 12.686] parameters from scipy-lbfgs [-4.415 -9.114 6.540 12.691]

Stochastic gradient descent

Full batch optimization is too expensive for solving empirical risk minimization problems on large datasets. The standard approach in such settings is to use stochastic gradient desceent (SGD). In this section we illustrate how to implement SGD. We apply it to a simple convex problem, namely MLE for logistic regression on the small iris dataset, so we can compare to the exact batch methods we illustrated above.

Minibatches

We use the tensorflow datasets library to make it easy to create streams of minibatches.

import tensorflow as tf import tensorflow_datasets as tfds def make_batch_stream(X_train, y_train, batch_size): dataset = tf.data.Dataset.from_tensor_slices({"X": X_train, "y": y_train}) batches = dataset.batch(batch_size) batch_stream = tfds.as_numpy(batches) # finite iterable of dict of NumPy arrays N = X_train.shape[0] nbatches = int(np.floor(N / batch_size)) print("{} examples split into {} batches of size {}".format(N, nbatches, batch_size)) return batch_stream batch_stream = make_batch_stream(X_train, y_train, 20) for epoch in range(2): print("epoch {}".format(epoch)) for batch in batch_stream: x, y = batch["X"], batch["y"] print(x.shape) # batch size * num features = 4
100 examples split into 5 batches of size 20 epoch 0 (20, 4) (20, 4) (20, 4) (20, 4) (20, 4) epoch 1 (20, 4) (20, 4) (20, 4) (20, 4) (20, 4)

SGD from scratch

We show a minimal implementation of SGD using vanilla JAX/ numpy.

def sgd(params, loss_fn, grad_loss_fn, batch_iter, max_epochs, lr): print_every = max(1, int(0.1 * max_epochs)) for epoch in range(max_epochs): epoch_loss = 0.0 for batch_dict in batch_iter: x, y = batch_dict["X"], batch_dict["y"] batch = (x, y) batch_grad = grad_loss_fn(params, batch) params = params - lr * batch_grad batch_loss = loss_fn(params, batch) # Average loss within this batch epoch_loss += batch_loss if epoch % print_every == 0: print("Epoch {}, batch Loss {}".format(epoch, batch_loss)) return params
np.random.seed(42) w_init = np.random.randn(D) max_epochs = 5 lr = 0.1 batch_size = 10 batch_stream = make_batch_stream(X_train, y_train, batch_size) w_mle_sgd = sgd(w_init, NLL, NLL_grad, batch_stream, max_epochs, lr)
100 examples split into 10 batches of size 10 Epoch 0, batch Loss 0.36490148305892944 Epoch 1, batch Loss 0.34500640630722046 Epoch 2, batch Loss 0.32851701974868774 Epoch 3, batch Loss 0.3143332302570343 Epoch 4, batch Loss 0.3018316924571991

Compare SGD with batch optimization

SGD is not a particularly good optimizer, even on this simple convex problem - it converges to a solution that it is quite different to the global MLE. Of course, this could be due to lack of identiability (since the object is convex, but maybe not strongly convex, unless we add some regularziation). But the predicted probabilities also differ substantially. Clearly we will need 'fancier' SGD methods, even for this simple problem.

print("parameters from sklearn {}".format(w_mle_sklearn)) print("parameters from sgd {}".format(w_mle_sgd)) # assert np.allclose(w_mle_sklearn, w_mle_sgd, atol=1e-1)
parameters from sklearn [-4.414 -9.111 6.539 12.686] parameters from sgd [-0.538 -0.827 0.613 1.661]
p_pred_sklearn = predict_prob(w_mle_sklearn, X_test) p_pred_sgd = predict_prob(w_mle_sgd, X_test) print("predictions from sklearn") print(p_pred_sklearn) print("predictions from sgd") print(p_pred_sgd) assert np.allclose(p_pred_sklearn, p_pred_sgd, atol=1e-1)
predictions from sklearn [0.002 0.000 1.000 0.012 0.002 0.000 0.000 0.979 0.740 0.000 0.706 0.000 0.000 0.000 0.000 0.001 1.000 0.000 0.009 1.000 0.000 0.650 0.000 1.000 0.094 0.998 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.998 0.000 0.000 0.000 0.000 0.999 0.000 0.000 0.000 0.000 0.000 0.281 0.909 0.000 0.999] predictions from sgd [0.327 0.009 0.853 0.407 0.331 0.016 0.260 0.662 0.524 0.275 0.576 0.017 0.009 0.016 0.011 0.359 0.774 0.297 0.386 0.791 0.021 0.558 0.021 0.762 0.463 0.716 0.706 0.756 0.024 0.021 0.011 0.006 0.241 0.017 0.020 0.682 0.302 0.012 0.011 0.006 0.717 0.349 0.311 0.009 0.009 0.266 0.478 0.588 0.267 0.739]
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) <ipython-input-28-fcb0fb8cdf54> in <module>() 6 print("predictions from sgd") 7 print(p_pred_sgd) ----> 8 assert np.allclose(p_pred_sklearn, p_pred_sgd, atol=1e-1) AssertionError:

Using jax.experimental.optimizers

JAX has a small optimization library focused on stochastic first-order optimizers. Every optimizer is modeled as an (init_fun, update_fun, get_params) triple of functions. The init_fun is used to initialize the optimizer state, which could include things like momentum variables, and the update_fun accepts a gradient and an optimizer state to produce a new optimizer state. The get_params function extracts the current iterate (i.e. the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested data structures, so you can store your parameters however you’d like.

Below we show how to reproduce our numpy code using this library.

# Version that uses JAX optimization library from jax.experimental import optimizers # @jit def sgd_jax(params, loss_fn, batch_stream, max_epochs, opt_init, opt_update, get_params): loss_history = [] opt_state = opt_init(params) # @jit def update(i, opt_state, batch): params = get_params(opt_state) g = grad(loss_fn)(params, batch) return opt_update(i, g, opt_state) print_every = max(1, int(0.1 * max_epochs)) total_steps = 0 for epoch in range(max_epochs): epoch_loss = 0.0 for batch_dict in batch_stream: X, y = batch_dict["X"], batch_dict["y"] batch = (X, y) total_steps += 1 opt_state = update(total_steps, opt_state, batch) params = get_params(opt_state) train_loss = np.float(loss_fn(params, batch)) loss_history.append(train_loss) if epoch % print_every == 0: print("Epoch {}, batch loss {}".format(epoch, train_loss)) return params, loss_history
# JAX with constant LR should match our minimal version of SGD schedule = optimizers.constant(step_size=lr) opt_init, opt_update, get_params = optimizers.sgd(step_size=schedule) w_mle_sgd2, history = sgd_jax(w_init, NLL, batch_stream, max_epochs, opt_init, opt_update, get_params) print(w_mle_sgd2) print(history)
Epoch 0, batch loss 0.36490148305892944 Epoch 1, batch loss 0.34500643610954285 Epoch 2, batch loss 0.32851701974868774 Epoch 3, batch loss 0.3143332004547119 Epoch 4, batch loss 0.3018316924571991 [-0.538 -0.827 0.613 1.661] [0.36490148305892944, 0.34500643610954285, 0.32851701974868774, 0.3143332004547119, 0.3018316924571991]