Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch19/cartpole/main.py
1247 views
1
# coding: utf-8
2
3
# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu
4
# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 2021
5
#
6
# Code Repository: https://github.com
7
#
8
# Code License: MIT License (https://github.com/ /LICENSE.txt)
9
10
#################################################################################
11
# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments
12
#################################################################################
13
14
# Script: carpole/main.py
15
16
import gym
17
import numpy as np
18
import torch
19
import torch.nn as nn
20
import random
21
import matplotlib.pyplot as plt
22
from collections import namedtuple
23
from collections import deque
24
25
np.random.seed(1)
26
torch.manual_seed(1)
27
28
Transition = namedtuple(
29
'Transition', ('state', 'action', 'reward',
30
'next_state', 'done'))
31
32
33
class DQNAgent:
34
def __init__(
35
self, env, discount_factor=0.95,
36
epsilon_greedy=1.0, epsilon_min=0.01,
37
epsilon_decay=0.995, learning_rate=1e-3,
38
max_memory_size=2000):
39
self.env = env
40
self.state_size = env.observation_space.shape[0]
41
self.action_size = env.action_space.n
42
43
self.memory = deque(maxlen=max_memory_size)
44
45
self.gamma = discount_factor
46
self.epsilon = epsilon_greedy
47
self.epsilon_min = epsilon_min
48
self.epsilon_decay = epsilon_decay
49
self.lr = learning_rate
50
self._build_nn_model()
51
52
def _build_nn_model(self):
53
self.model = nn.Sequential(nn.Linear(self.state_size, 256),
54
nn.ReLU(),
55
nn.Linear(256, 128),
56
nn.ReLU(),
57
nn.Linear(128, 64),
58
nn.ReLU(),
59
nn.Linear(64, self.action_size))
60
61
self.loss_fn = nn.MSELoss()
62
self.optimizer = torch.optim.Adam(
63
self.model.parameters(), self.lr)
64
65
def remember(self, transition):
66
self.memory.append(transition)
67
68
def choose_action(self, state):
69
if np.random.rand() <= self.epsilon:
70
return np.random.choice(self.action_size)
71
with torch.no_grad():
72
q_values = self.model(torch.tensor(state, dtype=torch.float32))[0]
73
return torch.argmax(q_values).item() # returns action
74
75
def _learn(self, batch_samples):
76
batch_states, batch_targets = [], []
77
for transition in batch_samples:
78
s, a, r, next_s, done = transition
79
80
with torch.no_grad():
81
if done:
82
target = r
83
else:
84
pred = self.model(torch.tensor(next_s, dtype=torch.float32))[0]
85
target = r + self.gamma * pred.max()
86
87
target_all = self.model(torch.tensor(s, dtype=torch.float32))[0]
88
target_all[a] = target
89
90
batch_states.append(s.flatten())
91
batch_targets.append(target_all)
92
self._adjust_epsilon()
93
94
self.optimizer.zero_grad()
95
pred = self.model(torch.tensor(batch_states, dtype=torch.float32))
96
97
loss = self.loss_fn(pred, torch.stack(batch_targets))
98
loss.backward()
99
self.optimizer.step()
100
101
return loss.item()
102
103
def _adjust_epsilon(self):
104
if self.epsilon > self.epsilon_min:
105
self.epsilon *= self.epsilon_decay
106
107
def replay(self, batch_size):
108
samples = random.sample(self.memory, batch_size)
109
return self._learn(samples)
110
111
112
def plot_learning_history(history):
113
fig = plt.figure(1, figsize=(14, 5))
114
ax = fig.add_subplot(1, 1, 1)
115
episodes = np.arange(len(history)) + 1
116
plt.plot(episodes, history, lw=4,
117
marker='o', markersize=10)
118
ax.tick_params(axis='both', which='major', labelsize=15)
119
plt.xlabel('Episodes', size=20)
120
plt.ylabel('Total rewards', size=20)
121
plt.show()
122
123
124
# General settings
125
EPISODES = 200
126
batch_size = 32
127
init_replay_memory_size = 500
128
129
if __name__ == '__main__':
130
env = gym.make('CartPole-v1')
131
agent = DQNAgent(env)
132
state = env.reset()
133
state = np.reshape(state, [1, agent.state_size])
134
135
# Filling up the replay-memory
136
for i in range(init_replay_memory_size):
137
action = agent.choose_action(state)
138
next_state, reward, done, _ = env.step(action)
139
next_state = np.reshape(next_state, [1, agent.state_size])
140
agent.remember(Transition(state, action, reward,
141
next_state, done))
142
if done:
143
state = env.reset()
144
state = np.reshape(state, [1, agent.state_size])
145
else:
146
state = next_state
147
148
total_rewards, losses = [], []
149
for e in range(EPISODES):
150
state = env.reset()
151
if e % 10 == 0:
152
env.render()
153
state = np.reshape(state, [1, agent.state_size])
154
for i in range(500):
155
action = agent.choose_action(state)
156
next_state, reward, done, _ = env.step(action)
157
next_state = np.reshape(next_state, [1, agent.state_size])
158
agent.remember(Transition(state, action, reward,
159
next_state, done))
160
state = next_state
161
if e % 10 == 0:
162
env.render()
163
if done:
164
total_rewards.append(i)
165
print(f'Episode: {e}/{EPISODES}, Total reward: {i}')
166
break
167
loss = agent.replay(batch_size)
168
losses.append(loss)
169
plot_learning_history(total_rewards)
170
171