Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aimacode
GitHub Repository: aimacode/aima-python
Path: blob/master/notebook.py
615 views
1
import time
2
from collections import defaultdict
3
from inspect import getsource
4
5
import ipywidgets as widgets
6
import matplotlib.pyplot as plt
7
import networkx as nx
8
import numpy as np
9
from IPython.display import HTML
10
from IPython.display import display
11
from PIL import Image
12
from matplotlib import lines
13
14
from games import TicTacToe, alpha_beta_player, random_player, Fig52Extended
15
from learning import DataSet
16
from logic import parse_definite_clause, standardize_variables, unify_mm, subst
17
from search import GraphProblem, romania_map
18
19
20
# ______________________________________________________________________________
21
# Magic Words
22
23
24
def pseudocode(algorithm):
25
"""Print the pseudocode for the given algorithm."""
26
from urllib.request import urlopen
27
from IPython.display import Markdown
28
29
algorithm = algorithm.replace(' ', '-')
30
url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
31
f = urlopen(url)
32
md = f.read().decode('utf-8')
33
md = md.split('\n', 1)[-1].strip()
34
md = '#' + md
35
return Markdown(md)
36
37
38
def psource(*functions):
39
"""Print the source code for the given function(s)."""
40
source_code = '\n\n'.join(getsource(fn) for fn in functions)
41
try:
42
from pygments.formatters import HtmlFormatter
43
from pygments.lexers import PythonLexer
44
from pygments import highlight
45
46
display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True))))
47
48
except ImportError:
49
print(source_code)
50
51
52
# ______________________________________________________________________________
53
# Iris Visualization
54
55
56
def show_iris(i=0, j=1, k=2):
57
"""Plots the iris dataset in a 3D plot.
58
The three axes are given by i, j and k,
59
which correspond to three of the four iris features."""
60
61
plt.rcParams.update(plt.rcParamsDefault)
62
63
fig = plt.figure()
64
ax = fig.add_subplot(111, projection='3d')
65
66
iris = DataSet(name="iris")
67
buckets = iris.split_values_by_classes()
68
69
features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
70
f1, f2, f3 = features[i], features[j], features[k]
71
72
a_setosa = [v[i] for v in buckets["setosa"]]
73
b_setosa = [v[j] for v in buckets["setosa"]]
74
c_setosa = [v[k] for v in buckets["setosa"]]
75
76
a_virginica = [v[i] for v in buckets["virginica"]]
77
b_virginica = [v[j] for v in buckets["virginica"]]
78
c_virginica = [v[k] for v in buckets["virginica"]]
79
80
a_versicolor = [v[i] for v in buckets["versicolor"]]
81
b_versicolor = [v[j] for v in buckets["versicolor"]]
82
c_versicolor = [v[k] for v in buckets["versicolor"]]
83
84
for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa),
85
('g', '^', a_virginica, b_virginica, c_virginica),
86
('r', 'o', a_versicolor, b_versicolor, c_versicolor)]:
87
ax.scatter(sl, sw, pl, c=c, marker=m)
88
89
ax.set_xlabel(f1)
90
ax.set_ylabel(f2)
91
ax.set_zlabel(f3)
92
93
plt.show()
94
95
96
# ______________________________________________________________________________
97
# MNIST
98
99
100
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
101
import os, struct
102
import array
103
import numpy as np
104
105
if fashion:
106
path = "aima-data/MNIST/Fashion"
107
108
plt.rcParams.update(plt.rcParamsDefault)
109
plt.rcParams['figure.figsize'] = (10.0, 8.0)
110
plt.rcParams['image.interpolation'] = 'nearest'
111
plt.rcParams['image.cmap'] = 'gray'
112
113
train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
114
train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
115
test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
116
test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
117
118
magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
119
tr_img = array.array("B", train_img_file.read())
120
train_img_file.close()
121
magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
122
tr_lbl = array.array("b", train_lbl_file.read())
123
train_lbl_file.close()
124
125
magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16))
126
te_img = array.array("B", test_img_file.read())
127
test_img_file.close()
128
magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8))
129
te_lbl = array.array("b", test_lbl_file.read())
130
test_lbl_file.close()
131
132
# print(len(tr_img), len(tr_lbl), tr_size)
133
# print(len(te_img), len(te_lbl), te_size)
134
135
train_img = np.zeros((tr_size, tr_rows * tr_cols), dtype=np.int16)
136
train_lbl = np.zeros((tr_size,), dtype=np.int8)
137
for i in range(tr_size):
138
train_img[i] = np.array(tr_img[i * tr_rows * tr_cols: (i + 1) * tr_rows * tr_cols]).reshape((tr_rows * te_cols))
139
train_lbl[i] = tr_lbl[i]
140
141
test_img = np.zeros((te_size, te_rows * te_cols), dtype=np.int16)
142
test_lbl = np.zeros((te_size,), dtype=np.int8)
143
for i in range(te_size):
144
test_img[i] = np.array(te_img[i * te_rows * te_cols: (i + 1) * te_rows * te_cols]).reshape((te_rows * te_cols))
145
test_lbl[i] = te_lbl[i]
146
147
return (train_img, train_lbl, test_img, test_lbl)
148
149
150
digit_classes = [str(i) for i in range(10)]
151
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
152
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
153
154
155
def show_MNIST(labels, images, samples=8, fashion=False):
156
if not fashion:
157
classes = digit_classes
158
else:
159
classes = fashion_classes
160
161
num_classes = len(classes)
162
163
for y, cls in enumerate(classes):
164
idxs = np.nonzero([i == y for i in labels])
165
idxs = np.random.choice(idxs[0], samples, replace=False)
166
for i, idx in enumerate(idxs):
167
plt_idx = i * num_classes + y + 1
168
plt.subplot(samples, num_classes, plt_idx)
169
plt.imshow(images[idx].reshape((28, 28)))
170
plt.axis("off")
171
if i == 0:
172
plt.title(cls)
173
174
plt.show()
175
176
177
def show_ave_MNIST(labels, images, fashion=False):
178
if not fashion:
179
item_type = "Digit"
180
classes = digit_classes
181
else:
182
item_type = "Apparel"
183
classes = fashion_classes
184
185
num_classes = len(classes)
186
187
for y, cls in enumerate(classes):
188
idxs = np.nonzero([i == y for i in labels])
189
print(item_type, y, ":", len(idxs[0]), "images.")
190
191
ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis=0)
192
# print(ave_img.shape)
193
194
plt.subplot(1, num_classes, y + 1)
195
plt.imshow(ave_img.reshape((28, 28)))
196
plt.axis("off")
197
plt.title(cls)
198
199
plt.show()
200
201
202
# ______________________________________________________________________________
203
# MDP
204
205
206
def make_plot_grid_step_function(columns, rows, U_over_time):
207
"""ipywidgets interactive function supports single parameter as input.
208
This function creates and return such a function by taking as input
209
other parameters."""
210
211
def plot_grid_step(iteration):
212
data = U_over_time[iteration]
213
data = defaultdict(lambda: 0, data)
214
grid = []
215
for row in range(rows):
216
current_row = []
217
for column in range(columns):
218
current_row.append(data[(column, row)])
219
grid.append(current_row)
220
grid.reverse() # output like book
221
fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation='nearest')
222
223
plt.axis('off')
224
fig.axes.get_xaxis().set_visible(False)
225
fig.axes.get_yaxis().set_visible(False)
226
227
for col in range(len(grid)):
228
for row in range(len(grid[0])):
229
magic = grid[col][row]
230
fig.axes.text(row, col, "{0:.2f}".format(magic), va='center', ha='center')
231
232
plt.show()
233
234
return plot_grid_step
235
236
237
def make_visualize(slider):
238
"""Takes an input a sliderand returns callback function
239
for timer and animation."""
240
241
def visualize_callback(visualize, time_step):
242
if visualize is True:
243
for i in range(slider.min, slider.max + 1):
244
slider.value = i
245
time.sleep(float(time_step))
246
247
return visualize_callback
248
249
250
# ______________________________________________________________________________
251
252
253
_canvas = """
254
<script type="text/javascript" src="./js/canvas.js"></script>
255
<div>
256
<canvas id="{0}" width="{1}" height="{2}" style="background:rgba(158, 167, 184, 0.2);" onclick='click_callback(this, event, "{3}")'></canvas>
257
</div>
258
259
<script> var {0}_canvas_object = new Canvas("{0}");</script>
260
""" # noqa
261
262
263
class Canvas:
264
"""Inherit from this class to manage the HTML canvas element in jupyter notebooks.
265
To create an object of this class any_name_xyz = Canvas("any_name_xyz")
266
The first argument given must be the name of the object being created.
267
IPython must be able to reference the variable name that is being passed."""
268
269
def __init__(self, varname, width=800, height=600, cid=None):
270
self.name = varname
271
self.cid = cid or varname
272
self.width = width
273
self.height = height
274
self.html = _canvas.format(self.cid, self.width, self.height, self.name)
275
self.exec_list = []
276
display_html(self.html)
277
278
def mouse_click(self, x, y):
279
"""Override this method to handle mouse click at position (x, y)"""
280
raise NotImplementedError
281
282
def mouse_move(self, x, y):
283
raise NotImplementedError
284
285
def execute(self, exec_str):
286
"""Stores the command to be executed to a list which is used later during update()"""
287
if not isinstance(exec_str, str):
288
print("Invalid execution argument:", exec_str)
289
self.alert("Received invalid execution command format")
290
prefix = "{0}_canvas_object.".format(self.cid)
291
self.exec_list.append(prefix + exec_str + ';')
292
293
def fill(self, r, g, b):
294
"""Changes the fill color to a color in rgb format"""
295
self.execute("fill({0}, {1}, {2})".format(r, g, b))
296
297
def stroke(self, r, g, b):
298
"""Changes the colors of line/strokes to rgb"""
299
self.execute("stroke({0}, {1}, {2})".format(r, g, b))
300
301
def strokeWidth(self, w):
302
"""Changes the width of lines/strokes to 'w' pixels"""
303
self.execute("strokeWidth({0})".format(w))
304
305
def rect(self, x, y, w, h):
306
"""Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner"""
307
self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h))
308
309
def rect_n(self, xn, yn, wn, hn):
310
"""Similar to rect(), but the dimensions are normalized to fall between 0 and 1"""
311
x = round(xn * self.width)
312
y = round(yn * self.height)
313
w = round(wn * self.width)
314
h = round(hn * self.height)
315
self.rect(x, y, w, h)
316
317
def line(self, x1, y1, x2, y2):
318
"""Draw a line from (x1, y1) to (x2, y2)"""
319
self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2))
320
321
def line_n(self, x1n, y1n, x2n, y2n):
322
"""Similar to line(), but the dimensions are normalized to fall between 0 and 1"""
323
x1 = round(x1n * self.width)
324
y1 = round(y1n * self.height)
325
x2 = round(x2n * self.width)
326
y2 = round(y2n * self.height)
327
self.line(x1, y1, x2, y2)
328
329
def arc(self, x, y, r, start, stop):
330
"""Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'"""
331
self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop))
332
333
def arc_n(self, xn, yn, rn, start, stop):
334
"""Similar to arc(), but the dimensions are normalized to fall between 0 and 1
335
The normalizing factor for radius is selected between width and height by
336
seeing which is smaller."""
337
x = round(xn * self.width)
338
y = round(yn * self.height)
339
r = round(rn * min(self.width, self.height))
340
self.arc(x, y, r, start, stop)
341
342
def clear(self):
343
"""Clear the HTML canvas"""
344
self.execute("clear()")
345
346
def font(self, font):
347
"""Changes the font of text"""
348
self.execute('font("{0}")'.format(font))
349
350
def text(self, txt, x, y, fill=True):
351
"""Display a text at (x, y)"""
352
if fill:
353
self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y))
354
else:
355
self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y))
356
357
def text_n(self, txt, xn, yn, fill=True):
358
"""Similar to text(), but with normalized coordinates"""
359
x = round(xn * self.width)
360
y = round(yn * self.height)
361
self.text(txt, x, y, fill)
362
363
def alert(self, message):
364
"""Immediately display an alert"""
365
display_html('<script>alert("{0}")</script>'.format(message))
366
367
def update(self):
368
"""Execute the JS code to execute the commands queued by execute()"""
369
exec_code = "<script>\n" + '\n'.join(self.exec_list) + "\n</script>"
370
self.exec_list = []
371
display_html(exec_code)
372
373
374
def display_html(html_string):
375
display(HTML(html_string))
376
377
378
################################################################################
379
380
381
class Canvas_TicTacToe(Canvas):
382
"""Play a 3x3 TicTacToe game on HTML canvas"""
383
384
def __init__(self, varname, player_1='human', player_2='random',
385
width=300, height=350, cid=None):
386
valid_players = ('human', 'random', 'alpha_beta')
387
if player_1 not in valid_players or player_2 not in valid_players:
388
raise TypeError("Players must be one of {}".format(valid_players))
389
super().__init__(varname, width, height, cid)
390
self.ttt = TicTacToe()
391
self.state = self.ttt.initial
392
self.turn = 0
393
self.strokeWidth(5)
394
self.players = (player_1, player_2)
395
self.font("20px Arial")
396
self.draw_board()
397
398
def mouse_click(self, x, y):
399
player = self.players[self.turn]
400
if self.ttt.terminal_test(self.state):
401
if 0.55 <= x / self.width <= 0.95 and 6 / 7 <= y / self.height <= 6 / 7 + 1 / 8:
402
self.state = self.ttt.initial
403
self.turn = 0
404
self.draw_board()
405
return
406
407
if player == 'human':
408
x, y = int(3 * x / self.width) + 1, int(3 * y / (self.height * 6 / 7)) + 1
409
if (x, y) not in self.ttt.actions(self.state):
410
# Invalid move
411
return
412
move = (x, y)
413
elif player == 'alpha_beta':
414
move = alpha_beta_player(self.ttt, self.state)
415
else:
416
move = random_player(self.ttt, self.state)
417
self.state = self.ttt.result(self.state, move)
418
self.turn ^= 1
419
self.draw_board()
420
421
def draw_board(self):
422
self.clear()
423
self.stroke(0, 0, 0)
424
offset = 1 / 20
425
self.line_n(0 + offset, (1 / 3) * 6 / 7, 1 - offset, (1 / 3) * 6 / 7)
426
self.line_n(0 + offset, (2 / 3) * 6 / 7, 1 - offset, (2 / 3) * 6 / 7)
427
self.line_n(1 / 3, (0 + offset) * 6 / 7, 1 / 3, (1 - offset) * 6 / 7)
428
self.line_n(2 / 3, (0 + offset) * 6 / 7, 2 / 3, (1 - offset) * 6 / 7)
429
430
board = self.state.board
431
for mark in board:
432
if board[mark] == 'X':
433
self.draw_x(mark)
434
elif board[mark] == 'O':
435
self.draw_o(mark)
436
if self.ttt.terminal_test(self.state):
437
# End game message
438
utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial))
439
if utility == 0:
440
self.text_n('Game Draw!', offset, 6 / 7 + offset)
441
else:
442
self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6 / 7 + offset)
443
# Find the 3 and draw a line
444
self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
445
for i in range(3):
446
if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
447
len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
448
self.line_n(i / 3 + 1 / 6, offset * 6 / 7, i / 3 + 1 / 6, (1 - offset) * 6 / 7)
449
if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
450
len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
451
self.line_n(offset, (i / 3 + 1 / 6) * 6 / 7, 1 - offset, (i / 3 + 1 / 6) * 6 / 7)
452
if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
453
len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
454
self.line_n(offset, offset * 6 / 7, 1 - offset, (1 - offset) * 6 / 7)
455
if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
456
len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
457
self.line_n(offset, (1 - offset) * 6 / 7, 1 - offset, offset * 6 / 7)
458
# restart button
459
self.fill(0, 0, 255)
460
self.rect_n(0.5 + offset, 6 / 7, 0.4, 1 / 8)
461
self.fill(0, 0, 0)
462
self.text_n('Restart', 0.5 + 2 * offset, 13 / 14)
463
else: # Print which player's turn it is
464
self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
465
offset, 6 / 7 + offset)
466
467
self.update()
468
469
def draw_x(self, position):
470
self.stroke(0, 255, 0)
471
x, y = [i - 1 for i in position]
472
offset = 1 / 15
473
self.line_n(x / 3 + offset, (y / 3 + offset) * 6 / 7, x / 3 + 1 / 3 - offset, (y / 3 + 1 / 3 - offset) * 6 / 7)
474
self.line_n(x / 3 + 1 / 3 - offset, (y / 3 + offset) * 6 / 7, x / 3 + offset, (y / 3 + 1 / 3 - offset) * 6 / 7)
475
476
def draw_o(self, position):
477
self.stroke(255, 0, 0)
478
x, y = [i - 1 for i in position]
479
self.arc_n(x / 3 + 1 / 6, (y / 3 + 1 / 6) * 6 / 7, 1 / 9, 0, 360)
480
481
482
class Canvas_min_max(Canvas):
483
"""MinMax for Fig52Extended on HTML canvas"""
484
485
def __init__(self, varname, util_list, width=800, height=600, cid=None):
486
super.__init__(varname, width, height, cid)
487
self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
488
self.game = Fig52Extended()
489
self.game.utils = self.utils
490
self.nodes = list(range(40))
491
self.l = 1 / 40
492
self.node_pos = {}
493
for i in range(4):
494
base = len(self.node_pos)
495
row_size = 3 ** i
496
for node in [base + j for j in range(row_size)]:
497
self.node_pos[node] = ((node - base) / row_size + 1 / (2 * row_size) - self.l / 2,
498
self.l / 2 + (self.l + (1 - 5 * self.l) / 3) * i)
499
self.font("12px Arial")
500
self.node_stack = []
501
self.explored = {node for node in self.utils}
502
self.thick_lines = set()
503
self.change_list = []
504
self.draw_graph()
505
self.stack_manager = self.stack_manager_gen()
506
507
def min_max(self, node):
508
game = self.game
509
player = game.to_move(node)
510
511
def max_value(node):
512
if game.terminal_test(node):
513
return game.utility(node, player)
514
self.change_list.append(('a', node))
515
self.change_list.append(('h',))
516
max_a = max(game.actions(node), key=lambda x: min_value(game.result(node, x)))
517
max_node = game.result(node, max_a)
518
self.utils[node] = self.utils[max_node]
519
x1, y1 = self.node_pos[node]
520
x2, y2 = self.node_pos[max_node]
521
self.change_list.append(('l', (node, max_node - 3 * node - 1)))
522
self.change_list.append(('e', node))
523
self.change_list.append(('p',))
524
self.change_list.append(('h',))
525
return self.utils[node]
526
527
def min_value(node):
528
if game.terminal_test(node):
529
return game.utility(node, player)
530
self.change_list.append(('a', node))
531
self.change_list.append(('h',))
532
min_a = min(game.actions(node), key=lambda x: max_value(game.result(node, x)))
533
min_node = game.result(node, min_a)
534
self.utils[node] = self.utils[min_node]
535
x1, y1 = self.node_pos[node]
536
x2, y2 = self.node_pos[min_node]
537
self.change_list.append(('l', (node, min_node - 3 * node - 1)))
538
self.change_list.append(('e', node))
539
self.change_list.append(('p',))
540
self.change_list.append(('h',))
541
return self.utils[node]
542
543
return max_value(node)
544
545
def stack_manager_gen(self):
546
self.min_max(0)
547
for change in self.change_list:
548
if change[0] == 'a':
549
self.node_stack.append(change[1])
550
elif change[0] == 'e':
551
self.explored.add(change[1])
552
elif change[0] == 'h':
553
yield
554
elif change[0] == 'l':
555
self.thick_lines.add(change[1])
556
elif change[0] == 'p':
557
self.node_stack.pop()
558
559
def mouse_click(self, x, y):
560
try:
561
self.stack_manager.send(None)
562
except StopIteration:
563
pass
564
self.draw_graph()
565
566
def draw_graph(self):
567
self.clear()
568
# draw nodes
569
self.stroke(0, 0, 0)
570
self.strokeWidth(1)
571
# highlight for nodes in stack
572
for node in self.node_stack:
573
x, y = self.node_pos[node]
574
self.fill(200, 200, 0)
575
self.rect_n(x - self.l / 5, y - self.l / 5, self.l * 7 / 5, self.l * 7 / 5)
576
for node in self.nodes:
577
x, y = self.node_pos[node]
578
if node in self.explored:
579
self.fill(255, 255, 255)
580
else:
581
self.fill(200, 200, 200)
582
self.rect_n(x, y, self.l, self.l)
583
self.line_n(x, y, x + self.l, y)
584
self.line_n(x, y, x, y + self.l)
585
self.line_n(x + self.l, y + self.l, x + self.l, y)
586
self.line_n(x + self.l, y + self.l, x, y + self.l)
587
self.fill(0, 0, 0)
588
if node in self.explored:
589
self.text_n(self.utils[node], x + self.l / 10, y + self.l * 9 / 10)
590
# draw edges
591
for i in range(13):
592
x1, y1 = self.node_pos[i][0] + self.l / 2, self.node_pos[i][1] + self.l
593
for j in range(3):
594
x2, y2 = self.node_pos[i * 3 + j + 1][0] + self.l / 2, self.node_pos[i * 3 + j + 1][1]
595
if i in [1, 2, 3]:
596
self.stroke(200, 0, 0)
597
else:
598
self.stroke(0, 200, 0)
599
if (i, j) in self.thick_lines:
600
self.strokeWidth(3)
601
else:
602
self.strokeWidth(1)
603
self.line_n(x1, y1, x2, y2)
604
self.update()
605
606
607
class Canvas_alpha_beta(Canvas):
608
"""Alpha-beta pruning for Fig52Extended on HTML canvas"""
609
610
def __init__(self, varname, util_list, width=800, height=600, cid=None):
611
super().__init__(varname, width, height, cid)
612
self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
613
self.game = Fig52Extended()
614
self.game.utils = self.utils
615
self.nodes = list(range(40))
616
self.l = 1 / 40
617
self.node_pos = {}
618
for i in range(4):
619
base = len(self.node_pos)
620
row_size = 3 ** i
621
for node in [base + j for j in range(row_size)]:
622
self.node_pos[node] = ((node - base) / row_size + 1 / (2 * row_size) - self.l / 2,
623
3 * self.l / 2 + (self.l + (1 - 6 * self.l) / 3) * i)
624
self.font("12px Arial")
625
self.node_stack = []
626
self.explored = {node for node in self.utils}
627
self.pruned = set()
628
self.ab = {}
629
self.thick_lines = set()
630
self.change_list = []
631
self.draw_graph()
632
self.stack_manager = self.stack_manager_gen()
633
634
def alpha_beta_search(self, node):
635
game = self.game
636
player = game.to_move(node)
637
638
# Functions used by alpha_beta
639
def max_value(node, alpha, beta):
640
if game.terminal_test(node):
641
self.change_list.append(('a', node))
642
self.change_list.append(('h',))
643
self.change_list.append(('p',))
644
return game.utility(node, player)
645
v = -np.inf
646
self.change_list.append(('a', node))
647
self.change_list.append(('ab', node, v, beta))
648
self.change_list.append(('h',))
649
for a in game.actions(node):
650
min_val = min_value(game.result(node, a), alpha, beta)
651
if v < min_val:
652
v = min_val
653
max_node = game.result(node, a)
654
self.change_list.append(('ab', node, v, beta))
655
if v >= beta:
656
self.change_list.append(('h',))
657
self.pruned.add(node)
658
break
659
alpha = max(alpha, v)
660
self.utils[node] = v
661
if node not in self.pruned:
662
self.change_list.append(('l', (node, max_node - 3 * node - 1)))
663
self.change_list.append(('e', node))
664
self.change_list.append(('p',))
665
self.change_list.append(('h',))
666
return v
667
668
def min_value(node, alpha, beta):
669
if game.terminal_test(node):
670
self.change_list.append(('a', node))
671
self.change_list.append(('h',))
672
self.change_list.append(('p',))
673
return game.utility(node, player)
674
v = np.inf
675
self.change_list.append(('a', node))
676
self.change_list.append(('ab', node, alpha, v))
677
self.change_list.append(('h',))
678
for a in game.actions(node):
679
max_val = max_value(game.result(node, a), alpha, beta)
680
if v > max_val:
681
v = max_val
682
min_node = game.result(node, a)
683
self.change_list.append(('ab', node, alpha, v))
684
if v <= alpha:
685
self.change_list.append(('h',))
686
self.pruned.add(node)
687
break
688
beta = min(beta, v)
689
self.utils[node] = v
690
if node not in self.pruned:
691
self.change_list.append(('l', (node, min_node - 3 * node - 1)))
692
self.change_list.append(('e', node))
693
self.change_list.append(('p',))
694
self.change_list.append(('h',))
695
return v
696
697
return max_value(node, -np.inf, np.inf)
698
699
def stack_manager_gen(self):
700
self.alpha_beta_search(0)
701
for change in self.change_list:
702
if change[0] == 'a':
703
self.node_stack.append(change[1])
704
elif change[0] == 'ab':
705
self.ab[change[1]] = change[2:]
706
elif change[0] == 'e':
707
self.explored.add(change[1])
708
elif change[0] == 'h':
709
yield
710
elif change[0] == 'l':
711
self.thick_lines.add(change[1])
712
elif change[0] == 'p':
713
self.node_stack.pop()
714
715
def mouse_click(self, x, y):
716
try:
717
self.stack_manager.send(None)
718
except StopIteration:
719
pass
720
self.draw_graph()
721
722
def draw_graph(self):
723
self.clear()
724
# draw nodes
725
self.stroke(0, 0, 0)
726
self.strokeWidth(1)
727
# highlight for nodes in stack
728
for node in self.node_stack:
729
x, y = self.node_pos[node]
730
# alpha > beta
731
if node not in self.explored and self.ab[node][0] > self.ab[node][1]:
732
self.fill(200, 100, 100)
733
else:
734
self.fill(200, 200, 0)
735
self.rect_n(x - self.l / 5, y - self.l / 5, self.l * 7 / 5, self.l * 7 / 5)
736
for node in self.nodes:
737
x, y = self.node_pos[node]
738
if node in self.explored:
739
if node in self.pruned:
740
self.fill(50, 50, 50)
741
else:
742
self.fill(255, 255, 255)
743
else:
744
self.fill(200, 200, 200)
745
self.rect_n(x, y, self.l, self.l)
746
self.line_n(x, y, x + self.l, y)
747
self.line_n(x, y, x, y + self.l)
748
self.line_n(x + self.l, y + self.l, x + self.l, y)
749
self.line_n(x + self.l, y + self.l, x, y + self.l)
750
self.fill(0, 0, 0)
751
if node in self.explored and node not in self.pruned:
752
self.text_n(self.utils[node], x + self.l / 10, y + self.l * 9 / 10)
753
# draw edges
754
for i in range(13):
755
x1, y1 = self.node_pos[i][0] + self.l / 2, self.node_pos[i][1] + self.l
756
for j in range(3):
757
x2, y2 = self.node_pos[i * 3 + j + 1][0] + self.l / 2, self.node_pos[i * 3 + j + 1][1]
758
if i in [1, 2, 3]:
759
self.stroke(200, 0, 0)
760
else:
761
self.stroke(0, 200, 0)
762
if (i, j) in self.thick_lines:
763
self.strokeWidth(3)
764
else:
765
self.strokeWidth(1)
766
self.line_n(x1, y1, x2, y2)
767
# display alpha and beta
768
for node in self.node_stack:
769
if node not in self.explored:
770
x, y = self.node_pos[node]
771
alpha, beta = self.ab[node]
772
self.text_n(alpha, x - self.l / 2, y - self.l / 10)
773
self.text_n(beta, x + self.l, y - self.l / 10)
774
self.update()
775
776
777
class Canvas_fol_bc_ask(Canvas):
778
"""fol_bc_ask() on HTML canvas"""
779
780
def __init__(self, varname, kb, query, width=800, height=600, cid=None):
781
super().__init__(varname, width, height, cid)
782
self.kb = kb
783
self.query = query
784
self.l = 1 / 20
785
self.b = 3 * self.l
786
bc_out = list(self.fol_bc_ask())
787
if len(bc_out) == 0:
788
self.valid = False
789
else:
790
self.valid = True
791
graph = bc_out[0][0][0]
792
s = bc_out[0][1]
793
while True:
794
new_graph = subst(s, graph)
795
if graph == new_graph:
796
break
797
graph = new_graph
798
self.make_table(graph)
799
self.context = None
800
self.draw_table()
801
802
def fol_bc_ask(self):
803
KB = self.kb
804
query = self.query
805
806
def fol_bc_or(KB, goal, theta):
807
for rule in KB.fetch_rules_for_goal(goal):
808
lhs, rhs = parse_definite_clause(standardize_variables(rule))
809
for theta1 in fol_bc_and(KB, lhs, unify_mm(rhs, goal, theta)):
810
yield ([(goal, theta1[0])], theta1[1])
811
812
def fol_bc_and(KB, goals, theta):
813
if theta is None:
814
pass
815
elif not goals:
816
yield ([], theta)
817
else:
818
first, rest = goals[0], goals[1:]
819
for theta1 in fol_bc_or(KB, subst(theta, first), theta):
820
for theta2 in fol_bc_and(KB, rest, theta1[1]):
821
yield (theta1[0] + theta2[0], theta2[1])
822
823
return fol_bc_or(KB, query, {})
824
825
def make_table(self, graph):
826
table = []
827
pos = {}
828
links = set()
829
edges = set()
830
831
def dfs(node, depth):
832
if len(table) <= depth:
833
table.append([])
834
pos = len(table[depth])
835
table[depth].append(node[0])
836
for child in node[1]:
837
child_id = dfs(child, depth + 1)
838
links.add(((depth, pos), child_id))
839
return (depth, pos)
840
841
dfs(graph, 0)
842
y_off = 0.85 / len(table)
843
for i, row in enumerate(table):
844
x_off = 0.95 / len(row)
845
for j, node in enumerate(row):
846
pos[(i, j)] = (0.025 + j * x_off + (x_off - self.b) / 2, 0.025 + i * y_off + (y_off - self.l) / 2)
847
for p, c in links:
848
x1, y1 = pos[p]
849
x2, y2 = pos[c]
850
edges.add((x1 + self.b / 2, y1 + self.l, x2 + self.b / 2, y2))
851
852
self.table = table
853
self.pos = pos
854
self.edges = edges
855
856
def mouse_click(self, x, y):
857
x, y = x / self.width, y / self.height
858
for node in self.pos:
859
xs, ys = self.pos[node]
860
xe, ye = xs + self.b, ys + self.l
861
if xs <= x <= xe and ys <= y <= ye:
862
self.context = node
863
break
864
self.draw_table()
865
866
def draw_table(self):
867
self.clear()
868
self.strokeWidth(3)
869
self.stroke(0, 0, 0)
870
self.font("12px Arial")
871
if self.valid:
872
# draw nodes
873
for i, j in self.pos:
874
x, y = self.pos[(i, j)]
875
self.fill(200, 200, 200)
876
self.rect_n(x, y, self.b, self.l)
877
self.line_n(x, y, x + self.b, y)
878
self.line_n(x, y, x, y + self.l)
879
self.line_n(x + self.b, y, x + self.b, y + self.l)
880
self.line_n(x, y + self.l, x + self.b, y + self.l)
881
self.fill(0, 0, 0)
882
self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01)
883
# draw edges
884
for x1, y1, x2, y2 in self.edges:
885
self.line_n(x1, y1, x2, y2)
886
else:
887
self.fill(255, 0, 0)
888
self.rect_n(0, 0, 1, 1)
889
# text area
890
self.fill(255, 255, 255)
891
self.rect_n(0, 0.9, 1, 0.1)
892
self.strokeWidth(5)
893
self.stroke(0, 0, 0)
894
self.line_n(0, 0.9, 1, 0.9)
895
self.font("22px Arial")
896
self.fill(0, 0, 0)
897
self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975)
898
self.update()
899
900
901
############################################################################################################
902
903
##################### Functions to assist plotting in search.ipynb ####################
904
905
############################################################################################################
906
907
908
def show_map(graph_data, node_colors=None):
909
G = nx.Graph(graph_data['graph_dict'])
910
node_colors = node_colors or graph_data['node_colors']
911
node_positions = graph_data['node_positions']
912
node_label_pos = graph_data['node_label_positions']
913
edge_weights = graph_data['edge_weights']
914
915
# set the size of the plot
916
plt.figure(figsize=(18, 13))
917
# draw the graph (both nodes and edges) with locations from romania_locations
918
nx.draw(G, pos={k: node_positions[k] for k in G.nodes()},
919
node_color=[node_colors[node] for node in G.nodes()], linewidths=0.3, edgecolors='k')
920
921
# draw labels for nodes
922
node_label_handles = nx.draw_networkx_labels(G, pos=node_label_pos, font_size=14)
923
924
# add a white bounding box behind the node labels
925
[label.set_bbox(dict(facecolor='white', edgecolor='none')) for label in node_label_handles.values()]
926
927
# add edge lables to the graph
928
nx.draw_networkx_edge_labels(G, pos=node_positions, edge_labels=edge_weights, font_size=14)
929
930
# add a legend
931
white_circle = lines.Line2D([], [], color="white", marker='o', markersize=15, markerfacecolor="white")
932
orange_circle = lines.Line2D([], [], color="orange", marker='o', markersize=15, markerfacecolor="orange")
933
red_circle = lines.Line2D([], [], color="red", marker='o', markersize=15, markerfacecolor="red")
934
gray_circle = lines.Line2D([], [], color="gray", marker='o', markersize=15, markerfacecolor="gray")
935
green_circle = lines.Line2D([], [], color="green", marker='o', markersize=15, markerfacecolor="green")
936
plt.legend((white_circle, orange_circle, red_circle, gray_circle, green_circle),
937
('Un-explored', 'Frontier', 'Currently Exploring', 'Explored', 'Final Solution'),
938
numpoints=1, prop={'size': 16}, loc=(.8, .75))
939
940
# show the plot. No need to use in notebooks. nx.draw will show the graph itself.
941
plt.show()
942
943
944
# helper functions for visualisations
945
946
def final_path_colors(initial_node_colors, problem, solution):
947
"Return a node_colors dict of the final path provided the problem and solution."
948
949
# get initial node colors
950
final_colors = dict(initial_node_colors)
951
# color all the nodes in solution and starting node to green
952
final_colors[problem.initial] = "green"
953
for node in solution:
954
final_colors[node] = "green"
955
return final_colors
956
957
958
def display_visual(graph_data, user_input, algorithm=None, problem=None):
959
initial_node_colors = graph_data['node_colors']
960
if user_input is False:
961
def slider_callback(iteration):
962
# don't show graph for the first time running the cell calling this function
963
try:
964
show_map(graph_data, node_colors=all_node_colors[iteration])
965
except:
966
pass
967
968
def visualize_callback(visualize):
969
if visualize is True:
970
button.value = False
971
972
global all_node_colors
973
974
iterations, all_node_colors, node = algorithm(problem)
975
solution = node.solution()
976
all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
977
978
slider.max = len(all_node_colors) - 1
979
980
for i in range(slider.max + 1):
981
slider.value = i
982
# time.sleep(.5)
983
984
slider = widgets.IntSlider(min=0, max=1, step=1, value=0)
985
slider_visual = widgets.interactive(slider_callback, iteration=slider)
986
display(slider_visual)
987
988
button = widgets.ToggleButton(value=False)
989
button_visual = widgets.interactive(visualize_callback, visualize=button)
990
display(button_visual)
991
992
if user_input is True:
993
node_colors = dict(initial_node_colors)
994
if isinstance(algorithm, dict):
995
assert set(algorithm.keys()).issubset({"Breadth First Tree Search",
996
"Depth First Tree Search",
997
"Breadth First Search",
998
"Depth First Graph Search",
999
"Best First Graph Search",
1000
"Uniform Cost Search",
1001
"Depth Limited Search",
1002
"Iterative Deepening Search",
1003
"Greedy Best First Search",
1004
"A-star Search",
1005
"Recursive Best First Search"})
1006
1007
algo_dropdown = widgets.Dropdown(description="Search algorithm: ",
1008
options=sorted(list(algorithm.keys())),
1009
value="Breadth First Tree Search")
1010
display(algo_dropdown)
1011
elif algorithm is None:
1012
print("No algorithm to run.")
1013
return 0
1014
1015
def slider_callback(iteration):
1016
# don't show graph for the first time running the cell calling this function
1017
try:
1018
show_map(graph_data, node_colors=all_node_colors[iteration])
1019
except:
1020
pass
1021
1022
def visualize_callback(visualize):
1023
if visualize is True:
1024
button.value = False
1025
1026
problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map)
1027
global all_node_colors
1028
1029
user_algorithm = algorithm[algo_dropdown.value]
1030
1031
iterations, all_node_colors, node = user_algorithm(problem)
1032
solution = node.solution()
1033
all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
1034
1035
slider.max = len(all_node_colors) - 1
1036
1037
for i in range(slider.max + 1):
1038
slider.value = i
1039
# time.sleep(.5)
1040
1041
start_dropdown = widgets.Dropdown(description="Start city: ",
1042
options=sorted(list(node_colors.keys())), value="Arad")
1043
display(start_dropdown)
1044
1045
end_dropdown = widgets.Dropdown(description="Goal city: ",
1046
options=sorted(list(node_colors.keys())), value="Fagaras")
1047
display(end_dropdown)
1048
1049
button = widgets.ToggleButton(value=False)
1050
button_visual = widgets.interactive(visualize_callback, visualize=button)
1051
display(button_visual)
1052
1053
slider = widgets.IntSlider(min=0, max=1, step=1, value=0)
1054
slider_visual = widgets.interactive(slider_callback, iteration=slider)
1055
display(slider_visual)
1056
1057
1058
# Function to plot NQueensCSP in csp.py and NQueensProblem in search.py
1059
def plot_NQueens(solution):
1060
n = len(solution)
1061
board = np.array([2 * int((i + j) % 2) for j in range(n) for i in range(n)]).reshape((n, n))
1062
im = Image.open('images/queen_s.png')
1063
height = im.size[1]
1064
im = np.array(im).astype(np.float) / 255
1065
fig = plt.figure(figsize=(7, 7))
1066
ax = fig.add_subplot(111)
1067
ax.set_title('{} Queens'.format(n))
1068
plt.imshow(board, cmap='binary', interpolation='nearest')
1069
# NQueensCSP gives a solution as a dictionary
1070
if isinstance(solution, dict):
1071
for (k, v) in solution.items():
1072
newax = fig.add_axes([0.064 + (k * 0.112), 0.062 + ((7 - v) * 0.112), 0.1, 0.1], zorder=1)
1073
newax.imshow(im)
1074
newax.axis('off')
1075
# NQueensProblem gives a solution as a list
1076
elif isinstance(solution, list):
1077
for (k, v) in enumerate(solution):
1078
newax = fig.add_axes([0.064 + (k * 0.112), 0.062 + ((7 - v) * 0.112), 0.1, 0.1], zorder=1)
1079
newax.imshow(im)
1080
newax.axis('off')
1081
fig.tight_layout()
1082
plt.show()
1083
1084
1085
# Function to plot a heatmap, given a grid
1086
def heatmap(grid, cmap='binary', interpolation='nearest'):
1087
fig = plt.figure(figsize=(7, 7))
1088
ax = fig.add_subplot(111)
1089
ax.set_title('Heatmap')
1090
plt.imshow(grid, cmap=cmap, interpolation=interpolation)
1091
fig.tight_layout()
1092
plt.show()
1093
1094
1095
# Generates a gaussian kernel
1096
def gaussian_kernel(l=5, sig=1.0):
1097
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
1098
xx, yy = np.meshgrid(ax, ax)
1099
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
1100
return kernel
1101
1102
1103
# Plots utility function for a POMDP
1104
def plot_pomdp_utility(utility):
1105
save = utility['0'][0]
1106
delete = utility['1'][0]
1107
ask_save = utility['2'][0]
1108
ask_delete = utility['2'][-1]
1109
left = (save[0] - ask_save[0]) / (save[0] - ask_save[0] + ask_save[1] - save[1])
1110
right = (delete[0] - ask_delete[0]) / (delete[0] - ask_delete[0] + ask_delete[1] - delete[1])
1111
1112
colors = ['g', 'b', 'k']
1113
for action in utility:
1114
for value in utility[action]:
1115
plt.plot(value, color=colors[int(action)])
1116
plt.vlines([left, right], -20, 10, linestyles='dashed', colors='c')
1117
plt.ylim(-20, 13)
1118
plt.xlim(0, 1)
1119
plt.text(left / 2 - 0.05, 10, 'Save')
1120
plt.text((right + left) / 2 - 0.02, 10, 'Ask')
1121
plt.text((right + 1) / 2 - 0.07, 10, 'Delete')
1122
plt.show()
1123
1124