Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aimacode
GitHub Repository: aimacode/aima-python
Path: blob/master/gui/grid_mdp.py
621 views
1
import os.path
2
import sys
3
import tkinter as tk
4
import tkinter.messagebox
5
from functools import partial
6
from tkinter import ttk
7
8
import matplotlib
9
import matplotlib.animation as animation
10
from matplotlib import pyplot as plt
11
from matplotlib import style
12
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
13
from matplotlib.figure import Figure
14
from matplotlib.ticker import MaxNLocator
15
16
from mdp import *
17
18
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
19
20
matplotlib.use('TkAgg')
21
style.use('ggplot')
22
23
fig = Figure(figsize=(20, 15))
24
sub = fig.add_subplot(111)
25
plt.rcParams['axes.grid'] = False
26
27
WALL_VALUE = -99999.0
28
TERM_VALUE = -999999.0
29
30
black = '#000'
31
white = '#fff'
32
gray2 = '#222'
33
gray9 = '#999'
34
grayd = '#ddd'
35
grayef = '#efefef'
36
pblue = '#000040'
37
green8 = '#008080'
38
green4 = '#004040'
39
40
cell_window_mantainer = None
41
42
43
def extents(f):
44
"""adjusts axis markers for heatmap"""
45
46
delta = f[1] - f[0]
47
return [f[0] - delta / 2, f[-1] + delta / 2]
48
49
50
def display(gridmdp, _height, _width):
51
"""displays matrix"""
52
53
dialog = tk.Toplevel()
54
dialog.wm_title('Values')
55
56
container = tk.Frame(dialog)
57
container.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
58
59
for i in range(max(1, _height)):
60
for j in range(max(1, _width)):
61
label = ttk.Label(container, text=f'{gridmdp[_height - i - 1][j]:.3f}', font=('Helvetica', 12))
62
label.grid(row=i + 1, column=j + 1, padx=3, pady=3)
63
64
dialog.mainloop()
65
66
67
def display_best_policy(_best_policy, _height, _width):
68
"""displays best policy"""
69
dialog = tk.Toplevel()
70
dialog.wm_title('Best Policy')
71
72
container = tk.Frame(dialog)
73
container.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
74
75
for i in range(max(1, _height)):
76
for j in range(max(1, _width)):
77
label = ttk.Label(container, text=_best_policy[i][j], font=('Helvetica', 12, 'bold'))
78
label.grid(row=i + 1, column=j + 1, padx=3, pady=3)
79
80
dialog.mainloop()
81
82
83
def initialize_dialogbox(_width, _height, gridmdp, terminals, buttons):
84
"""creates dialogbox for initialization"""
85
86
dialog = tk.Toplevel()
87
dialog.wm_title('Initialize')
88
89
container = tk.Frame(dialog)
90
container.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
91
container.grid_rowconfigure(0, weight=1)
92
container.grid_columnconfigure(0, weight=1)
93
94
wall = tk.IntVar()
95
wall.set(0)
96
term = tk.IntVar()
97
term.set(0)
98
reward = tk.DoubleVar()
99
reward.set(0.0)
100
101
label = ttk.Label(container, text='Initialize', font=('Helvetica', 12), anchor=tk.N)
102
label.grid(row=0, column=0, columnspan=3, sticky='new', pady=15, padx=5)
103
label_reward = ttk.Label(container, text='Reward', font=('Helvetica', 10), anchor=tk.N)
104
label_reward.grid(row=1, column=0, columnspan=3, sticky='new', pady=1, padx=5)
105
entry_reward = ttk.Entry(container, font=('Helvetica', 10), justify=tk.CENTER, exportselection=0,
106
textvariable=reward)
107
entry_reward.grid(row=2, column=0, columnspan=3, sticky='new', pady=5, padx=50)
108
109
rbtn_term = ttk.Radiobutton(container, text='Terminal', variable=term, value=TERM_VALUE)
110
rbtn_term.grid(row=3, column=0, columnspan=3, sticky='nsew', padx=160, pady=5)
111
rbtn_wall = ttk.Radiobutton(container, text='Wall', variable=wall, value=WALL_VALUE)
112
rbtn_wall.grid(row=4, column=0, columnspan=3, sticky='nsew', padx=172, pady=5)
113
114
initialize_widget_disability_checks(_width, _height, gridmdp, terminals, label_reward, entry_reward, rbtn_wall,
115
rbtn_term)
116
117
btn_apply = ttk.Button(container, text='Apply',
118
command=partial(initialize_update_table, _width, _height, gridmdp, terminals, buttons,
119
reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall))
120
btn_apply.grid(row=5, column=0, sticky='nsew', pady=5, padx=5)
121
btn_reset = ttk.Button(container, text='Reset',
122
command=partial(initialize_reset_all, _width, _height, gridmdp, terminals, buttons, reward,
123
term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term))
124
btn_reset.grid(row=5, column=1, sticky='nsew', pady=5, padx=5)
125
btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy)
126
btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5)
127
128
dialog.geometry('400x200')
129
dialog.mainloop()
130
131
132
def update_table(i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term,
133
rbtn_wall):
134
"""functionality for 'apply' button"""
135
if wall.get() == WALL_VALUE:
136
buttons[i][j].configure(style='wall.TButton')
137
buttons[i][j].config(text='Wall')
138
label_reward.config(foreground='#999')
139
entry_reward.config(state=tk.DISABLED)
140
rbtn_term.state(['!focus', '!selected'])
141
rbtn_term.config(state=tk.DISABLED)
142
gridmdp[i][j] = WALL_VALUE
143
144
elif wall.get() != WALL_VALUE:
145
if reward.get() != 0.0:
146
gridmdp[i][j] = reward.get()
147
buttons[i][j].configure(style='reward.TButton')
148
buttons[i][j].config(text=f'R = {reward.get()}')
149
150
if term.get() == TERM_VALUE:
151
if (i, j) not in terminals:
152
terminals.append((i, j))
153
rbtn_wall.state(['!focus', '!selected'])
154
rbtn_wall.config(state=tk.DISABLED)
155
156
if gridmdp[i][j] < 0:
157
buttons[i][j].configure(style='-term.TButton')
158
159
elif gridmdp[i][j] > 0:
160
buttons[i][j].configure(style='+term.TButton')
161
162
elif gridmdp[i][j] == 0.0:
163
buttons[i][j].configure(style='=term.TButton')
164
165
166
def initialize_update_table(_width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward,
167
entry_reward, rbtn_term, rbtn_wall):
168
"""runs update_table for all cells"""
169
170
for i in range(max(1, _height)):
171
for j in range(max(1, _width)):
172
update_table(i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term,
173
rbtn_wall)
174
175
176
def reset_all(_height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall,
177
rbtn_term):
178
"""functionality for reset button"""
179
reward.set(0.0)
180
term.set(0)
181
wall.set(0)
182
gridmdp[i][j] = 0.0
183
buttons[i][j].configure(style='TButton')
184
buttons[i][j].config(text=f'({_height - i - 1}, {j})')
185
186
if (i, j) in terminals:
187
terminals.remove((i, j))
188
189
label_reward.config(foreground='#000')
190
entry_reward.config(state=tk.NORMAL)
191
rbtn_term.config(state=tk.NORMAL)
192
rbtn_wall.config(state=tk.NORMAL)
193
rbtn_wall.state(['!focus', '!selected'])
194
rbtn_term.state(['!focus', '!selected'])
195
196
197
def initialize_reset_all(_width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward,
198
rbtn_wall, rbtn_term):
199
"""runs reset_all for all cells"""
200
201
for i in range(max(1, _height)):
202
for j in range(max(1, _width)):
203
reset_all(_height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward,
204
rbtn_wall, rbtn_term)
205
206
207
def external_reset(_width, _height, gridmdp, terminals, buttons):
208
"""reset from edit menu"""
209
for i in range(max(1, _height)):
210
for j in range(max(1, _width)):
211
gridmdp[i][j] = 0.0
212
buttons[i][j].configure(style='TButton')
213
buttons[i][j].config(text=f'({_height - i - 1}, {j})')
214
215
216
def widget_disability_checks(i, j, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term):
217
"""checks for required state of widgets in dialog boxes"""
218
219
if gridmdp[i][j] == WALL_VALUE:
220
label_reward.config(foreground='#999')
221
entry_reward.config(state=tk.DISABLED)
222
rbtn_term.config(state=tk.DISABLED)
223
rbtn_wall.state(['!focus', 'selected'])
224
rbtn_term.state(['!focus', '!selected'])
225
226
if (i, j) in terminals:
227
rbtn_wall.config(state=tk.DISABLED)
228
rbtn_wall.state(['!focus', '!selected'])
229
230
231
def flatten_list(_list):
232
"""returns a flattened list"""
233
return sum(_list, [])
234
235
236
def initialize_widget_disability_checks(_width, _height, gridmdp, terminals, label_reward, entry_reward, rbtn_wall,
237
rbtn_term):
238
"""checks for required state of widgets when cells are initialized"""
239
240
bool_walls = [['False'] * max(1, _width) for _ in range(max(1, _height))]
241
bool_terms = [['False'] * max(1, _width) for _ in range(max(1, _height))]
242
243
for i in range(max(1, _height)):
244
for j in range(max(1, _width)):
245
if gridmdp[i][j] == WALL_VALUE:
246
bool_walls[i][j] = 'True'
247
248
if (i, j) in terminals:
249
bool_terms[i][j] = 'True'
250
251
bool_walls_fl = flatten_list(bool_walls)
252
bool_terms_fl = flatten_list(bool_terms)
253
254
if bool_walls_fl.count('True') == len(bool_walls_fl):
255
print('`')
256
label_reward.config(foreground='#999')
257
entry_reward.config(state=tk.DISABLED)
258
rbtn_term.config(state=tk.DISABLED)
259
rbtn_wall.state(['!focus', 'selected'])
260
rbtn_term.state(['!focus', '!selected'])
261
262
if bool_terms_fl.count('True') == len(bool_terms_fl):
263
rbtn_wall.config(state=tk.DISABLED)
264
rbtn_wall.state(['!focus', '!selected'])
265
rbtn_term.state(['!focus', 'selected'])
266
267
268
def dialogbox(i, j, gridmdp, terminals, buttons, _height):
269
"""creates dialogbox for each cell"""
270
global cell_window_mantainer
271
if (cell_window_mantainer != None):
272
cell_window_mantainer.destroy()
273
274
dialog = tk.Toplevel()
275
cell_window_mantainer = dialog
276
dialog.wm_title(f'{_height - i - 1}, {j}')
277
278
container = tk.Frame(dialog)
279
container.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
280
container.grid_rowconfigure(0, weight=1)
281
container.grid_columnconfigure(0, weight=1)
282
283
wall = tk.IntVar()
284
wall.set(gridmdp[i][j])
285
term = tk.IntVar()
286
term.set(TERM_VALUE if (i, j) in terminals else 0.0)
287
reward = tk.DoubleVar()
288
reward.set(gridmdp[i][j] if gridmdp[i][j] != WALL_VALUE else 0.0)
289
290
label = ttk.Label(container, text=f'Configure cell {_height - i - 1}, {j}', font=('Helvetica', 12), anchor=tk.N)
291
label.grid(row=0, column=0, columnspan=3, sticky='new', pady=15, padx=5)
292
label_reward = ttk.Label(container, text='Reward', font=('Helvetica', 10), anchor=tk.N)
293
label_reward.grid(row=1, column=0, columnspan=3, sticky='new', pady=1, padx=5)
294
entry_reward = ttk.Entry(container, font=('Helvetica', 10), justify=tk.CENTER, exportselection=0,
295
textvariable=reward)
296
entry_reward.grid(row=2, column=0, columnspan=3, sticky='new', pady=5, padx=50)
297
298
rbtn_term = ttk.Radiobutton(container, text='Terminal', variable=term, value=TERM_VALUE)
299
rbtn_term.grid(row=3, column=0, columnspan=3, sticky='nsew', padx=160, pady=5)
300
rbtn_wall = ttk.Radiobutton(container, text='Wall', variable=wall, value=WALL_VALUE)
301
rbtn_wall.grid(row=4, column=0, columnspan=3, sticky='nsew', padx=172, pady=5)
302
303
widget_disability_checks(i, j, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term)
304
305
btn_apply = ttk.Button(container, text='Apply',
306
command=partial(update_table, i, j, gridmdp, terminals, buttons, reward, term, wall,
307
label_reward, entry_reward, rbtn_term, rbtn_wall))
308
btn_apply.grid(row=5, column=0, sticky='nsew', pady=5, padx=5)
309
btn_reset = ttk.Button(container, text='Reset',
310
command=partial(reset_all, _height, i, j, gridmdp, terminals, buttons, reward, term, wall,
311
label_reward, entry_reward, rbtn_wall, rbtn_term))
312
btn_reset.grid(row=5, column=1, sticky='nsew', pady=5, padx=5)
313
btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy)
314
btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5)
315
316
dialog.geometry('400x200')
317
dialog.mainloop()
318
319
320
class MDPapp(tk.Tk):
321
322
def __init__(self, *args, **kwargs):
323
324
tk.Tk.__init__(self, *args, **kwargs)
325
tk.Tk.wm_title(self, 'Grid MDP')
326
self.shared_data = {
327
'height': tk.IntVar(),
328
'width': tk.IntVar()}
329
self.shared_data['height'].set(1)
330
self.shared_data['width'].set(1)
331
self.container = tk.Frame(self)
332
self.container.pack(side='top', fill='both', expand=True)
333
self.container.grid_rowconfigure(0, weight=1)
334
self.container.grid_columnconfigure(0, weight=1)
335
336
self.frames = {}
337
338
self.menu_bar = tk.Menu(self.container)
339
self.file_menu = tk.Menu(self.menu_bar, tearoff=0)
340
self.file_menu.add_command(label='Exit', command=self.exit)
341
self.menu_bar.add_cascade(label='File', menu=self.file_menu)
342
343
self.edit_menu = tk.Menu(self.menu_bar, tearoff=1)
344
self.edit_menu.add_command(label='Reset', command=self.master_reset)
345
self.edit_menu.add_command(label='Initialize', command=self.initialize)
346
self.edit_menu.add_separator()
347
self.edit_menu.add_command(label='View matrix', command=self.view_matrix)
348
self.edit_menu.add_command(label='View terminals', command=self.view_terminals)
349
self.menu_bar.add_cascade(label='Edit', menu=self.edit_menu)
350
self.menu_bar.entryconfig('Edit', state=tk.DISABLED)
351
352
self.build_menu = tk.Menu(self.menu_bar, tearoff=1)
353
self.build_menu.add_command(label='Build and Run', command=self.build)
354
self.menu_bar.add_cascade(label='Build', menu=self.build_menu)
355
self.menu_bar.entryconfig('Build', state=tk.DISABLED)
356
tk.Tk.config(self, menu=self.menu_bar)
357
358
for F in (HomePage, BuildMDP, SolveMDP):
359
frame = F(self.container, self)
360
self.frames[F] = frame
361
frame.grid(row=0, column=0, sticky='nsew')
362
363
self.show_frame(HomePage)
364
365
def placeholder_function(self):
366
"""placeholder function"""
367
368
print('Not supported yet!')
369
370
def exit(self):
371
"""function to exit"""
372
if tkinter.messagebox.askokcancel('Exit?', 'All changes will be lost'):
373
quit()
374
375
def new(self):
376
"""function to create new GridMDP"""
377
378
self.master_reset()
379
build_page = self.get_page(BuildMDP)
380
build_page.gridmdp = None
381
build_page.terminals = None
382
build_page.buttons = None
383
self.show_frame(HomePage)
384
385
def get_page(self, page_class):
386
"""returns pages from stored frames"""
387
return self.frames[page_class]
388
389
def view_matrix(self):
390
"""prints current matrix to console"""
391
392
build_page = self.get_page(BuildMDP)
393
_height = self.shared_data['height'].get()
394
_width = self.shared_data['width'].get()
395
print(build_page.gridmdp)
396
display(build_page.gridmdp, _height, _width)
397
398
def view_terminals(self):
399
"""prints current terminals to console"""
400
build_page = self.get_page(BuildMDP)
401
print('Terminals', build_page.terminals)
402
403
def initialize(self):
404
"""calls initialize from BuildMDP"""
405
406
build_page = self.get_page(BuildMDP)
407
build_page.initialize()
408
409
def master_reset(self):
410
"""calls master_reset from BuildMDP"""
411
build_page = self.get_page(BuildMDP)
412
build_page.master_reset()
413
414
def build(self):
415
"""runs specified mdp solving algorithm"""
416
417
frame = SolveMDP(self.container, self)
418
self.frames[SolveMDP] = frame
419
frame.grid(row=0, column=0, sticky='nsew')
420
self.show_frame(SolveMDP)
421
build_page = self.get_page(BuildMDP)
422
gridmdp = build_page.gridmdp
423
terminals = build_page.terminals
424
solve_page = self.get_page(SolveMDP)
425
_height = self.shared_data['height'].get()
426
_width = self.shared_data['width'].get()
427
solve_page.create_graph(gridmdp, terminals, _height, _width)
428
429
def show_frame(self, controller, cb=False):
430
"""shows specified frame and optionally runs create_buttons"""
431
if cb:
432
build_page = self.get_page(BuildMDP)
433
build_page.create_buttons()
434
frame = self.frames[controller]
435
frame.tkraise()
436
437
438
class HomePage(tk.Frame):
439
440
def __init__(self, parent, controller):
441
"""HomePage constructor"""
442
443
tk.Frame.__init__(self, parent)
444
self.controller = controller
445
frame1 = tk.Frame(self)
446
frame1.pack(side=tk.TOP)
447
frame3 = tk.Frame(self)
448
frame3.pack(side=tk.TOP)
449
frame4 = tk.Frame(self)
450
frame4.pack(side=tk.TOP)
451
frame2 = tk.Frame(self)
452
frame2.pack(side=tk.TOP)
453
454
s = ttk.Style()
455
s.theme_use('clam')
456
s.configure('TButton', background=grayd, padding=0)
457
s.configure('wall.TButton', background=gray2, foreground=white)
458
s.configure('reward.TButton', background=gray9)
459
s.configure('+term.TButton', background=green8)
460
s.configure('-term.TButton', background=pblue, foreground=white)
461
s.configure('=term.TButton', background=green4)
462
463
label = ttk.Label(frame1, text='GridMDP builder', font=('Helvetica', 18, 'bold'), background=grayef)
464
label.pack(pady=75, padx=50, side=tk.TOP)
465
466
ec_btn = ttk.Button(frame3, text='Empty cells', width=20)
467
ec_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10)
468
ec_btn.configure(style='TButton')
469
470
w_btn = ttk.Button(frame3, text='Walls', width=20)
471
w_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10)
472
w_btn.configure(style='wall.TButton')
473
474
r_btn = ttk.Button(frame3, text='Rewards', width=20)
475
r_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10)
476
r_btn.configure(style='reward.TButton')
477
478
term_p = ttk.Button(frame3, text='Positive terminals', width=20)
479
term_p.pack(pady=0, padx=0, side=tk.LEFT, ipady=10)
480
term_p.configure(style='+term.TButton')
481
482
term_z = ttk.Button(frame3, text='Neutral terminals', width=20)
483
term_z.pack(pady=0, padx=0, side=tk.LEFT, ipady=10)
484
term_z.configure(style='=term.TButton')
485
486
term_n = ttk.Button(frame3, text='Negative terminals', width=20)
487
term_n.pack(pady=0, padx=0, side=tk.LEFT, ipady=10)
488
term_n.configure(style='-term.TButton')
489
490
label = ttk.Label(frame4, text='Dimensions', font=('Verdana', 14), background=grayef)
491
label.pack(pady=15, padx=10, side=tk.TOP)
492
entry_h = tk.Entry(frame2, textvariable=self.controller.shared_data['height'], font=('Verdana', 10), width=3,
493
justify=tk.CENTER)
494
entry_h.pack(pady=10, padx=10, side=tk.LEFT)
495
label_x = ttk.Label(frame2, text='X', font=('Verdana', 10), background=grayef)
496
label_x.pack(pady=10, padx=4, side=tk.LEFT)
497
entry_w = tk.Entry(frame2, textvariable=self.controller.shared_data['width'], font=('Verdana', 10), width=3,
498
justify=tk.CENTER)
499
entry_w.pack(pady=10, padx=10, side=tk.LEFT)
500
button = ttk.Button(self, text='Build a GridMDP', command=lambda: controller.show_frame(BuildMDP, cb=True))
501
button.pack(pady=10, padx=10, side=tk.TOP, ipadx=20, ipady=10)
502
button.configure(style='reward.TButton')
503
504
505
class BuildMDP(tk.Frame):
506
507
def __init__(self, parent, controller):
508
509
tk.Frame.__init__(self, parent)
510
self.grid_rowconfigure(0, weight=1)
511
self.grid_columnconfigure(0, weight=1)
512
self.frame = tk.Frame(self)
513
self.frame.pack()
514
self.controller = controller
515
516
def create_buttons(self):
517
"""creates interactive cells to build MDP"""
518
_height = self.controller.shared_data['height'].get()
519
_width = self.controller.shared_data['width'].get()
520
self.controller.menu_bar.entryconfig('Edit', state=tk.NORMAL)
521
self.controller.menu_bar.entryconfig('Build', state=tk.NORMAL)
522
self.gridmdp = [[0.0] * max(1, _width) for _ in range(max(1, _height))]
523
self.buttons = [[None] * max(1, _width) for _ in range(max(1, _height))]
524
self.terminals = []
525
526
s = ttk.Style()
527
s.theme_use('clam')
528
s.configure('TButton', background=grayd, padding=0)
529
s.configure('wall.TButton', background=gray2, foreground=white)
530
s.configure('reward.TButton', background=gray9)
531
s.configure('+term.TButton', background=green8)
532
s.configure('-term.TButton', background=pblue, foreground=white)
533
s.configure('=term.TButton', background=green4)
534
535
for i in range(max(1, _height)):
536
for j in range(max(1, _width)):
537
self.buttons[i][j] = ttk.Button(self.frame, text=f'({_height - i - 1}, {j})',
538
width=int(196 / max(1, _width)),
539
command=partial(dialogbox, i, j, self.gridmdp, self.terminals,
540
self.buttons, _height))
541
self.buttons[i][j].grid(row=i, column=j, ipady=int(336 / max(1, _height)) - 12)
542
543
def initialize(self):
544
"""runs initialize_dialogbox"""
545
546
_height = self.controller.shared_data['height'].get()
547
_width = self.controller.shared_data['width'].get()
548
initialize_dialogbox(_width, _height, self.gridmdp, self.terminals, self.buttons)
549
550
def master_reset(self):
551
"""runs external reset"""
552
_height = self.controller.shared_data['height'].get()
553
_width = self.controller.shared_data['width'].get()
554
if tkinter.messagebox.askokcancel('Reset', 'Are you sure you want to reset all cells?'):
555
external_reset(_width, _height, self.gridmdp, self.terminals, self.buttons)
556
557
558
class SolveMDP(tk.Frame):
559
560
def __init__(self, parent, controller):
561
562
tk.Frame.__init__(self, parent)
563
self.grid_rowconfigure(0, weight=1)
564
self.grid_columnconfigure(0, weight=1)
565
self.frame = tk.Frame(self)
566
self.frame.pack()
567
self.controller = controller
568
self.terminated = False
569
self.iterations = 0
570
self.epsilon = 0.001
571
self.delta = 0
572
573
def process_data(self, terminals, _height, _width, gridmdp):
574
"""preprocess variables"""
575
576
flipped_terminals = []
577
578
for terminal in terminals:
579
flipped_terminals.append((terminal[1], _height - terminal[0] - 1))
580
581
grid_to_solve = [[0.0] * max(1, _width) for _ in range(max(1, _height))]
582
grid_to_show = [[0.0] * max(1, _width) for _ in range(max(1, _height))]
583
584
for i in range(max(1, _height)):
585
for j in range(max(1, _width)):
586
if gridmdp[i][j] == WALL_VALUE:
587
grid_to_show[i][j] = 0.0
588
grid_to_solve[i][j] = None
589
590
else:
591
grid_to_show[i][j] = grid_to_solve[i][j] = gridmdp[i][j]
592
593
return flipped_terminals, grid_to_solve, np.flipud(grid_to_show)
594
595
def create_graph(self, gridmdp, terminals, _height, _width):
596
"""creates canvas and initializes value_iteration_parameters"""
597
self._height = _height
598
self._width = _width
599
self.controller.menu_bar.entryconfig('Edit', state=tk.DISABLED)
600
self.controller.menu_bar.entryconfig('Build', state=tk.DISABLED)
601
602
self.terminals, self.gridmdp, self.grid_to_show = self.process_data(terminals, _height, _width, gridmdp)
603
self.sequential_decision_environment = GridMDP(self.gridmdp, terminals=self.terminals)
604
605
self.initialize_value_iteration_parameters(self.sequential_decision_environment)
606
607
self.canvas = FigureCanvasTkAgg(fig, self.frame)
608
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
609
self.anim = animation.FuncAnimation(fig, self.animate_graph, interval=50)
610
self.canvas.show()
611
612
def animate_graph(self, i):
613
"""performs value iteration and animates graph"""
614
615
# cmaps to use: bone_r, Oranges, inferno, BrBG, copper
616
self.iterations += 1
617
x_interval = max(2, len(self.gridmdp[0]))
618
y_interval = max(2, len(self.gridmdp))
619
x = np.linspace(0, len(self.gridmdp[0]) - 1, x_interval)
620
y = np.linspace(0, len(self.gridmdp) - 1, y_interval)
621
622
sub.clear()
623
sub.imshow(self.grid_to_show, cmap='BrBG', aspect='auto', interpolation='none', extent=extents(x) + extents(y),
624
origin='lower')
625
fig.tight_layout()
626
627
U = self.U1.copy()
628
629
for s in self.sequential_decision_environment.states:
630
self.U1[s] = self.R(s) + self.gamma * max(
631
[sum([p * U[s1] for (p, s1) in self.T(s, a)]) for a in self.sequential_decision_environment.actions(s)])
632
self.delta = max(self.delta, abs(self.U1[s] - U[s]))
633
634
self.grid_to_show = grid_to_show = [[0.0] * max(1, self._width) for _ in range(max(1, self._height))]
635
for k, v in U.items():
636
self.grid_to_show[k[1]][k[0]] = v
637
638
if (self.delta < self.epsilon * (1 - self.gamma) / self.gamma) or (
639
self.iterations > 60) and self.terminated is False:
640
self.terminated = True
641
display(self.grid_to_show, self._height, self._width)
642
643
pi = best_policy(self.sequential_decision_environment,
644
value_iteration(self.sequential_decision_environment, .01))
645
display_best_policy(self.sequential_decision_environment.to_arrows(pi), self._height, self._width)
646
647
ax = fig.gca()
648
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
649
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
650
651
def initialize_value_iteration_parameters(self, mdp):
652
"""initializes value_iteration parameters"""
653
self.U1 = {s: 0 for s in mdp.states}
654
self.R, self.T, self.gamma = mdp.R, mdp.T, mdp.gamma
655
656
def value_iteration_metastep(self, mdp, iterations=20):
657
"""runs value_iteration"""
658
659
U_over_time = []
660
U1 = {s: 0 for s in mdp.states}
661
R, T, gamma = mdp.R, mdp.T, mdp.gamma
662
663
for _ in range(iterations):
664
U = U1.copy()
665
666
for s in mdp.states:
667
U1[s] = R(s) + gamma * max([sum([p * U[s1] for (p, s1) in T(s, a)]) for a in mdp.actions(s)])
668
669
U_over_time.append(U)
670
return U_over_time
671
672
673
if __name__ == '__main__':
674
app = MDPapp()
675
app.geometry('1280x720')
676
app.mainloop()
677
678