Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aimacode
GitHub Repository: aimacode/aima-python
Path: blob/master/notebook4e.py
615 views
1
import time
2
from collections import defaultdict
3
from inspect import getsource
4
5
import ipywidgets as widgets
6
import matplotlib.pyplot as plt
7
import networkx as nx
8
import numpy as np
9
from IPython.display import HTML
10
from IPython.display import display
11
from PIL import Image
12
from matplotlib import lines
13
from matplotlib.colors import ListedColormap
14
15
from games import TicTacToe, alpha_beta_player, random_player, Fig52Extended
16
from learning import DataSet
17
from logic import parse_definite_clause, standardize_variables, unify_mm, subst
18
from search import GraphProblem, romania_map
19
20
21
# ______________________________________________________________________________
22
# Magic Words
23
24
25
def pseudocode(algorithm):
26
"""Print the pseudocode for the given algorithm."""
27
from urllib.request import urlopen
28
from IPython.display import Markdown
29
30
algorithm = algorithm.replace(' ', '-')
31
url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
32
f = urlopen(url)
33
md = f.read().decode('utf-8')
34
md = md.split('\n', 1)[-1].strip()
35
md = '#' + md
36
return Markdown(md)
37
38
39
def psource(*functions):
40
"""Print the source code for the given function(s)."""
41
source_code = '\n\n'.join(getsource(fn) for fn in functions)
42
try:
43
from pygments.formatters import HtmlFormatter
44
from pygments.lexers import PythonLexer
45
from pygments import highlight
46
47
display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True))))
48
49
except ImportError:
50
print(source_code)
51
52
53
def plot_model_boundary(dataset, attr1, attr2, model=None):
54
# prepare data
55
examples = np.asarray(dataset.examples)
56
X = np.asarray([examples[:, attr1], examples[:, attr2]])
57
y = examples[:, dataset.target]
58
h = 0.1
59
60
# create color maps
61
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#00AAFF'])
62
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#00AAFF'])
63
64
# calculate min, max and limits
65
x_min, x_max = X[0].min() - 1, X[0].max() + 1
66
y_min, y_max = X[1].min() - 1, X[1].max() + 1
67
# mesh the grid
68
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
69
np.arange(y_min, y_max, h))
70
Z = []
71
for grid in zip(xx.ravel(), yy.ravel()):
72
# put them back to the example
73
grid = np.round(grid, decimals=1).tolist()
74
Z.append(model(grid))
75
# Put the result into a color plot
76
Z = np.asarray(Z)
77
Z = Z.reshape(xx.shape)
78
plt.figure()
79
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
80
81
# Plot also the training points
82
plt.scatter(X[0], X[1], c=y, cmap=cmap_bold)
83
plt.xlim(xx.min(), xx.max())
84
plt.ylim(yy.min(), yy.max())
85
plt.show()
86
87
88
# ______________________________________________________________________________
89
# Iris Visualization
90
91
92
def show_iris(i=0, j=1, k=2):
93
"""Plots the iris dataset in a 3D plot.
94
The three axes are given by i, j and k,
95
which correspond to three of the four iris features."""
96
97
plt.rcParams.update(plt.rcParamsDefault)
98
99
fig = plt.figure()
100
ax = fig.add_subplot(111, projection='3d')
101
102
iris = DataSet(name="iris")
103
buckets = iris.split_values_by_classes()
104
105
features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
106
f1, f2, f3 = features[i], features[j], features[k]
107
108
a_setosa = [v[i] for v in buckets["setosa"]]
109
b_setosa = [v[j] for v in buckets["setosa"]]
110
c_setosa = [v[k] for v in buckets["setosa"]]
111
112
a_virginica = [v[i] for v in buckets["virginica"]]
113
b_virginica = [v[j] for v in buckets["virginica"]]
114
c_virginica = [v[k] for v in buckets["virginica"]]
115
116
a_versicolor = [v[i] for v in buckets["versicolor"]]
117
b_versicolor = [v[j] for v in buckets["versicolor"]]
118
c_versicolor = [v[k] for v in buckets["versicolor"]]
119
120
for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa),
121
('g', '^', a_virginica, b_virginica, c_virginica),
122
('r', 'o', a_versicolor, b_versicolor, c_versicolor)]:
123
ax.scatter(sl, sw, pl, c=c, marker=m)
124
125
ax.set_xlabel(f1)
126
ax.set_ylabel(f2)
127
ax.set_zlabel(f3)
128
129
plt.show()
130
131
132
# ______________________________________________________________________________
133
# MNIST
134
135
136
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
137
import os, struct
138
import array
139
import numpy as np
140
141
if fashion:
142
path = "aima-data/MNIST/Fashion"
143
144
plt.rcParams.update(plt.rcParamsDefault)
145
plt.rcParams['figure.figsize'] = (10.0, 8.0)
146
plt.rcParams['image.interpolation'] = 'nearest'
147
plt.rcParams['image.cmap'] = 'gray'
148
149
train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
150
train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
151
test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
152
test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
153
154
magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
155
tr_img = array.array("B", train_img_file.read())
156
train_img_file.close()
157
magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
158
tr_lbl = array.array("b", train_lbl_file.read())
159
train_lbl_file.close()
160
161
magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16))
162
te_img = array.array("B", test_img_file.read())
163
test_img_file.close()
164
magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8))
165
te_lbl = array.array("b", test_lbl_file.read())
166
test_lbl_file.close()
167
168
# print(len(tr_img), len(tr_lbl), tr_size)
169
# print(len(te_img), len(te_lbl), te_size)
170
171
train_img = np.zeros((tr_size, tr_rows * tr_cols), dtype=np.int16)
172
train_lbl = np.zeros((tr_size,), dtype=np.int8)
173
for i in range(tr_size):
174
train_img[i] = np.array(tr_img[i * tr_rows * tr_cols: (i + 1) * tr_rows * tr_cols]).reshape((tr_rows * te_cols))
175
train_lbl[i] = tr_lbl[i]
176
177
test_img = np.zeros((te_size, te_rows * te_cols), dtype=np.int16)
178
test_lbl = np.zeros((te_size,), dtype=np.int8)
179
for i in range(te_size):
180
test_img[i] = np.array(te_img[i * te_rows * te_cols: (i + 1) * te_rows * te_cols]).reshape((te_rows * te_cols))
181
test_lbl[i] = te_lbl[i]
182
183
return (train_img, train_lbl, test_img, test_lbl)
184
185
186
digit_classes = [str(i) for i in range(10)]
187
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
188
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
189
190
191
def show_MNIST(labels, images, samples=8, fashion=False):
192
if not fashion:
193
classes = digit_classes
194
else:
195
classes = fashion_classes
196
197
num_classes = len(classes)
198
199
for y, cls in enumerate(classes):
200
idxs = np.nonzero([i == y for i in labels])
201
idxs = np.random.choice(idxs[0], samples, replace=False)
202
for i, idx in enumerate(idxs):
203
plt_idx = i * num_classes + y + 1
204
plt.subplot(samples, num_classes, plt_idx)
205
plt.imshow(images[idx].reshape((28, 28)))
206
plt.axis("off")
207
if i == 0:
208
plt.title(cls)
209
210
plt.show()
211
212
213
def show_ave_MNIST(labels, images, fashion=False):
214
if not fashion:
215
item_type = "Digit"
216
classes = digit_classes
217
else:
218
item_type = "Apparel"
219
classes = fashion_classes
220
221
num_classes = len(classes)
222
223
for y, cls in enumerate(classes):
224
idxs = np.nonzero([i == y for i in labels])
225
print(item_type, y, ":", len(idxs[0]), "images.")
226
227
ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis=0)
228
# print(ave_img.shape)
229
230
plt.subplot(1, num_classes, y + 1)
231
plt.imshow(ave_img.reshape((28, 28)))
232
plt.axis("off")
233
plt.title(cls)
234
235
plt.show()
236
237
238
# ______________________________________________________________________________
239
# MDP
240
241
242
def make_plot_grid_step_function(columns, rows, U_over_time):
243
"""ipywidgets interactive function supports single parameter as input.
244
This function creates and return such a function by taking as input
245
other parameters."""
246
247
def plot_grid_step(iteration):
248
data = U_over_time[iteration]
249
data = defaultdict(lambda: 0, data)
250
grid = []
251
for row in range(rows):
252
current_row = []
253
for column in range(columns):
254
current_row.append(data[(column, row)])
255
grid.append(current_row)
256
grid.reverse() # output like book
257
fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation='nearest')
258
259
plt.axis('off')
260
fig.axes.get_xaxis().set_visible(False)
261
fig.axes.get_yaxis().set_visible(False)
262
263
for col in range(len(grid)):
264
for row in range(len(grid[0])):
265
magic = grid[col][row]
266
fig.axes.text(row, col, "{0:.2f}".format(magic), va='center', ha='center')
267
268
plt.show()
269
270
return plot_grid_step
271
272
273
def make_visualize(slider):
274
"""Takes an input a sliderand returns callback function
275
for timer and animation."""
276
277
def visualize_callback(visualize, time_step):
278
if visualize is True:
279
for i in range(slider.min, slider.max + 1):
280
slider.value = i
281
time.sleep(float(time_step))
282
283
return visualize_callback
284
285
286
# ______________________________________________________________________________
287
288
289
_canvas = """
290
<script type="text/javascript" src="./js/canvas.js"></script>
291
<div>
292
<canvas id="{0}" width="{1}" height="{2}" style="background:rgba(158, 167, 184, 0.2);" onclick='click_callback(this, event, "{3}")'></canvas>
293
</div>
294
295
<script> var {0}_canvas_object = new Canvas("{0}");</script>
296
""" # noqa
297
298
299
class Canvas:
300
"""Inherit from this class to manage the HTML canvas element in jupyter notebooks.
301
To create an object of this class any_name_xyz = Canvas("any_name_xyz")
302
The first argument given must be the name of the object being created.
303
IPython must be able to reference the variable name that is being passed."""
304
305
def __init__(self, varname, width=800, height=600, cid=None):
306
self.name = varname
307
self.cid = cid or varname
308
self.width = width
309
self.height = height
310
self.html = _canvas.format(self.cid, self.width, self.height, self.name)
311
self.exec_list = []
312
display_html(self.html)
313
314
def mouse_click(self, x, y):
315
"""Override this method to handle mouse click at position (x, y)"""
316
raise NotImplementedError
317
318
def mouse_move(self, x, y):
319
raise NotImplementedError
320
321
def execute(self, exec_str):
322
"""Stores the command to be executed to a list which is used later during update()"""
323
if not isinstance(exec_str, str):
324
print("Invalid execution argument:", exec_str)
325
self.alert("Received invalid execution command format")
326
prefix = "{0}_canvas_object.".format(self.cid)
327
self.exec_list.append(prefix + exec_str + ';')
328
329
def fill(self, r, g, b):
330
"""Changes the fill color to a color in rgb format"""
331
self.execute("fill({0}, {1}, {2})".format(r, g, b))
332
333
def stroke(self, r, g, b):
334
"""Changes the colors of line/strokes to rgb"""
335
self.execute("stroke({0}, {1}, {2})".format(r, g, b))
336
337
def strokeWidth(self, w):
338
"""Changes the width of lines/strokes to 'w' pixels"""
339
self.execute("strokeWidth({0})".format(w))
340
341
def rect(self, x, y, w, h):
342
"""Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner"""
343
self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h))
344
345
def rect_n(self, xn, yn, wn, hn):
346
"""Similar to rect(), but the dimensions are normalized to fall between 0 and 1"""
347
x = round(xn * self.width)
348
y = round(yn * self.height)
349
w = round(wn * self.width)
350
h = round(hn * self.height)
351
self.rect(x, y, w, h)
352
353
def line(self, x1, y1, x2, y2):
354
"""Draw a line from (x1, y1) to (x2, y2)"""
355
self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2))
356
357
def line_n(self, x1n, y1n, x2n, y2n):
358
"""Similar to line(), but the dimensions are normalized to fall between 0 and 1"""
359
x1 = round(x1n * self.width)
360
y1 = round(y1n * self.height)
361
x2 = round(x2n * self.width)
362
y2 = round(y2n * self.height)
363
self.line(x1, y1, x2, y2)
364
365
def arc(self, x, y, r, start, stop):
366
"""Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'"""
367
self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop))
368
369
def arc_n(self, xn, yn, rn, start, stop):
370
"""Similar to arc(), but the dimensions are normalized to fall between 0 and 1
371
The normalizing factor for radius is selected between width and height by
372
seeing which is smaller."""
373
x = round(xn * self.width)
374
y = round(yn * self.height)
375
r = round(rn * min(self.width, self.height))
376
self.arc(x, y, r, start, stop)
377
378
def clear(self):
379
"""Clear the HTML canvas"""
380
self.execute("clear()")
381
382
def font(self, font):
383
"""Changes the font of text"""
384
self.execute('font("{0}")'.format(font))
385
386
def text(self, txt, x, y, fill=True):
387
"""Display a text at (x, y)"""
388
if fill:
389
self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y))
390
else:
391
self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y))
392
393
def text_n(self, txt, xn, yn, fill=True):
394
"""Similar to text(), but with normalized coordinates"""
395
x = round(xn * self.width)
396
y = round(yn * self.height)
397
self.text(txt, x, y, fill)
398
399
def alert(self, message):
400
"""Immediately display an alert"""
401
display_html('<script>alert("{0}")</script>'.format(message))
402
403
def update(self):
404
"""Execute the JS code to execute the commands queued by execute()"""
405
exec_code = "<script>\n" + '\n'.join(self.exec_list) + "\n</script>"
406
self.exec_list = []
407
display_html(exec_code)
408
409
410
def display_html(html_string):
411
display(HTML(html_string))
412
413
414
################################################################################
415
416
417
class Canvas_TicTacToe(Canvas):
418
"""Play a 3x3 TicTacToe game on HTML canvas"""
419
420
def __init__(self, varname, player_1='human', player_2='random',
421
width=300, height=350, cid=None):
422
valid_players = ('human', 'random', 'alpha_beta')
423
if player_1 not in valid_players or player_2 not in valid_players:
424
raise TypeError("Players must be one of {}".format(valid_players))
425
super().__init__(varname, width, height, cid)
426
self.ttt = TicTacToe()
427
self.state = self.ttt.initial
428
self.turn = 0
429
self.strokeWidth(5)
430
self.players = (player_1, player_2)
431
self.font("20px Arial")
432
self.draw_board()
433
434
def mouse_click(self, x, y):
435
player = self.players[self.turn]
436
if self.ttt.terminal_test(self.state):
437
if 0.55 <= x / self.width <= 0.95 and 6 / 7 <= y / self.height <= 6 / 7 + 1 / 8:
438
self.state = self.ttt.initial
439
self.turn = 0
440
self.draw_board()
441
return
442
443
if player == 'human':
444
x, y = int(3 * x / self.width) + 1, int(3 * y / (self.height * 6 / 7)) + 1
445
if (x, y) not in self.ttt.actions(self.state):
446
# Invalid move
447
return
448
move = (x, y)
449
elif player == 'alpha_beta':
450
move = alpha_beta_player(self.ttt, self.state)
451
else:
452
move = random_player(self.ttt, self.state)
453
self.state = self.ttt.result(self.state, move)
454
self.turn ^= 1
455
self.draw_board()
456
457
def draw_board(self):
458
self.clear()
459
self.stroke(0, 0, 0)
460
offset = 1 / 20
461
self.line_n(0 + offset, (1 / 3) * 6 / 7, 1 - offset, (1 / 3) * 6 / 7)
462
self.line_n(0 + offset, (2 / 3) * 6 / 7, 1 - offset, (2 / 3) * 6 / 7)
463
self.line_n(1 / 3, (0 + offset) * 6 / 7, 1 / 3, (1 - offset) * 6 / 7)
464
self.line_n(2 / 3, (0 + offset) * 6 / 7, 2 / 3, (1 - offset) * 6 / 7)
465
466
board = self.state.board
467
for mark in board:
468
if board[mark] == 'X':
469
self.draw_x(mark)
470
elif board[mark] == 'O':
471
self.draw_o(mark)
472
if self.ttt.terminal_test(self.state):
473
# End game message
474
utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial))
475
if utility == 0:
476
self.text_n('Game Draw!', offset, 6 / 7 + offset)
477
else:
478
self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6 / 7 + offset)
479
# Find the 3 and draw a line
480
self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
481
for i in range(3):
482
if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
483
len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
484
self.line_n(i / 3 + 1 / 6, offset * 6 / 7, i / 3 + 1 / 6, (1 - offset) * 6 / 7)
485
if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
486
len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
487
self.line_n(offset, (i / 3 + 1 / 6) * 6 / 7, 1 - offset, (i / 3 + 1 / 6) * 6 / 7)
488
if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
489
len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
490
self.line_n(offset, offset * 6 / 7, 1 - offset, (1 - offset) * 6 / 7)
491
if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
492
len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
493
self.line_n(offset, (1 - offset) * 6 / 7, 1 - offset, offset * 6 / 7)
494
# restart button
495
self.fill(0, 0, 255)
496
self.rect_n(0.5 + offset, 6 / 7, 0.4, 1 / 8)
497
self.fill(0, 0, 0)
498
self.text_n('Restart', 0.5 + 2 * offset, 13 / 14)
499
else: # Print which player's turn it is
500
self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
501
offset, 6 / 7 + offset)
502
503
self.update()
504
505
def draw_x(self, position):
506
self.stroke(0, 255, 0)
507
x, y = [i - 1 for i in position]
508
offset = 1 / 15
509
self.line_n(x / 3 + offset, (y / 3 + offset) * 6 / 7, x / 3 + 1 / 3 - offset, (y / 3 + 1 / 3 - offset) * 6 / 7)
510
self.line_n(x / 3 + 1 / 3 - offset, (y / 3 + offset) * 6 / 7, x / 3 + offset, (y / 3 + 1 / 3 - offset) * 6 / 7)
511
512
def draw_o(self, position):
513
self.stroke(255, 0, 0)
514
x, y = [i - 1 for i in position]
515
self.arc_n(x / 3 + 1 / 6, (y / 3 + 1 / 6) * 6 / 7, 1 / 9, 0, 360)
516
517
518
class Canvas_min_max(Canvas):
519
"""MinMax for Fig52Extended on HTML canvas"""
520
521
def __init__(self, varname, util_list, width=800, height=600, cid=None):
522
super().__init__(varname, width, height, cid)
523
self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
524
self.game = Fig52Extended()
525
self.game.utils = self.utils
526
self.nodes = list(range(40))
527
self.l = 1 / 40
528
self.node_pos = {}
529
for i in range(4):
530
base = len(self.node_pos)
531
row_size = 3 ** i
532
for node in [base + j for j in range(row_size)]:
533
self.node_pos[node] = ((node - base) / row_size + 1 / (2 * row_size) - self.l / 2,
534
self.l / 2 + (self.l + (1 - 5 * self.l) / 3) * i)
535
self.font("12px Arial")
536
self.node_stack = []
537
self.explored = {node for node in self.utils}
538
self.thick_lines = set()
539
self.change_list = []
540
self.draw_graph()
541
self.stack_manager = self.stack_manager_gen()
542
543
def min_max(self, node):
544
game = self.game
545
player = game.to_move(node)
546
547
def max_value(node):
548
if game.terminal_test(node):
549
return game.utility(node, player)
550
self.change_list.append(('a', node))
551
self.change_list.append(('h',))
552
max_a = max(game.actions(node), key=lambda x: min_value(game.result(node, x)))
553
max_node = game.result(node, max_a)
554
self.utils[node] = self.utils[max_node]
555
x1, y1 = self.node_pos[node]
556
x2, y2 = self.node_pos[max_node]
557
self.change_list.append(('l', (node, max_node - 3 * node - 1)))
558
self.change_list.append(('e', node))
559
self.change_list.append(('p',))
560
self.change_list.append(('h',))
561
return self.utils[node]
562
563
def min_value(node):
564
if game.terminal_test(node):
565
return game.utility(node, player)
566
self.change_list.append(('a', node))
567
self.change_list.append(('h',))
568
min_a = min(game.actions(node), key=lambda x: max_value(game.result(node, x)))
569
min_node = game.result(node, min_a)
570
self.utils[node] = self.utils[min_node]
571
x1, y1 = self.node_pos[node]
572
x2, y2 = self.node_pos[min_node]
573
self.change_list.append(('l', (node, min_node - 3 * node - 1)))
574
self.change_list.append(('e', node))
575
self.change_list.append(('p',))
576
self.change_list.append(('h',))
577
return self.utils[node]
578
579
return max_value(node)
580
581
def stack_manager_gen(self):
582
self.min_max(0)
583
for change in self.change_list:
584
if change[0] == 'a':
585
self.node_stack.append(change[1])
586
elif change[0] == 'e':
587
self.explored.add(change[1])
588
elif change[0] == 'h':
589
yield
590
elif change[0] == 'l':
591
self.thick_lines.add(change[1])
592
elif change[0] == 'p':
593
self.node_stack.pop()
594
595
def mouse_click(self, x, y):
596
try:
597
self.stack_manager.send(None)
598
except StopIteration:
599
pass
600
self.draw_graph()
601
602
def draw_graph(self):
603
self.clear()
604
# draw nodes
605
self.stroke(0, 0, 0)
606
self.strokeWidth(1)
607
# highlight for nodes in stack
608
for node in self.node_stack:
609
x, y = self.node_pos[node]
610
self.fill(200, 200, 0)
611
self.rect_n(x - self.l / 5, y - self.l / 5, self.l * 7 / 5, self.l * 7 / 5)
612
for node in self.nodes:
613
x, y = self.node_pos[node]
614
if node in self.explored:
615
self.fill(255, 255, 255)
616
else:
617
self.fill(200, 200, 200)
618
self.rect_n(x, y, self.l, self.l)
619
self.line_n(x, y, x + self.l, y)
620
self.line_n(x, y, x, y + self.l)
621
self.line_n(x + self.l, y + self.l, x + self.l, y)
622
self.line_n(x + self.l, y + self.l, x, y + self.l)
623
self.fill(0, 0, 0)
624
if node in self.explored:
625
self.text_n(self.utils[node], x + self.l / 10, y + self.l * 9 / 10)
626
# draw edges
627
for i in range(13):
628
x1, y1 = self.node_pos[i][0] + self.l / 2, self.node_pos[i][1] + self.l
629
for j in range(3):
630
x2, y2 = self.node_pos[i * 3 + j + 1][0] + self.l / 2, self.node_pos[i * 3 + j + 1][1]
631
if i in [1, 2, 3]:
632
self.stroke(200, 0, 0)
633
else:
634
self.stroke(0, 200, 0)
635
if (i, j) in self.thick_lines:
636
self.strokeWidth(3)
637
else:
638
self.strokeWidth(1)
639
self.line_n(x1, y1, x2, y2)
640
self.update()
641
642
643
class Canvas_alpha_beta(Canvas):
644
"""Alpha-beta pruning for Fig52Extended on HTML canvas"""
645
646
def __init__(self, varname, util_list, width=800, height=600, cid=None):
647
super().__init__(varname, width, height, cid)
648
self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
649
self.game = Fig52Extended()
650
self.game.utils = self.utils
651
self.nodes = list(range(40))
652
self.l = 1 / 40
653
self.node_pos = {}
654
for i in range(4):
655
base = len(self.node_pos)
656
row_size = 3 ** i
657
for node in [base + j for j in range(row_size)]:
658
self.node_pos[node] = ((node - base) / row_size + 1 / (2 * row_size) - self.l / 2,
659
3 * self.l / 2 + (self.l + (1 - 6 * self.l) / 3) * i)
660
self.font("12px Arial")
661
self.node_stack = []
662
self.explored = {node for node in self.utils}
663
self.pruned = set()
664
self.ab = {}
665
self.thick_lines = set()
666
self.change_list = []
667
self.draw_graph()
668
self.stack_manager = self.stack_manager_gen()
669
670
def alpha_beta_search(self, node):
671
game = self.game
672
player = game.to_move(node)
673
674
# Functions used by alpha_beta
675
def max_value(node, alpha, beta):
676
if game.terminal_test(node):
677
self.change_list.append(('a', node))
678
self.change_list.append(('h',))
679
self.change_list.append(('p',))
680
return game.utility(node, player)
681
v = -np.inf
682
self.change_list.append(('a', node))
683
self.change_list.append(('ab', node, v, beta))
684
self.change_list.append(('h',))
685
for a in game.actions(node):
686
min_val = min_value(game.result(node, a), alpha, beta)
687
if v < min_val:
688
v = min_val
689
max_node = game.result(node, a)
690
self.change_list.append(('ab', node, v, beta))
691
if v >= beta:
692
self.change_list.append(('h',))
693
self.pruned.add(node)
694
break
695
alpha = max(alpha, v)
696
self.utils[node] = v
697
if node not in self.pruned:
698
self.change_list.append(('l', (node, max_node - 3 * node - 1)))
699
self.change_list.append(('e', node))
700
self.change_list.append(('p',))
701
self.change_list.append(('h',))
702
return v
703
704
def min_value(node, alpha, beta):
705
if game.terminal_test(node):
706
self.change_list.append(('a', node))
707
self.change_list.append(('h',))
708
self.change_list.append(('p',))
709
return game.utility(node, player)
710
v = np.inf
711
self.change_list.append(('a', node))
712
self.change_list.append(('ab', node, alpha, v))
713
self.change_list.append(('h',))
714
for a in game.actions(node):
715
max_val = max_value(game.result(node, a), alpha, beta)
716
if v > max_val:
717
v = max_val
718
min_node = game.result(node, a)
719
self.change_list.append(('ab', node, alpha, v))
720
if v <= alpha:
721
self.change_list.append(('h',))
722
self.pruned.add(node)
723
break
724
beta = min(beta, v)
725
self.utils[node] = v
726
if node not in self.pruned:
727
self.change_list.append(('l', (node, min_node - 3 * node - 1)))
728
self.change_list.append(('e', node))
729
self.change_list.append(('p',))
730
self.change_list.append(('h',))
731
return v
732
733
return max_value(node, -np.inf, np.inf)
734
735
def stack_manager_gen(self):
736
self.alpha_beta_search(0)
737
for change in self.change_list:
738
if change[0] == 'a':
739
self.node_stack.append(change[1])
740
elif change[0] == 'ab':
741
self.ab[change[1]] = change[2:]
742
elif change[0] == 'e':
743
self.explored.add(change[1])
744
elif change[0] == 'h':
745
yield
746
elif change[0] == 'l':
747
self.thick_lines.add(change[1])
748
elif change[0] == 'p':
749
self.node_stack.pop()
750
751
def mouse_click(self, x, y):
752
try:
753
self.stack_manager.send(None)
754
except StopIteration:
755
pass
756
self.draw_graph()
757
758
def draw_graph(self):
759
self.clear()
760
# draw nodes
761
self.stroke(0, 0, 0)
762
self.strokeWidth(1)
763
# highlight for nodes in stack
764
for node in self.node_stack:
765
x, y = self.node_pos[node]
766
# alpha > beta
767
if node not in self.explored and self.ab[node][0] > self.ab[node][1]:
768
self.fill(200, 100, 100)
769
else:
770
self.fill(200, 200, 0)
771
self.rect_n(x - self.l / 5, y - self.l / 5, self.l * 7 / 5, self.l * 7 / 5)
772
for node in self.nodes:
773
x, y = self.node_pos[node]
774
if node in self.explored:
775
if node in self.pruned:
776
self.fill(50, 50, 50)
777
else:
778
self.fill(255, 255, 255)
779
else:
780
self.fill(200, 200, 200)
781
self.rect_n(x, y, self.l, self.l)
782
self.line_n(x, y, x + self.l, y)
783
self.line_n(x, y, x, y + self.l)
784
self.line_n(x + self.l, y + self.l, x + self.l, y)
785
self.line_n(x + self.l, y + self.l, x, y + self.l)
786
self.fill(0, 0, 0)
787
if node in self.explored and node not in self.pruned:
788
self.text_n(self.utils[node], x + self.l / 10, y + self.l * 9 / 10)
789
# draw edges
790
for i in range(13):
791
x1, y1 = self.node_pos[i][0] + self.l / 2, self.node_pos[i][1] + self.l
792
for j in range(3):
793
x2, y2 = self.node_pos[i * 3 + j + 1][0] + self.l / 2, self.node_pos[i * 3 + j + 1][1]
794
if i in [1, 2, 3]:
795
self.stroke(200, 0, 0)
796
else:
797
self.stroke(0, 200, 0)
798
if (i, j) in self.thick_lines:
799
self.strokeWidth(3)
800
else:
801
self.strokeWidth(1)
802
self.line_n(x1, y1, x2, y2)
803
# display alpha and beta
804
for node in self.node_stack:
805
if node not in self.explored:
806
x, y = self.node_pos[node]
807
alpha, beta = self.ab[node]
808
self.text_n(alpha, x - self.l / 2, y - self.l / 10)
809
self.text_n(beta, x + self.l, y - self.l / 10)
810
self.update()
811
812
813
class Canvas_fol_bc_ask(Canvas):
814
"""fol_bc_ask() on HTML canvas"""
815
816
def __init__(self, varname, kb, query, width=800, height=600, cid=None):
817
super().__init__(varname, width, height, cid)
818
self.kb = kb
819
self.query = query
820
self.l = 1 / 20
821
self.b = 3 * self.l
822
bc_out = list(self.fol_bc_ask())
823
if len(bc_out) == 0:
824
self.valid = False
825
else:
826
self.valid = True
827
graph = bc_out[0][0][0]
828
s = bc_out[0][1]
829
while True:
830
new_graph = subst(s, graph)
831
if graph == new_graph:
832
break
833
graph = new_graph
834
self.make_table(graph)
835
self.context = None
836
self.draw_table()
837
838
def fol_bc_ask(self):
839
KB = self.kb
840
query = self.query
841
842
def fol_bc_or(KB, goal, theta):
843
for rule in KB.fetch_rules_for_goal(goal):
844
lhs, rhs = parse_definite_clause(standardize_variables(rule))
845
for theta1 in fol_bc_and(KB, lhs, unify_mm(rhs, goal, theta)):
846
yield ([(goal, theta1[0])], theta1[1])
847
848
def fol_bc_and(KB, goals, theta):
849
if theta is None:
850
pass
851
elif not goals:
852
yield ([], theta)
853
else:
854
first, rest = goals[0], goals[1:]
855
for theta1 in fol_bc_or(KB, subst(theta, first), theta):
856
for theta2 in fol_bc_and(KB, rest, theta1[1]):
857
yield (theta1[0] + theta2[0], theta2[1])
858
859
return fol_bc_or(KB, query, {})
860
861
def make_table(self, graph):
862
table = []
863
pos = {}
864
links = set()
865
edges = set()
866
867
def dfs(node, depth):
868
if len(table) <= depth:
869
table.append([])
870
pos = len(table[depth])
871
table[depth].append(node[0])
872
for child in node[1]:
873
child_id = dfs(child, depth + 1)
874
links.add(((depth, pos), child_id))
875
return (depth, pos)
876
877
dfs(graph, 0)
878
y_off = 0.85 / len(table)
879
for i, row in enumerate(table):
880
x_off = 0.95 / len(row)
881
for j, node in enumerate(row):
882
pos[(i, j)] = (0.025 + j * x_off + (x_off - self.b) / 2, 0.025 + i * y_off + (y_off - self.l) / 2)
883
for p, c in links:
884
x1, y1 = pos[p]
885
x2, y2 = pos[c]
886
edges.add((x1 + self.b / 2, y1 + self.l, x2 + self.b / 2, y2))
887
888
self.table = table
889
self.pos = pos
890
self.edges = edges
891
892
def mouse_click(self, x, y):
893
x, y = x / self.width, y / self.height
894
for node in self.pos:
895
xs, ys = self.pos[node]
896
xe, ye = xs + self.b, ys + self.l
897
if xs <= x <= xe and ys <= y <= ye:
898
self.context = node
899
break
900
self.draw_table()
901
902
def draw_table(self):
903
self.clear()
904
self.strokeWidth(3)
905
self.stroke(0, 0, 0)
906
self.font("12px Arial")
907
if self.valid:
908
# draw nodes
909
for i, j in self.pos:
910
x, y = self.pos[(i, j)]
911
self.fill(200, 200, 200)
912
self.rect_n(x, y, self.b, self.l)
913
self.line_n(x, y, x + self.b, y)
914
self.line_n(x, y, x, y + self.l)
915
self.line_n(x + self.b, y, x + self.b, y + self.l)
916
self.line_n(x, y + self.l, x + self.b, y + self.l)
917
self.fill(0, 0, 0)
918
self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01)
919
# draw edges
920
for x1, y1, x2, y2 in self.edges:
921
self.line_n(x1, y1, x2, y2)
922
else:
923
self.fill(255, 0, 0)
924
self.rect_n(0, 0, 1, 1)
925
# text area
926
self.fill(255, 255, 255)
927
self.rect_n(0, 0.9, 1, 0.1)
928
self.strokeWidth(5)
929
self.stroke(0, 0, 0)
930
self.line_n(0, 0.9, 1, 0.9)
931
self.font("22px Arial")
932
self.fill(0, 0, 0)
933
self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975)
934
self.update()
935
936
937
############################################################################################################
938
939
##################### Functions to assist plotting in search.ipynb ####################
940
941
############################################################################################################
942
943
944
def show_map(graph_data, node_colors=None):
945
G = nx.Graph(graph_data['graph_dict'])
946
node_colors = node_colors or graph_data['node_colors']
947
node_positions = graph_data['node_positions']
948
node_label_pos = graph_data['node_label_positions']
949
edge_weights = graph_data['edge_weights']
950
951
# set the size of the plot
952
plt.figure(figsize=(18, 13))
953
# draw the graph (both nodes and edges) with locations from romania_locations
954
nx.draw(G, pos={k: node_positions[k] for k in G.nodes()},
955
node_color=[node_colors[node] for node in G.nodes()], linewidths=0.3, edgecolors='k')
956
957
# draw labels for nodes
958
node_label_handles = nx.draw_networkx_labels(G, pos=node_label_pos, font_size=14)
959
960
# add a white bounding box behind the node labels
961
[label.set_bbox(dict(facecolor='white', edgecolor='none')) for label in node_label_handles.values()]
962
963
# add edge lables to the graph
964
nx.draw_networkx_edge_labels(G, pos=node_positions, edge_labels=edge_weights, font_size=14)
965
966
# add a legend
967
white_circle = lines.Line2D([], [], color="white", marker='o', markersize=15, markerfacecolor="white")
968
orange_circle = lines.Line2D([], [], color="orange", marker='o', markersize=15, markerfacecolor="orange")
969
red_circle = lines.Line2D([], [], color="red", marker='o', markersize=15, markerfacecolor="red")
970
gray_circle = lines.Line2D([], [], color="gray", marker='o', markersize=15, markerfacecolor="gray")
971
green_circle = lines.Line2D([], [], color="green", marker='o', markersize=15, markerfacecolor="green")
972
plt.legend((white_circle, orange_circle, red_circle, gray_circle, green_circle),
973
('Un-explored', 'Frontier', 'Currently Exploring', 'Explored', 'Final Solution'),
974
numpoints=1, prop={'size': 16}, loc=(.8, .75))
975
976
# show the plot. No need to use in notebooks. nx.draw will show the graph itself.
977
plt.show()
978
979
980
# helper functions for visualisations
981
982
def final_path_colors(initial_node_colors, problem, solution):
983
"""Return a node_colors dict of the final path provided the problem and solution."""
984
985
# get initial node colors
986
final_colors = dict(initial_node_colors)
987
# color all the nodes in solution and starting node to green
988
final_colors[problem.initial] = "green"
989
for node in solution:
990
final_colors[node] = "green"
991
return final_colors
992
993
994
def display_visual(graph_data, user_input, algorithm=None, problem=None):
995
initial_node_colors = graph_data['node_colors']
996
if user_input is False:
997
def slider_callback(iteration):
998
# don't show graph for the first time running the cell calling this function
999
try:
1000
show_map(graph_data, node_colors=all_node_colors[iteration])
1001
except:
1002
pass
1003
1004
def visualize_callback(visualize):
1005
if visualize is True:
1006
button.value = False
1007
1008
global all_node_colors
1009
1010
iterations, all_node_colors, node = algorithm(problem)
1011
solution = node.solution()
1012
all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
1013
1014
slider.max = len(all_node_colors) - 1
1015
1016
for i in range(slider.max + 1):
1017
slider.value = i
1018
# time.sleep(.5)
1019
1020
slider = widgets.IntSlider(min=0, max=1, step=1, value=0)
1021
slider_visual = widgets.interactive(slider_callback, iteration=slider)
1022
display(slider_visual)
1023
1024
button = widgets.ToggleButton(value=False)
1025
button_visual = widgets.interactive(visualize_callback, visualize=button)
1026
display(button_visual)
1027
1028
if user_input is True:
1029
node_colors = dict(initial_node_colors)
1030
if isinstance(algorithm, dict):
1031
assert set(algorithm.keys()).issubset({"Breadth First Tree Search",
1032
"Depth First Tree Search",
1033
"Breadth First Search",
1034
"Depth First Graph Search",
1035
"Best First Graph Search",
1036
"Uniform Cost Search",
1037
"Depth Limited Search",
1038
"Iterative Deepening Search",
1039
"Greedy Best First Search",
1040
"A-star Search",
1041
"Recursive Best First Search"})
1042
1043
algo_dropdown = widgets.Dropdown(description="Search algorithm: ",
1044
options=sorted(list(algorithm.keys())),
1045
value="Breadth First Tree Search")
1046
display(algo_dropdown)
1047
elif algorithm is None:
1048
print("No algorithm to run.")
1049
return 0
1050
1051
def slider_callback(iteration):
1052
# don't show graph for the first time running the cell calling this function
1053
try:
1054
show_map(graph_data, node_colors=all_node_colors[iteration])
1055
except:
1056
pass
1057
1058
def visualize_callback(visualize):
1059
if visualize is True:
1060
button.value = False
1061
1062
problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map)
1063
global all_node_colors
1064
1065
user_algorithm = algorithm[algo_dropdown.value]
1066
1067
iterations, all_node_colors, node = user_algorithm(problem)
1068
solution = node.solution()
1069
all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
1070
1071
slider.max = len(all_node_colors) - 1
1072
1073
for i in range(slider.max + 1):
1074
slider.value = i
1075
# time.sleep(.5)
1076
1077
start_dropdown = widgets.Dropdown(description="Start city: ",
1078
options=sorted(list(node_colors.keys())), value="Arad")
1079
display(start_dropdown)
1080
1081
end_dropdown = widgets.Dropdown(description="Goal city: ",
1082
options=sorted(list(node_colors.keys())), value="Fagaras")
1083
display(end_dropdown)
1084
1085
button = widgets.ToggleButton(value=False)
1086
button_visual = widgets.interactive(visualize_callback, visualize=button)
1087
display(button_visual)
1088
1089
slider = widgets.IntSlider(min=0, max=1, step=1, value=0)
1090
slider_visual = widgets.interactive(slider_callback, iteration=slider)
1091
display(slider_visual)
1092
1093
1094
# Function to plot NQueensCSP in csp.py and NQueensProblem in search.py
1095
def plot_NQueens(solution):
1096
n = len(solution)
1097
board = np.array([2 * int((i + j) % 2) for j in range(n) for i in range(n)]).reshape((n, n))
1098
im = Image.open('images/queen_s.png')
1099
height = im.size[1]
1100
im = np.array(im).astype(np.float) / 255
1101
fig = plt.figure(figsize=(7, 7))
1102
ax = fig.add_subplot(111)
1103
ax.set_title('{} Queens'.format(n))
1104
plt.imshow(board, cmap='binary', interpolation='nearest')
1105
# NQueensCSP gives a solution as a dictionary
1106
if isinstance(solution, dict):
1107
for (k, v) in solution.items():
1108
newax = fig.add_axes([0.064 + (k * 0.112), 0.062 + ((7 - v) * 0.112), 0.1, 0.1], zorder=1)
1109
newax.imshow(im)
1110
newax.axis('off')
1111
# NQueensProblem gives a solution as a list
1112
elif isinstance(solution, list):
1113
for (k, v) in enumerate(solution):
1114
newax = fig.add_axes([0.064 + (k * 0.112), 0.062 + ((7 - v) * 0.112), 0.1, 0.1], zorder=1)
1115
newax.imshow(im)
1116
newax.axis('off')
1117
fig.tight_layout()
1118
plt.show()
1119
1120
1121
# Function to plot a heatmap, given a grid
1122
def heatmap(grid, cmap='binary', interpolation='nearest'):
1123
fig = plt.figure(figsize=(7, 7))
1124
ax = fig.add_subplot(111)
1125
ax.set_title('Heatmap')
1126
plt.imshow(grid, cmap=cmap, interpolation=interpolation)
1127
fig.tight_layout()
1128
plt.show()
1129
1130
1131
# Generates a gaussian kernel
1132
def gaussian_kernel(l=5, sig=1.0):
1133
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
1134
xx, yy = np.meshgrid(ax, ax)
1135
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
1136
return kernel
1137
1138
1139
# Plots utility function for a POMDP
1140
def plot_pomdp_utility(utility):
1141
save = utility['0'][0]
1142
delete = utility['1'][0]
1143
ask_save = utility['2'][0]
1144
ask_delete = utility['2'][-1]
1145
left = (save[0] - ask_save[0]) / (save[0] - ask_save[0] + ask_save[1] - save[1])
1146
right = (delete[0] - ask_delete[0]) / (delete[0] - ask_delete[0] + ask_delete[1] - delete[1])
1147
1148
colors = ['g', 'b', 'k']
1149
for action in utility:
1150
for value in utility[action]:
1151
plt.plot(value, color=colors[int(action)])
1152
plt.vlines([left, right], -20, 10, linestyles='dashed', colors='c')
1153
plt.ylim(-20, 13)
1154
plt.xlim(0, 1)
1155
plt.text(left / 2 - 0.05, 10, 'Save')
1156
plt.text((right + left) / 2 - 0.02, 10, 'Ask')
1157
plt.text((right + 1) / 2 - 0.07, 10, 'Delete')
1158
plt.show()
1159
1160