Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch19/gridworld/agent.py
1247 views
1
# coding: utf-8
2
3
# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu
4
# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 2021
5
#
6
# Code Repository:
7
#
8
# Code License: MIT License (https://github.com/ /LICENSE.txt)
9
10
#################################################################################
11
# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments
12
#################################################################################
13
14
# Script: agent.py
15
16
from collections import defaultdict
17
import numpy as np
18
19
20
class Agent:
21
def __init__(
22
self, env,
23
learning_rate=0.01,
24
discount_factor=0.9,
25
epsilon_greedy=0.9,
26
epsilon_min=0.1,
27
epsilon_decay=0.95):
28
self.env = env
29
self.lr = learning_rate
30
self.gamma = discount_factor
31
self.epsilon = epsilon_greedy
32
self.epsilon_min = epsilon_min
33
self.epsilon_decay = epsilon_decay
34
35
# Define the q_table
36
self.q_table = defaultdict(lambda: np.zeros(self.env.nA))
37
38
def choose_action(self, state):
39
if np.random.uniform() < self.epsilon:
40
action = np.random.choice(self.env.nA)
41
else:
42
q_vals = self.q_table[state]
43
perm_actions = np.random.permutation(self.env.nA)
44
q_vals = [q_vals[a] for a in perm_actions]
45
perm_q_argmax = np.argmax(q_vals)
46
action = perm_actions[perm_q_argmax]
47
return action
48
49
def _learn(self, transition):
50
s, a, r, next_s, done = transition
51
q_val = self.q_table[s][a]
52
if done:
53
q_target = r
54
else:
55
q_target = r + self.gamma*np.max(self.q_table[next_s])
56
57
# Update the q_table
58
self.q_table[s][a] += self.lr * (q_target - q_val)
59
60
# Adjust the epsilon
61
self._adjust_epsilon()
62
63
def _adjust_epsilon(self):
64
if self.epsilon > self.epsilon_min:
65
self.epsilon *= self.epsilon_decay
66
67