Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/10/iris_logreg_loss_surface.ipynb
1192 views
Kernel: base
# Plot 2d NLL loss surface for binary logistic regression with 1 feature # Loosely based on # https://peterroelants.github.io/posts/neural-network-implementation-part02/ import jax import jax.numpy as jnp from jax.config import config import matplotlib.pyplot as plt import seaborn as sns try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml import os os.environ["FIG_DIR"] = "/Users/kpmurphy/github/bookv2/figures" from probml_utils import LogisticRegression as lr from mpl_toolkits.mplot3d import axes3d, Axes3D try: import sklearn except ModuleNotFoundError: %pip install -qq scikit-learn import sklearn from sklearn.linear_model import LogisticRegression from sklearn import datasets
<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject
config.update("jax_enable_x64", True)
iris = datasets.load_iris() X = iris["data"][:, 3:] # petal width y = (iris["target"] == 2).astype("int64") # 1 if Iris-Virginica, else 0 print(X.shape) log_reg = LogisticRegression(solver="lbfgs", fit_intercept=True, penalty="none") log_reg.fit(X, y) w_mle = log_reg.coef_[0][0] # 12.947270212450366 b_mle = log_reg.intercept_[0] # -21.125250539711022 ypred = log_reg.predict_proba(X) # Probml_utils Logistic Regression parameters, b_mle1, w_mle1 = lr.fit(X, y, lambd=0) print(parameters, b_mle1, w_mle1) ypred1 = lr.predict_proba(parameters, X) assert jnp.isclose(w_mle, w_mle1[0], rtol=1e-4) assert jnp.isclose(b_mle, b_mle1, rtol=1e-4) assert jnp.isclose(ypred[0][1], ypred1[0])
(150, 1) [-21.12564903 12.94751255] -21.125649030019847 [12.94751255]
# Define the logistic function def logistic(z): return 1.0 / (1 + jnp.exp(-z)) # Define the prediction function @jax.jit def predict_prob(x, w): z = x.dot(w.T) p = logistic(z) return p.reshape(-1) # drop the last dimension # Define the NLL loss function (y=probability, t=binary target) @jax.jit def loss(p, y): return -jnp.mean(jnp.multiply(y, jnp.log(p)) + jnp.multiply((1 - y), jnp.log(1 - p))) # return -jnp.mean(jnp.multiply(y, jnp.log(p).reshape(-1)) + jnp.multiply((1 - y), jnp.log(1 - p).reshape(-1))) N = X.shape[0] ones = jnp.ones((N, 1)) X1 = jnp.hstack((X, ones)) # We compute the loss on a grid of (w, b) values. # We use for loops for simplicity. ngrid = 50 sf = 0.5 fudge = 5 # ws = jnp.linspace(-sf * w_mle, +sf * w_mle, ngrid) ws = jnp.linspace(w_mle - fudge, w_mle + fudge, ngrid) # bs = jnp.linspace(-sf * b_mle, +sf * b_mle, ngrid) bs = jnp.linspace(b_mle - fudge, b_mle + fudge, ngrid) grid_w, grid_b = jnp.meshgrid(ws, bs) loss_grid = [] for i in range(ngrid): losses = [] for j in range(ngrid): params = jnp.array([[grid_w[i, j], grid_b[i, j]]]) p = predict_prob(X1, params) losses.append(loss(p, y)) loss_grid.append(losses) loss_grid = jnp.array(loss_grid)
print(y.shape, p.shape)
(150,) (150,)
fig, ax = plt.subplots() CS = plt.contour(grid_w, grid_b, loss_grid, cmap="jet") pml.savefig("logregIrisLossContours2") sns.despine() plt.show()
Image in a Jupyter notebook
# Plot the loss function surface plt.figure() plt.contourf(grid_w, grid_b, loss_grid, 20) cbar = plt.colorbar() cbar.ax.set_ylabel("NLL", fontsize=12) plt.xlabel("$w$", fontsize=12) plt.ylabel("$b$", fontsize=12) plt.title("Loss function surface") pml.savefig("logregIrisLossHeatmap2") plt.show()
Image in a Jupyter notebook
fig = plt.figure() ax = fig.add_subplot(111, projection="3d") surf = ax.plot_surface(grid_w, grid_b, loss_grid) ax.set_xlabel("$w$") ax.set_ylabel("$b$") pml.savefig("logregIrisLossSurf2") plt.show()
Image in a Jupyter notebook