Optimization (using JAX)
In this notebook, we explore various algorithms for solving optimization problems of the form We focus on the case where is a differentiable function. We make use of the JAX library for automatic differentiation.
Some other possibly useful resources:
Fitting a model using sklearn
Models in the sklearn library support the fit
method for parameter estimation. Under the hood, this involves an optimization problem. In this colab, we lift up this hood and replicate the functionality from first principles.
As a running example, we will use binary logistic regression on the iris dataset.
Objectives and their gradients
The key input to an optimization algorithm (aka solver) is the objective function and its gradient. As an example, we use negative log likelihood for a binary logistic regression model as the objective. We compute the gradient by hand, and also use JAX's autodiff feature.
Second-order optimization
The "gold standard" of optimization is second-order methods, that leverage Hessian information. Since the Hessian has O(D^2) parameters, such methods do not scale to high-dimensional problems. However, we can sometimes approximate the Hessian using low-rank or diagonal approximations. Below we illustrate the low-rank BFGS method, and the limited-memory version of BFGS, that uses O(D H) space and O(D^2) time per step, where H is the history length.
In general, second-order methods also require exact (rather than noisy) gradients. In the context of ML, this means they are "full batch" methods, since computing the exact gradient requires evaluating the loss on all the datapoints. However, for small data problems, this is feasible (and advisable).
Below we illustrate how to use LBFGS as in scipy.optimize
Stochastic gradient descent
Full batch optimization is too expensive for solving empirical risk minimization problems on large datasets. The standard approach in such settings is to use stochastic gradient desceent (SGD). In this section we illustrate how to implement SGD. We apply it to a simple convex problem, namely MLE for logistic regression on the small iris dataset, so we can compare to the exact batch methods we illustrated above.
Minibatches
We use the tensorflow datasets library to make it easy to create streams of minibatches.
SGD from scratch
We show a minimal implementation of SGD using vanilla JAX/ numpy.
Compare SGD with batch optimization
SGD is not a particularly good optimizer, even on this simple convex problem - it converges to a solution that it is quite different to the global MLE. Of course, this could be due to lack of identiability (since the object is convex, but maybe not strongly convex, unless we add some regularziation). But the predicted probabilities also differ substantially. Clearly we will need 'fancier' SGD methods, even for this simple problem.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-28-fcb0fb8cdf54> in <module>()
6 print("predictions from sgd")
7 print(p_pred_sgd)
----> 8 assert np.allclose(p_pred_sklearn, p_pred_sgd, atol=1e-1)
AssertionError:
Using jax.experimental.optimizers
JAX has a small optimization library focused on stochastic first-order optimizers. Every optimizer is modeled as an (init_fun
, update_fun
, get_params
) triple of functions. The init_fun
is used to initialize the optimizer state, which could include things like momentum variables, and the update_fun
accepts a gradient and an optimizer state to produce a new optimizer state. The get_params
function extracts the current iterate (i.e. the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested data structures, so you can store your parameters however you’d like.
Below we show how to reproduce our numpy code using this library.