Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/autodiff_logreg.py
1192 views
1
#!/usr/bin/python
2
"""
3
Demonstrate automatic differentiaiton on binary logistic regression
4
using JAX, Torch and TF
5
"""
6
7
import superimport
8
9
import warnings
10
11
import tensorflow as tf
12
from absl import app, flags
13
14
warnings.filterwarnings("ignore", category=DeprecationWarning)
15
16
FLAGS = flags.FLAGS
17
18
# Define a command-line argument using the Abseil library:
19
# https://abseil.io/docs/python/guides/flags
20
flags.DEFINE_boolean('jax', True, 'Whether to use JAX.')
21
flags.DEFINE_boolean('tf', True, 'Whether to use Tensorflow 2.')
22
flags.DEFINE_boolean('pytorch', True, 'Whether to use PyTorch.')
23
flags.DEFINE_boolean('verbose', True, 'Whether to print lots of output.')
24
25
import numpy as np
26
#from scipy.misc import logsumexp
27
from scipy.special import logsumexp
28
import numpy as np # original numpy
29
30
np.set_printoptions(precision=3)
31
32
import jax
33
import jax.numpy as jnp
34
import numpy as np
35
from jax.scipy.special import logsumexp
36
from jax import grad, hessian, jacfwd, jacrev, jit, vmap
37
from jax.experimental import optimizers
38
from jax.experimental import stax
39
print("jax version {}".format(jax.__version__))
40
from jax.lib import xla_bridge
41
print("jax backend {}".format(xla_bridge.get_backend().platform))
42
import os
43
os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/home/murphyk/miniconda3/lib"
44
45
46
import torch
47
import torchvision
48
print("torch version {}".format(torch.__version__))
49
if torch.cuda.is_available():
50
print(torch.cuda.get_device_name(0))
51
print("current device {}".format(torch.cuda.current_device()))
52
else:
53
print("Torch cannot find GPU")
54
55
def set_torch_seed(seed):
56
np.random.seed(seed)
57
torch.manual_seed(seed)
58
torch.cuda.manual_seed_all(seed)
59
60
use_cuda = torch.cuda.is_available()
61
device = torch.device("cuda:0" if use_cuda else "cpu")
62
#torch.backends.cudnn.benchmark = True
63
64
65
66
import tensorflow as tf
67
from tensorflow import keras
68
print("tf version {}".format(tf.__version__))
69
if tf.test.is_gpu_available():
70
print(tf.test.gpu_device_name())
71
else:
72
print("TF cannot find GPU")
73
74
tf.compat.v1.enable_eager_execution()
75
76
# We make some wrappers around random number generation
77
# so it works even if we switch from numpy to JAX
78
79
def set_seed(seed):
80
return np.random.seed(seed)
81
82
def randn(args):
83
return np.random.randn(*args)
84
85
def randperm(args):
86
return np.random.permutation(args)
87
88
def BCE_with_logits(logits, targets):
89
'''Binary cross entropy loss'''
90
N = logits.shape[0]
91
logits = logits.reshape(N,1)
92
logits_plus = jnp.hstack([np.zeros((N,1)), logits]) # e^0=1
93
logits_minus = jnp.hstack([np.zeros((N,1)), -logits])
94
logp1 = -logsumexp(logits_minus, axis=1)
95
logp0 = -logsumexp(logits_plus, axis=1)
96
logprobs = logp1 * targets + logp0 * (1-targets)
97
return -np.sum(logprobs)/N
98
99
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)
100
101
def predict_logit(weights, inputs):
102
return jnp.dot(inputs, weights) # Already vectorized
103
104
def predict_prob(weights, inputs):
105
return sigmoid(predict_logit(weights, inputs))
106
107
def NLL(weights, batch):
108
X, y = batch
109
logits = predict_logit(weights, X)
110
return BCE_with_logits(logits, y)
111
112
def NLL_grad(weights, batch):
113
X, y = batch
114
N = X.shape[0]
115
mu = predict_prob(weights, X)
116
g = jnp.sum(np.dot(np.diag(mu - y), X), axis=0)/N
117
return g
118
119
120
121
122
def setup_sklearn():
123
import sklearn.datasets
124
from sklearn.model_selection import train_test_split
125
126
iris = sklearn.datasets.load_iris()
127
X = iris["data"]
128
y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0'
129
N, D = X.shape # 150, 4
130
131
132
X_train, X_test, y_train, y_test = train_test_split(
133
X, y, test_size=0.33, random_state=42)
134
135
from sklearn.linear_model import LogisticRegression
136
137
# We set C to a large number to turn off regularization.
138
# We don't fit the bias term to simplify the comparison below.
139
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
140
log_reg.fit(X_train, y_train)
141
w_mle_sklearn = jnp.ravel(log_reg.coef_)
142
set_seed(0)
143
w = w_mle_sklearn
144
return w, X_test, y_test
145
146
def compute_gradients_manually(w, X_test, y_test):
147
y_pred = predict_prob(w, X_test)
148
loss = NLL(w, (X_test, y_test))
149
grad_np = NLL_grad(w, (X_test, y_test))
150
print("params {}".format(w))
151
#print("pred {}".format(y_pred))
152
print("loss {}".format(loss))
153
print("grad {}".format(grad_np))
154
return grad_np
155
156
157
def compute_gradients_jax(w, X_test, y_test):
158
print("Starting JAX demo")
159
grad_jax = jax.grad(NLL)(w, (X_test, y_test))
160
print("grad {}".format(grad_jax))
161
return grad_jax
162
163
164
def compute_gradients_stax(w, X_test, y_test):
165
print("Starting STAX demo")
166
N, D = X_test.shape
167
def const_init(params):
168
def init(rng_key, shape):
169
return params
170
return init
171
172
#net_init, net_apply = stax.serial(stax.Dense(1), stax.elementwise(sigmoid))
173
dense_layer = stax.Dense(1, W_init=const_init(np.reshape(w, (D,1))),
174
b_init=const_init(np.array([0.0])))
175
net_init, net_apply = stax.serial(dense_layer)
176
rng = jax.random.PRNGKey(0)
177
in_shape = (-1,D)
178
out_shape, net_params = net_init(rng, in_shape)
179
180
def NLL_model(net_params, net_apply, batch):
181
X, y = batch
182
logits = net_apply(net_params, X)
183
return BCE_with_logits(logits, y)
184
185
y_pred2 = net_apply(net_params, X_test)
186
loss2 = NLL_model(net_params, net_apply, (X_test, y_test))
187
grad_jax2 = grad(NLL_model)(net_params, net_apply, (X_test, y_test))
188
grad_jax3 = grad_jax2[0][0] # layer 0, block 0 (weights not bias)
189
grad_jax4 = grad_jax3[:,0] # column vector
190
191
print("params {}".format(net_params))
192
#print("pred {}".format(y_pred2))
193
print("loss {}".format(loss2))
194
print("grad {}".format(grad_jax2))
195
return grad_jax4
196
197
198
199
200
201
def compute_gradients_torch(w, X_test, y_test):
202
print("Starting torch demo")
203
N, D = X_test.shape
204
w_torch = torch.Tensor(np.reshape(w, [D, 1])).to(device)
205
w_torch.requires_grad_()
206
x_test_tensor = torch.Tensor(X_test).to(device)
207
y_test_tensor = torch.Tensor(y_test).to(device)
208
y_pred = torch.sigmoid(torch.matmul(x_test_tensor, w_torch))[:,0]
209
criterion = torch.nn.BCELoss(reduction='mean')
210
loss_torch = criterion(y_pred, y_test_tensor)
211
loss_torch.backward()
212
grad_torch = w_torch.grad[:,0].numpy()
213
print("params {}".format(w_torch))
214
#print("pred {}".format(y_pred))
215
print("loss {}".format(loss_torch))
216
print("grad {}".format(grad_torch))
217
return grad_torch
218
219
def compute_gradients_torch_nn(w, X_test, y_test):
220
print("Starting torch demo: NN version")
221
N, D = X_test.shape
222
x_test_tensor = torch.Tensor(X_test).to(device)
223
y_test_tensor = torch.Tensor(y_test).to(device)
224
class Model(torch.nn.Module):
225
def __init__(self):
226
super(Model, self).__init__()
227
self.linear = torch.nn.Linear(D, 1, bias=False)
228
229
def forward(self, x):
230
y_pred = torch.sigmoid(self.linear(x))
231
return y_pred
232
233
model = Model()
234
# Manually set parameters to desired values
235
print(model.state_dict())
236
from collections import OrderedDict
237
w1 = torch.Tensor(np.reshape(w, [1, D])).to(device) # row vector
238
new_state_dict = OrderedDict({'linear.weight': w1})
239
model.load_state_dict(new_state_dict, strict=False)
240
#print(model.state_dict())
241
model.to(device) # make sure new params are on same device as data
242
243
criterion = torch.nn.BCELoss(reduction='mean')
244
y_pred2 = model(x_test_tensor)[:,0]
245
loss_torch2 = criterion(y_pred2, y_test_tensor)
246
loss_torch2.backward()
247
params_torch2 = list(model.parameters())
248
grad_torch2 = params_torch2[0].grad[0].numpy()
249
250
print("params {}".format(w1))
251
#print("pred {}".format(y_pred))
252
print("loss {}".format(loss_torch2))
253
print("grad {}".format(grad_torch2))
254
return grad_torch2
255
256
def compute_gradients_tf(w, X_test, y_test):
257
print("Starting TF demo")
258
N, D = X_test.shape
259
w_tf = tf.Variable(np.reshape(w, (D,1)))
260
x_test_tf = tf.convert_to_tensor(X_test, dtype=np.float64)
261
y_test_tf = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float64)
262
with tf.GradientTape() as tape:
263
logits = tf.linalg.matmul(x_test_tf, w_tf)
264
y_pred = tf.math.sigmoid(logits)
265
loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_test_tf, logits=logits)
266
loss_tf = tf.reduce_mean(loss_batch, axis=0)
267
grad_tf = tape.gradient(loss_tf, [w_tf])
268
grad_tf = grad_tf[0][:,0].numpy()
269
270
print("params {}".format(w_tf))
271
#print("pred {}".format(y_pred))
272
print("loss {}".format(loss_tf))
273
print("grad {}".format(grad_tf))
274
return grad_tf
275
276
277
def compute_gradients_keras(w, X_test, y_test):
278
# This no longer runs
279
N, D = X_test.shape
280
print("Starting TF demo: keras version")
281
model = tf.keras.models.Sequential([
282
tf.keras.layers.Dense(1, input_shape=(D,), activation=None, use_bias=False)
283
])
284
#model.compile(optimizer='sgd', loss=tf.nn.sigmoid_cross_entropy_with_logits)
285
model.build()
286
w_tf2 = tf.convert_to_tensor(np.reshape(w, (D,1)))
287
model.set_weights([w_tf2])
288
y_test_tf2 = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float32)
289
with tf.GradientTape() as tape:
290
logits_temp = model.predict(x_test_tf) # forwards pass only
291
logits2 = model(x_test_tf, training=True) # OO version enables backprop
292
loss_batch2 = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf2, logits2)
293
loss_tf2 = tf.reduce_mean(loss_batch2, axis=0)
294
grad_tf2 = tape.gradient(loss_tf2, model.trainable_variables)
295
grad_tf2 = grad_tf2[0][:,0].numpy()
296
print("params {}".format(w_tf2))
297
print("loss {}".format(loss_tf2))
298
print("grad {}".format(grad_tf2))
299
return grad_tf2
300
301
302
def main(_):
303
if FLAGS.verbose:
304
print('We will compute gradients for binary logistic regression')
305
306
w, X_test, y_test = setup_sklearn()
307
grad_np = compute_gradients_manually(w, X_test, y_test)
308
if FLAGS.jax:
309
grad_jax = compute_gradients_jax(w, X_test, y_test)
310
assert jnp.allclose(grad_np, grad_jax)
311
grad_stax = compute_gradients_stax(w, X_test, y_test)
312
assert jnp.allclose(grad_np, grad_stax)
313
314
if FLAGS.pytorch:
315
grad_torch = compute_gradients_torch(w, X_test, y_test)
316
assert jnp.allclose(grad_np, grad_torch)
317
grad_torch_nn = compute_gradients_torch_nn(w, X_test, y_test)
318
assert jnp.allclose(grad_np, grad_torch_nn)
319
320
if FLAGS.tf:
321
grad_tf = compute_gradients_tf(w, X_test, y_test)
322
assert jnp.allclose(grad_np, grad_tf)
323
#grad_tf = compute_gradients_keras(w)
324
325
326
if __name__ == '__main__':
327
app.run(main)
328
329