Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/autodiff_demo.py
1192 views
1
# Desmonstrate automatic differentiaiton on binary logistic regression
2
# using JAX, Torch and TF
3
4
import superimport
5
6
import numpy as np
7
#from scipy.misc import logsumexp
8
from scipy.special import logsumexp
9
10
np.set_printoptions(precision=3)
11
12
USE_JAX = True
13
USE_TORCH = True
14
USE_TF = True
15
16
# We make some wrappers around random number generation
17
# so it works even if we switch from numpy to JAX
18
import numpy as np # original numpy
19
20
def set_seed(seed):
21
jnp.random.seed(seed)
22
23
def randn(args):
24
return np.random.randn(*args)
25
26
def randperm(args):
27
return np.random.permutation(args)
28
29
if USE_TORCH:
30
import torch
31
import torchvision
32
print("torch version {}".format(torch.__version__))
33
if torch.cuda.is_available():
34
print(torch.cuda.get_device_name(0))
35
print("current device {}".format(torch.cuda.current_device()))
36
else:
37
print("Torch cannot find GPU")
38
39
def set_seed(seed):
40
np.random.seed(seed)
41
torch.manual_seed(seed)
42
torch.cuda.manual_seed_all(seed)
43
44
use_cuda = torch.cuda.is_available()
45
device = torch.device("cuda:0" if use_cuda else "cpu")
46
#torch.backends.cudnn.benchmark = True
47
48
if USE_JAX:
49
import jax
50
import jax.numpy as jnp
51
import numpy as np
52
from jax.scipy.special import logsumexp
53
from jax import grad, hessian, jacfwd, jacrev, jit, vmap
54
from jax.experimental import optimizers
55
print("jax version {}".format(jax.__version__))
56
from jax.lib import xla_bridge
57
print("jax backend {}".format(xla_bridge.get_backend().platform))
58
import os
59
os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/home/murphyk/miniconda3/lib"
60
61
62
if USE_TF:
63
import tensorflow as tf
64
from tensorflow import keras
65
print("tf version {}".format(tf.__version__))
66
if tf.test.is_gpu_available():
67
print(tf.test.gpu_device_name())
68
else:
69
print("TF cannot find GPU")
70
tf.compat.v1.enable_eager_execution()
71
72
### Dataset
73
import sklearn.datasets
74
from sklearn.model_selection import train_test_split
75
76
iris = sklearn.datasets.load_iris()
77
X = iris["data"]
78
y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0'
79
N, D = X.shape # 150, 4
80
81
82
X_train, X_test, y_train, y_test = train_test_split(
83
X, y, test_size=0.33, random_state=42)
84
85
from sklearn.linear_model import LogisticRegression
86
87
# We set C to a large number to turn off regularization.
88
# We don't fit the bias term to simplify the comparison below.
89
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
90
log_reg.fit(X_train, y_train)
91
w_mle_sklearn = jnp.ravel(log_reg.coef_)
92
93
set_seed(0)
94
w = w_mle_sklearn
95
96
## Compute gradient of loss "by hand" using numpy
97
98
99
def BCE_with_logits(logits, targets):
100
N = logits.shape[0]
101
logits = logits.reshape(N,1)
102
logits_plus = jnp.hstack([np.zeros((N,1)), logits]) # e^0=1
103
logits_minus = jnp.hstack([np.zeros((N,1)), -logits])
104
logp1 = -logsumexp(logits_minus, axis=1)
105
logp0 = -logsumexp(logits_plus, axis=1)
106
logprobs = logp1 * targets + logp0 * (1-targets)
107
return -np.sum(logprobs)/N
108
109
if True:
110
# Compute using numpy
111
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)
112
113
def predict_logit(weights, inputs):
114
return jnp.dot(inputs, weights) # Already vectorized
115
116
def predict_prob(weights, inputs):
117
return sigmoid(predict_logit(weights, inputs))
118
119
def NLL(weights, batch):
120
X, y = batch
121
logits = predict_logit(weights, X)
122
return BCE_with_logits(logits, y)
123
124
def NLL_grad(weights, batch):
125
X, y = batch
126
N = X.shape[0]
127
mu = predict_prob(weights, X)
128
g = jnp.sum(np.dot(np.diag(mu - y), X), axis=0)/N
129
return g
130
131
y_pred = predict_prob(w, X_test)
132
loss = NLL(w, (X_test, y_test))
133
grad_np = NLL_grad(w, (X_test, y_test))
134
print("params {}".format(w))
135
#print("pred {}".format(y_pred))
136
print("loss {}".format(loss))
137
print("grad {}".format(grad_np))
138
139
if USE_JAX:
140
print("Starting JAX demo")
141
grad_jax = grad(NLL)(w, (X_test, y_test))
142
print("grad {}".format(grad_jax))
143
assert jnp.allclose(grad_np, grad_jax)
144
145
print("Starting STAX demo")
146
# Stax version
147
from jax.experimental import stax
148
149
def const_init(params):
150
def init(rng_key, shape):
151
return params
152
return init
153
154
#net_init, net_apply = stax.serial(stax.Dense(1), stax.elementwise(sigmoid))
155
dense_layer = stax.Dense(1, W_init=const_init(np.reshape(w, (D,1))),
156
b_init=const_init(np.array([0.0])))
157
net_init, net_apply = stax.serial(dense_layer)
158
rng = jax.random.PRNGKey(0)
159
in_shape = (-1,D)
160
out_shape, net_params = net_init(rng, in_shape)
161
162
def NLL_model(net_params, net_apply, batch):
163
X, y = batch
164
logits = net_apply(net_params, X)
165
return BCE_with_logits(logits, y)
166
167
y_pred2 = net_apply(net_params, X_test)
168
loss2 = NLL_model(net_params, net_apply, (X_test, y_test))
169
grad_jax2 = grad(NLL_model)(net_params, net_apply, (X_test, y_test))
170
grad_jax3 = grad_jax2[0][0] # layer 0, block 0 (weights not bias)
171
grad_jax4 = grad_jax3[:,0] # column vector
172
assert jnp.allclose(grad_np, grad_jax4)
173
174
print("params {}".format(net_params))
175
#print("pred {}".format(y_pred2))
176
print("loss {}".format(loss2))
177
print("grad {}".format(grad_jax2))
178
179
if USE_TORCH:
180
import torch
181
182
print("Starting torch demo")
183
w_torch = torch.Tensor(np.reshape(w, [D, 1])).to(device)
184
w_torch.requires_grad_()
185
x_test_tensor = torch.Tensor(X_test).to(device)
186
y_test_tensor = torch.Tensor(y_test).to(device)
187
y_pred = torch.sigmoid(torch.matmul(x_test_tensor, w_torch))[:,0]
188
criterion = torch.nn.BCELoss(reduction='mean')
189
loss_torch = criterion(y_pred, y_test_tensor)
190
loss_torch.backward()
191
grad_torch = w_torch.grad[:,0].numpy()
192
assert jnp.allclose(grad_np, grad_torch)
193
194
print("params {}".format(w_torch))
195
#print("pred {}".format(y_pred))
196
print("loss {}".format(loss_torch))
197
print("grad {}".format(grad_torch))
198
199
if USE_TORCH:
200
print("Starting torch demo: Model version")
201
202
class Model(torch.nn.Module):
203
def __init__(self):
204
super(Model, self).__init__()
205
self.linear = torch.nn.Linear(D, 1, bias=False)
206
207
def forward(self, x):
208
y_pred = torch.sigmoid(self.linear(x))
209
return y_pred
210
211
model = Model()
212
# Manually set parameters to desired values
213
print(model.state_dict())
214
from collections import OrderedDict
215
w1 = torch.Tensor(np.reshape(w, [1, D])).to(device) # row vector
216
new_state_dict = OrderedDict({'linear.weight': w1})
217
model.load_state_dict(new_state_dict, strict=False)
218
#print(model.state_dict())
219
model.to(device) # make sure new params are on same device as data
220
221
criterion = torch.nn.BCELoss(reduction='mean')
222
y_pred2 = model(x_test_tensor)[:,0]
223
loss_torch2 = criterion(y_pred2, y_test_tensor)
224
loss_torch2.backward()
225
params_torch2 = list(model.parameters())
226
grad_torch2 = params_torch2[0].grad[0].numpy()
227
assert jnp.allclose(grad_np, grad_torch2)
228
229
print("params {}".format(w1))
230
#print("pred {}".format(y_pred))
231
print("loss {}".format(loss_torch))
232
print("grad {}".format(grad_torch2))
233
234
if USE_TF:
235
print("Starting TF demo")
236
w_tf = tf.Variable(np.reshape(w, (D,1)))
237
x_test_tf = tf.convert_to_tensor(X_test, dtype=np.float64)
238
y_test_tf = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float64)
239
with tf.GradientTape() as tape:
240
logits = tf.linalg.matmul(x_test_tf, w_tf)
241
y_pred = tf.math.sigmoid(logits)
242
loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(labels = y_test_tf, logits = logits)
243
loss_tf = tf.reduce_mean(loss_batch, axis=0)
244
grad_tf = tape.gradient(loss_tf, [w_tf])
245
grad_tf = grad_tf[0][:,0].numpy()
246
assert jnp.allclose(grad_np, grad_tf)
247
248
print("params {}".format(w_tf))
249
#print("pred {}".format(y_pred))
250
print("loss {}".format(loss_tf))
251
print("grad {}".format(grad_tf))
252
253
if False:
254
# This no longer runs
255
print("Starting TF demo: keras version")
256
model = tf.keras.models.Sequential([
257
tf.keras.layers.Dense(1, input_shape=(D,), activation=None, use_bias=False)
258
])
259
#model.compile(optimizer='sgd', loss=tf.nn.sigmoid_cross_entropy_with_logits)
260
model.build()
261
w_tf2 = tf.convert_to_tensor(np.reshape(w, (D,1)))
262
model.set_weights([w_tf2])
263
y_test_tf2 = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float32)
264
with tf.GradientTape() as tape:
265
logits_temp = model.predict(x_test_tf) # forwards pass only
266
logits2 = model(x_test_tf, training=True) # OO version enables backprop
267
loss_batch2 = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf2, logits2)
268
loss_tf2 = tf.reduce_mean(loss_batch2, axis=0)
269
grad_tf2 = tape.gradient(loss_tf2, model.trainable_variables)
270
grad_tf2 = grad_tf2[0][:,0].numpy()
271
assert jnp.allclose(grad_np, grad_tf2)
272
273
print("params {}".format(w_tf2))
274
print("loss {}".format(loss_tf2))
275
print("grad {}".format(grad_tf2))
276
277