Path: blob/master/deprecated/scripts/autodiff_logreg.py
1192 views
#!/usr/bin/python1"""2Demonstrate automatic differentiaiton on binary logistic regression3using JAX, Torch and TF4"""56import superimport78import warnings910import tensorflow as tf11from absl import app, flags1213warnings.filterwarnings("ignore", category=DeprecationWarning)1415FLAGS = flags.FLAGS1617# Define a command-line argument using the Abseil library:18# https://abseil.io/docs/python/guides/flags19flags.DEFINE_boolean('jax', True, 'Whether to use JAX.')20flags.DEFINE_boolean('tf', True, 'Whether to use Tensorflow 2.')21flags.DEFINE_boolean('pytorch', True, 'Whether to use PyTorch.')22flags.DEFINE_boolean('verbose', True, 'Whether to print lots of output.')2324import numpy as np25#from scipy.misc import logsumexp26from scipy.special import logsumexp27import numpy as np # original numpy2829np.set_printoptions(precision=3)3031import jax32import jax.numpy as jnp33import numpy as np34from jax.scipy.special import logsumexp35from jax import grad, hessian, jacfwd, jacrev, jit, vmap36from jax.experimental import optimizers37from jax.experimental import stax38print("jax version {}".format(jax.__version__))39from jax.lib import xla_bridge40print("jax backend {}".format(xla_bridge.get_backend().platform))41import os42os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/home/murphyk/miniconda3/lib"434445import torch46import torchvision47print("torch version {}".format(torch.__version__))48if torch.cuda.is_available():49print(torch.cuda.get_device_name(0))50print("current device {}".format(torch.cuda.current_device()))51else:52print("Torch cannot find GPU")5354def set_torch_seed(seed):55np.random.seed(seed)56torch.manual_seed(seed)57torch.cuda.manual_seed_all(seed)5859use_cuda = torch.cuda.is_available()60device = torch.device("cuda:0" if use_cuda else "cpu")61#torch.backends.cudnn.benchmark = True62636465import tensorflow as tf66from tensorflow import keras67print("tf version {}".format(tf.__version__))68if tf.test.is_gpu_available():69print(tf.test.gpu_device_name())70else:71print("TF cannot find GPU")7273tf.compat.v1.enable_eager_execution()7475# We make some wrappers around random number generation76# so it works even if we switch from numpy to JAX7778def set_seed(seed):79return np.random.seed(seed)8081def randn(args):82return np.random.randn(*args)8384def randperm(args):85return np.random.permutation(args)8687def BCE_with_logits(logits, targets):88'''Binary cross entropy loss'''89N = logits.shape[0]90logits = logits.reshape(N,1)91logits_plus = jnp.hstack([np.zeros((N,1)), logits]) # e^0=192logits_minus = jnp.hstack([np.zeros((N,1)), -logits])93logp1 = -logsumexp(logits_minus, axis=1)94logp0 = -logsumexp(logits_plus, axis=1)95logprobs = logp1 * targets + logp0 * (1-targets)96return -np.sum(logprobs)/N9798def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)99100def predict_logit(weights, inputs):101return jnp.dot(inputs, weights) # Already vectorized102103def predict_prob(weights, inputs):104return sigmoid(predict_logit(weights, inputs))105106def NLL(weights, batch):107X, y = batch108logits = predict_logit(weights, X)109return BCE_with_logits(logits, y)110111def NLL_grad(weights, batch):112X, y = batch113N = X.shape[0]114mu = predict_prob(weights, X)115g = jnp.sum(np.dot(np.diag(mu - y), X), axis=0)/N116return g117118119120121def setup_sklearn():122import sklearn.datasets123from sklearn.model_selection import train_test_split124125iris = sklearn.datasets.load_iris()126X = iris["data"]127y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0'128N, D = X.shape # 150, 4129130131X_train, X_test, y_train, y_test = train_test_split(132X, y, test_size=0.33, random_state=42)133134from sklearn.linear_model import LogisticRegression135136# We set C to a large number to turn off regularization.137# We don't fit the bias term to simplify the comparison below.138log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)139log_reg.fit(X_train, y_train)140w_mle_sklearn = jnp.ravel(log_reg.coef_)141set_seed(0)142w = w_mle_sklearn143return w, X_test, y_test144145def compute_gradients_manually(w, X_test, y_test):146y_pred = predict_prob(w, X_test)147loss = NLL(w, (X_test, y_test))148grad_np = NLL_grad(w, (X_test, y_test))149print("params {}".format(w))150#print("pred {}".format(y_pred))151print("loss {}".format(loss))152print("grad {}".format(grad_np))153return grad_np154155156def compute_gradients_jax(w, X_test, y_test):157print("Starting JAX demo")158grad_jax = jax.grad(NLL)(w, (X_test, y_test))159print("grad {}".format(grad_jax))160return grad_jax161162163def compute_gradients_stax(w, X_test, y_test):164print("Starting STAX demo")165N, D = X_test.shape166def const_init(params):167def init(rng_key, shape):168return params169return init170171#net_init, net_apply = stax.serial(stax.Dense(1), stax.elementwise(sigmoid))172dense_layer = stax.Dense(1, W_init=const_init(np.reshape(w, (D,1))),173b_init=const_init(np.array([0.0])))174net_init, net_apply = stax.serial(dense_layer)175rng = jax.random.PRNGKey(0)176in_shape = (-1,D)177out_shape, net_params = net_init(rng, in_shape)178179def NLL_model(net_params, net_apply, batch):180X, y = batch181logits = net_apply(net_params, X)182return BCE_with_logits(logits, y)183184y_pred2 = net_apply(net_params, X_test)185loss2 = NLL_model(net_params, net_apply, (X_test, y_test))186grad_jax2 = grad(NLL_model)(net_params, net_apply, (X_test, y_test))187grad_jax3 = grad_jax2[0][0] # layer 0, block 0 (weights not bias)188grad_jax4 = grad_jax3[:,0] # column vector189190print("params {}".format(net_params))191#print("pred {}".format(y_pred2))192print("loss {}".format(loss2))193print("grad {}".format(grad_jax2))194return grad_jax4195196197198199200def compute_gradients_torch(w, X_test, y_test):201print("Starting torch demo")202N, D = X_test.shape203w_torch = torch.Tensor(np.reshape(w, [D, 1])).to(device)204w_torch.requires_grad_()205x_test_tensor = torch.Tensor(X_test).to(device)206y_test_tensor = torch.Tensor(y_test).to(device)207y_pred = torch.sigmoid(torch.matmul(x_test_tensor, w_torch))[:,0]208criterion = torch.nn.BCELoss(reduction='mean')209loss_torch = criterion(y_pred, y_test_tensor)210loss_torch.backward()211grad_torch = w_torch.grad[:,0].numpy()212print("params {}".format(w_torch))213#print("pred {}".format(y_pred))214print("loss {}".format(loss_torch))215print("grad {}".format(grad_torch))216return grad_torch217218def compute_gradients_torch_nn(w, X_test, y_test):219print("Starting torch demo: NN version")220N, D = X_test.shape221x_test_tensor = torch.Tensor(X_test).to(device)222y_test_tensor = torch.Tensor(y_test).to(device)223class Model(torch.nn.Module):224def __init__(self):225super(Model, self).__init__()226self.linear = torch.nn.Linear(D, 1, bias=False)227228def forward(self, x):229y_pred = torch.sigmoid(self.linear(x))230return y_pred231232model = Model()233# Manually set parameters to desired values234print(model.state_dict())235from collections import OrderedDict236w1 = torch.Tensor(np.reshape(w, [1, D])).to(device) # row vector237new_state_dict = OrderedDict({'linear.weight': w1})238model.load_state_dict(new_state_dict, strict=False)239#print(model.state_dict())240model.to(device) # make sure new params are on same device as data241242criterion = torch.nn.BCELoss(reduction='mean')243y_pred2 = model(x_test_tensor)[:,0]244loss_torch2 = criterion(y_pred2, y_test_tensor)245loss_torch2.backward()246params_torch2 = list(model.parameters())247grad_torch2 = params_torch2[0].grad[0].numpy()248249print("params {}".format(w1))250#print("pred {}".format(y_pred))251print("loss {}".format(loss_torch2))252print("grad {}".format(grad_torch2))253return grad_torch2254255def compute_gradients_tf(w, X_test, y_test):256print("Starting TF demo")257N, D = X_test.shape258w_tf = tf.Variable(np.reshape(w, (D,1)))259x_test_tf = tf.convert_to_tensor(X_test, dtype=np.float64)260y_test_tf = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float64)261with tf.GradientTape() as tape:262logits = tf.linalg.matmul(x_test_tf, w_tf)263y_pred = tf.math.sigmoid(logits)264loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_test_tf, logits=logits)265loss_tf = tf.reduce_mean(loss_batch, axis=0)266grad_tf = tape.gradient(loss_tf, [w_tf])267grad_tf = grad_tf[0][:,0].numpy()268269print("params {}".format(w_tf))270#print("pred {}".format(y_pred))271print("loss {}".format(loss_tf))272print("grad {}".format(grad_tf))273return grad_tf274275276def compute_gradients_keras(w, X_test, y_test):277# This no longer runs278N, D = X_test.shape279print("Starting TF demo: keras version")280model = tf.keras.models.Sequential([281tf.keras.layers.Dense(1, input_shape=(D,), activation=None, use_bias=False)282])283#model.compile(optimizer='sgd', loss=tf.nn.sigmoid_cross_entropy_with_logits)284model.build()285w_tf2 = tf.convert_to_tensor(np.reshape(w, (D,1)))286model.set_weights([w_tf2])287y_test_tf2 = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float32)288with tf.GradientTape() as tape:289logits_temp = model.predict(x_test_tf) # forwards pass only290logits2 = model(x_test_tf, training=True) # OO version enables backprop291loss_batch2 = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf2, logits2)292loss_tf2 = tf.reduce_mean(loss_batch2, axis=0)293grad_tf2 = tape.gradient(loss_tf2, model.trainable_variables)294grad_tf2 = grad_tf2[0][:,0].numpy()295print("params {}".format(w_tf2))296print("loss {}".format(loss_tf2))297print("grad {}".format(grad_tf2))298return grad_tf2299300301def main(_):302if FLAGS.verbose:303print('We will compute gradients for binary logistic regression')304305w, X_test, y_test = setup_sklearn()306grad_np = compute_gradients_manually(w, X_test, y_test)307if FLAGS.jax:308grad_jax = compute_gradients_jax(w, X_test, y_test)309assert jnp.allclose(grad_np, grad_jax)310grad_stax = compute_gradients_stax(w, X_test, y_test)311assert jnp.allclose(grad_np, grad_stax)312313if FLAGS.pytorch:314grad_torch = compute_gradients_torch(w, X_test, y_test)315assert jnp.allclose(grad_np, grad_torch)316grad_torch_nn = compute_gradients_torch_nn(w, X_test, y_test)317assert jnp.allclose(grad_np, grad_torch_nn)318319if FLAGS.tf:320grad_tf = compute_gradients_tf(w, X_test, y_test)321assert jnp.allclose(grad_np, grad_tf)322#grad_tf = compute_gradients_keras(w)323324325if __name__ == '__main__':326app.run(main)327328329