Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/bandit_demo.py
1192 views
1
2
import matplotlib.pyplot as plt
3
from functools import partial
4
5
import numpy as np
6
import jax
7
import jax.numpy as jnp
8
from jax import jit, tree_leaves, tree_map, vmap
9
from jax.random import split, PRNGKey, permutation
10
from jax import random
11
12
from tensorflow_probability.substrates import jax as tfp
13
tfd = tfp.distributions
14
15
import tensorflow_datasets as tfds
16
17
import flax
18
import flax.linen as nn
19
20
import optax
21
from sgmcmc_utils import build_optax_optimizer
22
23
class BanditEnvironment:
24
def __init__(self, key, X, Y):
25
# Randomise dataset rows
26
n_obs, n_features = X.shape
27
key, mykey = split(key)
28
new_ixs = random.choice(mykey, n_obs, (n_obs,), replace=False)
29
X = jnp.asarray(X)[new_ixs]
30
Y = jnp.asarray(Y)[new_ixs]
31
self.contexts = X
32
self.labels_onehot = Y
33
34
35
def get_context(self, t):
36
return self.contexts[t]
37
38
def get_reward(self, t, action):
39
return np.float(self.labels_onehot[t][action])
40
41
def warmup(self, num_pulls):
42
num_steps, num_actions = self.labels_onehot.shape
43
# Create array of round-robin actions: 0, 1, 2, 0, 1, 2, 0, 1, 2, ...
44
warmup_actions = jnp.arange(num_actions)
45
warmup_actions = jnp.repeat(warmup_actions, num_pulls).reshape(num_actions, -1)
46
warmup_actions = warmup_actions.reshape(-1, order="F")
47
num_warmup_actions, *_ = warmup_actions.shape
48
actions = [int(a) for a in warmup_actions]
49
contexts = []
50
rewards = []
51
for t, a in enumerate(actions):
52
context = self.get_context(t)
53
reward = self.get_reward(t, a)
54
contexts.append(context)
55
rewards.append(reward)
56
return contexts, actions, rewards
57
58
59
60
61
class LinearBandit:
62
def __init__(self, num_features, num_arms):
63
self.num_features = num_features
64
self.num_arms = num_arms
65
66
def init_bel(self, key, contexts, actions, rewards):
67
eta = 6.0
68
lmbda = 0.25
69
bel = {
70
"mu": jnp.zeros((self.num_arms, self.num_features)),
71
"Sigma": 1 * lmbda * jnp.eye(self.num_features) * jnp.ones((self.num_arms, 1, 1)),
72
"a": eta * jnp.ones(self.num_arms),
73
"b": eta * jnp.ones(self.num_arms),
74
}
75
nwarmup = len(rewards)
76
for t in range(nwarmup): # could do batch update
77
context = contexts[t]
78
action = actions[t]
79
reward = rewards[t]
80
bel = self.update_bel(key, bel, context, action, reward)
81
return bel
82
83
def update_bel(self, key, bel, context, action, reward):
84
mu_k = bel["mu"][action]
85
Sigma_k = bel["Sigma"][action]
86
Lambda_k = jnp.linalg.inv(Sigma_k)
87
a_k = bel["a"][action]
88
b_k = bel["b"][action]
89
90
# weight params
91
Lambda_update = jnp.outer(context, context) + Lambda_k
92
Sigma_update = jnp.linalg.inv(Lambda_update)
93
mu_update = Sigma_update @ (Lambda_k @ mu_k + context * reward)
94
# noise params
95
a_update = a_k + 1/2
96
b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2
97
98
# Update only the chosen action at time t
99
mu = jax.ops.index_update(bel["mu"], action, mu_update)
100
Sigma = jax.ops.index_update(bel["Sigma"], action, Sigma_update)
101
a = jax.ops.index_update(bel["a"], action, a_update)
102
b = jax.ops.index_update(bel["b"], action, b_update)
103
104
bel = {"mu": mu, "Sigma": Sigma, "a": a, "b": b}
105
return bel
106
107
def sample_params(self, key, bel):
108
key_sigma, key_w = random.split(key, 2)
109
sigma2_samp = tfd.InverseGamma(concentration=bel["a"], scale=bel["b"]).sample(seed=key_sigma)
110
cov_matrix_samples = sigma2_samp[:, None, None] * bel["Sigma"]
111
w_samp = tfd.MultivariateNormalFullCovariance(loc=bel["mu"], covariance_matrix=cov_matrix_samples).sample(seed=key_w)
112
return sigma2_samp, w_samp
113
114
def choose_action(self, key, bel, context):
115
# Thompson sampling strategy
116
# Could also use epsilon greedy or UCB
117
sigma2_samp, w_samp = self.sample_params(key, bel)
118
predicted_reward = jnp.einsum("m,km->k", context, w_samp)
119
action = predicted_reward.argmax()
120
return action
121
122
123
class MLP(nn.Module):
124
num_features: int
125
num_arms: int
126
@nn.compact
127
def __call__(self, x): # x has both context and action
128
x = nn.relu(nn.Dense(100)(x))
129
x = nn.relu(nn.Dense(50)(x))
130
x = nn.Dense(1)(x) # identity activation for scalar regression output
131
return x
132
133
134
def fit_model(key, model, X, y, variables):
135
opt = optax.adam(learning_rate=1e-1)
136
data = (X,y)
137
batch_size = 512
138
nsteps = 100
139
140
def loglik(params, x, y):
141
pred_y = model.apply(variables, x)
142
loss = jnp.square(y - pred_y)
143
return loss
144
145
def logprior(params):
146
# Spherical Gaussian prior
147
l2_regularizer = 0.01
148
leaves_of_params = tree_leaves(params)
149
return sum(tree_map(lambda p: jnp.sum(jax.scipy.stats.norm.logpdf(p, scale=l2_regularizer)), leaves_of_params))
150
151
optimizer = build_optax_optimizer(opt, loglik, logprior, data, batch_size, pbar=False)
152
key, mykey = split(key)
153
params = variables["params"]
154
params, log_post_trace = optimizer(mykey, nsteps, params)
155
variables["params"] = params
156
return variables
157
158
159
def NeuralGreedy():
160
def __init__(self, num_features, num_arms, epsilon, memory=None):
161
self.num_features = num_features
162
self.num_arms = num_arms
163
self.model = MLP(num_features, num_arms)
164
self.epsilon = epsilon
165
self.memory = memory
166
167
def encode(self, context, action):
168
action_onehot = jax.nn.one_hot(action, self.num_arms)
169
ndims = self.num_features + self.num_arms
170
x = np.concatenate([context, action_onehot]);
171
return x
172
173
def init_bel(self, key, contexts, actions, rewards):
174
ndims = self.num_features + self.num_arms
175
ndata = len(rewards)
176
X = jax.vmap(self.encode)(contexts, actions)
177
y = rewards
178
variables = self.model.init(key, X)
179
variables = fit_model(key, self.model, X, y, variables)
180
bel = (X, y, variables)
181
return bel
182
183
def update_bel(self, key, bel, context, action, reward):
184
(X, y, variables) = bel
185
if self.memory is not None: # finite memory
186
if len(y)==self.memory: # memory is full
187
X.pop(0)
188
y.pop(0)
189
x = self.encode(context, action)
190
X.append(x)
191
y.append(reward)
192
variables = fit_model(key, self.model, X, y, variables)
193
bel = (X, y, variables)
194
return bel
195
196
def choose_action(self, key, bel, context):
197
(X, y, variables) = bel
198
key, mykey = split(key)
199
coin = jax.random.bernoulli(mykey, self.epsilon, (1))
200
if coin == 0:
201
# random action
202
actions = jnp.arange(self.num_arms)
203
key, mykey = split(key)
204
action = jax.random.choice(mykey, actions)
205
else:
206
# greedy action
207
predicted_rewards = jnp.zeros((self.num_arms,))
208
# should make this a minibatch of A examples
209
# so we can predict all rewards in parallel
210
for a in range(self.num_arms):
211
x = self.encode(context, a)
212
predicted_rewards[a] = self.model.apply(variables, x)
213
action = predicted_rewards.argmax()
214
return action
215
216
217
218
219
220
def run_bandit(key, bandit, env, nsteps, npulls):
221
contexts, actions, rewards = env.warmup(npulls)
222
nwarmup = len(rewards)
223
key, mykey = split(key)
224
bel = bandit.init_bel(mykey, contexts, actions, rewards)
225
for i in range(nsteps - nwarmup):
226
t = nwarmup + i
227
print(f'step {t}')
228
context = env.get_context(t)
229
key, mykey = split(key)
230
action = bandit.choose_action(mykey, bel, context)
231
reward = env.get_reward(t, action)
232
key, mykey = split(key)
233
bel = bandit.update_bel(mykey, bel, context, action, reward)
234
contexts.append(context)
235
actions.append(action)
236
rewards.append(reward)
237
return contexts, actions, rewards
238
239
240
def get_datasets():
241
ds_builder = tfds.builder('mnist')
242
ds_builder.download_and_prepare()
243
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
244
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
245
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
246
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
247
return train_ds, test_ds
248
249
def get_mnist():
250
train_ds, test_ds = get_datasets()
251
train_ds["image"] = train_ds["image"].reshape(-1, 28 ** 2)
252
test_ds["image"] = test_ds["image"].reshape(-1, 28 ** 2)
253
254
num_arms = len(jnp.unique(train_ds["label"]))
255
num_obs, num_features = train_ds["image"].shape
256
257
train_ds["X"] = train_ds.pop("image")
258
train_ds["Y"] = jax.nn.one_hot(train_ds.pop("label"), num_arms)
259
260
test_ds["X"] = test_ds.pop("image")
261
test_ds["Y"] = jax.nn.one_hot(test_ds.pop("label"), num_arms)
262
263
num_train = 5000
264
X = train_ds["X"][:num_train]
265
Y = train_ds["Y"][:num_train]
266
return X, Y
267
268
X, Y = get_mnist()
269
270
# test the code
271
key = random.PRNGKey(314)
272
env = BanditEnvironment(key, X, Y)
273
contexts, actions, rewards = env.warmup(2)
274
print(len(contexts))
275
print(contexts[0].shape)
276
277
num_obs, num_features = X.shape
278
_, num_arms = Y.shape
279
bandit = LinearBandit(num_features, num_arms)
280
281
# main loop
282
contexts, actions, rewards = run_bandit(key, bandit, env, nsteps=20, npulls=1)
283
print(len(rewards))
284
285
# multiple trials
286
'''
287
ntrials = 2
288
keys = random.split(key, ntrials)
289
npulls = 1
290
nsteps = 12
291
def trial(key):
292
env = MyEnvironment(key, X, Y)
293
contexts, actions, rewards = run_bandit(key, bandit, env, nsteps, npulls)
294
return jnp.array(42)
295
296
res = vmap(trial, in_axes=(0,))(keys)
297
print(res)
298
'''
299
300
# Neural greedy
301
'''
302
bandit = NeuralGreedy(num_features, num_arms, epsilon=0.1)
303
contexts, actions, rewards = run_bandit(key, bandit, env, nsteps=20, npulls=1)
304
print(len(rewards))
305
'''
306
307