Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/10/logreg_jax.ipynb
1193 views
Kernel: Python 3

Open In Colab

Logistic regression

In this notebook, we illustrate how to perform logistic regression on some small datasets. We will compare binary logistic regression as implemented by sklearn with our own implementation, for which we use a batch optimizer from scipy. We code the gradients by hand. We also show how to use the JAX autodiff package (see JAX AD colab).

# Standard Python libraries from __future__ import absolute_import, division, print_function, unicode_literals import os import time import numpy as np import glob import matplotlib.pyplot as plt import PIL import imageio from IPython import display import sklearn import seaborn as sns sns.set(style="ticks", color_codes=True)
# https://github.com/google/jax import jax import jax.numpy as jnp from jax.scipy.special import logsumexp from jax import grad, hessian, jacfwd, jacrev, jit, vmap from jax.experimental import optimizers print("jax version {}".format(jax.__version__))
jax version 0.2.12
# First we create a dataset. import sklearn.datasets from sklearn.model_selection import train_test_split iris = sklearn.datasets.load_iris() X = iris["data"] y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0' N, D = X.shape # 150, 4 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
# Now let's find the MLE using sklearn. We will use this as the "gold standard" from sklearn.linear_model import LogisticRegression # We set C to a large number to turn off regularization. # We don't fit the bias term to simplify the comparison below. log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False) log_reg.fit(X_train, y_train) w_mle_sklearn = jnp.ravel(log_reg.coef_) print(w_mle_sklearn)
[-4.41378437 -9.11061763 6.53872233 12.68572678]
# First we define the model, and check it gives the same output as sklearn. def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def predict_logit(weights, inputs): return jnp.dot(inputs, weights) # Already vectorized def predict_prob(weights, inputs): return sigmoid(predict_logit(weights, inputs)) ptest_sklearn = log_reg.predict_proba(X_test)[:, 1] print(jnp.round(ptest_sklearn, 3)) ptest_us = predict_prob(w_mle_sklearn, X_test) print(jnp.round(ptest_us, 3)) assert jnp.allclose(ptest_sklearn, ptest_us, atol=1e-2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[0.002 0. 1. 0.012 0.002 0. 0. 0.979 0.74 0. 0.706 0. 0. 0. 0. 0.001 1. 0. 0.009 1. 0. 0.65 0. 1. 0.094 0.998 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.998 0. 0. 0. 0. 0.999 0. 0. 0. 0. 0. 0.281 0.909 0. 0.999] [0.002 0. 1. 0.012 0.002 0. 0. 0.979 0.74 0. 0.706 0. 0. 0. 0. 0.001 1. 0. 0.009 1. 0. 0.65 0. 1. 0.094 0.998 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.998 0. 0. 0. 0. 0.999 0. 0. 0. 0. 0. 0.281 0.909 0. 0.999]
# Next we define the objective and check it gives the same output as sklearn. from sklearn.metrics import log_loss from jax.scipy.special import logsumexp # from scipy.misc import logsumexp def NLL_unstable(weights, batch): inputs, targets = batch p1 = predict_prob(weights, inputs) logprobs = jnp.log(p1) * targets + jnp.log(1 - p1) * (1 - targets) N = inputs.shape[0] return -jnp.sum(logprobs) / N def NLL(weights, batch): # Use log-sum-exp trick inputs, targets = batch # p1 = 1/(1+exp(-logit)), p0 = 1/(1+exp(+logit)) logits = predict_logit(weights, inputs).reshape((-1, 1)) N = logits.shape[0] logits_plus = jnp.hstack([jnp.zeros((N, 1)), logits]) # e^0=1 logits_minus = jnp.hstack([jnp.zeros((N, 1)), -logits]) logp1 = -logsumexp(logits_minus, axis=1) logp0 = -logsumexp(logits_plus, axis=1) logprobs = logp1 * targets + logp0 * (1 - targets) return -jnp.sum(logprobs) / N # We can use a small amount of L2 regularization, for numerical stability def PNLL(weights, batch, l2_penalty=1e-5): nll = NLL(weights, batch) l2_norm = jnp.sum(jnp.power(weights, 2)) # squared L2 norm return nll + l2_penalty * l2_norm # We evaluate the training loss at the MLE, where the parameter values are "extreme". nll_train = log_loss(y_train, predict_prob(w_mle_sklearn, X_train)) nll_train2 = NLL(w_mle_sklearn, (X_train, y_train)) nll_train3 = NLL_unstable(w_mle_sklearn, (X_train, y_train)) print(nll_train) print(nll_train2) print(nll_train3)
0.06907700925379459 0.06907699 nan
# Next we check the gradients compared to the manual formulas. # For simplicity, we initially just do this for a single random example. np.random.seed(42) D = 5 w = np.random.randn(D) x = np.random.randn(D) y = 0 # d/da sigmoid(a) = s(a) * (1-s(a)) deriv_sigmoid = lambda a: sigmoid(a) * (1 - sigmoid(a)) deriv_sigmoid_jax = grad(sigmoid) a = 1.5 # a random logit assert jnp.isclose(deriv_sigmoid(a), deriv_sigmoid_jax(a)) # mu(w)=sigmoid(w'x), d/dw mu(w) = mu * (1-mu) .* x def mu(w): return sigmoid(jnp.dot(w, x)) def deriv_mu(w): return mu(w) * (1 - mu(w)) * x deriv_mu_jax = grad(mu) assert jnp.allclose(deriv_mu(w), deriv_mu_jax(w)) # NLL(w) = -[y*log(mu) + (1-y)*log(1-mu)] # d/dw NLL(w) = (mu-y)*x def nll(w): return -(y * jnp.log(mu(w)) + (1 - y) * jnp.log(1 - mu(w))) def deriv_nll(w): return (mu(w) - y) * x deriv_nll_jax = grad(nll) assert jnp.allclose(deriv_nll(w), deriv_nll_jax(w))
# Now let's check the gradients on the batch version of our data. N = X_train.shape[0] mu = predict_prob(w_mle_sklearn, X_train) g1 = grad(NLL)(w_mle_sklearn, (X_train, y_train)) g2 = jnp.sum(jnp.dot(jnp.diag(mu - y_train), X_train), axis=0) / N print(g1) print(g2) assert jnp.allclose(g1, g2, atol=1e-2) H1 = hessian(NLL)(w_mle_sklearn, (X_train, y_train)) S = jnp.diag(mu * (1 - mu)) H2 = jnp.dot(jnp.dot(X_train.T, S), X_train) / N print(H1) print(H2) assert jnp.allclose(H1, H2, atol=1e-2)
[ 3.5801623e-08 7.0655005e-07 -9.9190243e-07 -1.4292980e-06] [ 2.3841858e-08 6.9618227e-07 -1.0067224e-06 -1.4327467e-06] [[0.80245787 0.36579472 0.6444712 0.2132109 ] [0.36579472 0.1684845 0.29427886 0.09809215] [0.64447117 0.29427886 0.5187084 0.17146072] [0.21321094 0.09809215 0.17146073 0.05745751]] [[0.80245805 0.36579484 0.6444712 0.21321094] [0.36579484 0.1684845 0.29427883 0.09809214] [0.6444711 0.29427883 0.51870865 0.17146075] [0.21321093 0.09809214 0.17146075 0.05745751]]
# Finally, use BFGS batch optimizer to compute MLE, and compare to sklearn import scipy.optimize def training_loss(w): return NLL(w, (X_train, y_train)) def training_grad(w): return grad(training_loss)(w) np.random.seed(43) N, D = X_train.shape w_init = np.random.randn(D) w_mle_scipy = scipy.optimize.minimize(training_loss, w_init, jac=training_grad, method="BFGS").x print("parameters from sklearn {}".format(w_mle_sklearn)) print("parameters from scipy-bfgs {}".format(w_mle_scipy)) assert jnp.allclose(w_mle_sklearn, w_mle_scipy, atol=1e-1) prob_scipy = predict_prob(w_mle_scipy, X_test) prob_sklearn = predict_prob(w_mle_sklearn, X_test) print(jnp.round(prob_scipy, 3)) print(jnp.round(prob_sklearn, 3)) assert jnp.allclose(prob_scipy, prob_sklearn, atol=1e-2)
parameters from sklearn [-4.41378437 -9.11061763 6.53872233 12.68572678] parameters from scipy-bfgs [-4.43822388 -9.04306242 6.52521732 12.7028332 ] [0.002 0. 1. 0.012 0.002 0. 0. 0.979 0.732 0. 0.711 0. 0. 0. 0. 0.001 1. 0. 0.009 1. 0. 0.654 0. 1. 0.095 0.998 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.998 0. 0. 0. 0. 0.999 0. 0. 0. 0. 0. 0.279 0.91 0. 0.999] [0.002 0. 1. 0.012 0.002 0. 0. 0.979 0.74 0. 0.706 0. 0. 0. 0. 0.001 1. 0. 0.009 1. 0. 0.65 0. 1. 0.094 0.998 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.998 0. 0. 0. 0. 0.999 0. 0. 0. 0. 0. 0.281 0.909 0. 0.999]