Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
greyhatguy007
GitHub Repository: greyhatguy007/Machine-Learning-Specialization-Coursera
Path: blob/main/C3 - Unsupervised Learning, Recommenders, Reinforcement Learning/week3/optional-labs/utils.py
3564 views
1
import numpy as np
2
import matplotlib.pyplot as plt
3
4
def generate_rewards(num_states, each_step_reward, terminal_left_reward, terminal_right_reward):
5
6
rewards = [each_step_reward] * num_states
7
rewards[0] = terminal_left_reward
8
rewards[-1] = terminal_right_reward
9
10
return rewards
11
12
def generate_transition_prob(num_states, num_actions, misstep_prob = 0):
13
# 0 is left, 1 is right
14
15
p = np.zeros((num_states, num_actions, num_states))
16
17
for i in range(num_states):
18
if i != 0:
19
p[i, 0, i-1] = 1 - misstep_prob
20
p[i, 1, i-1] = misstep_prob
21
22
if i != num_states - 1:
23
p[i, 1, i+1] = 1 - misstep_prob
24
p[i, 0, i+1] = misstep_prob
25
26
# Terminal States
27
p[0] = np.zeros((num_actions, num_states))
28
p[-1] = np.zeros((num_actions, num_states))
29
30
return p
31
32
def calculate_Q_value(num_states, rewards, transition_prob, gamma, V_states, state, action):
33
q_sa = rewards[state] + gamma * sum([transition_prob[state, action, sp] * V_states[sp] for sp in range(num_states)])
34
return q_sa
35
36
def evaluate_policy(num_states, rewards, transition_prob, gamma, policy):
37
max_policy_eval = 10000
38
threshold = 1e-10
39
40
V = np.zeros(num_states)
41
42
for i in range(max_policy_eval):
43
delta = 0
44
for s in range(num_states):
45
v = V[s]
46
V[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V, s, policy[s])
47
delta = max(delta, abs(v - V[s]))
48
49
if delta < threshold:
50
break
51
52
return V
53
54
def improve_policy(num_states, num_actions, rewards, transition_prob, gamma, V, policy):
55
policy_stable = True
56
57
for s in range(num_states):
58
q_best = V[s]
59
for a in range(num_actions):
60
q_sa = calculate_Q_value(num_states, rewards, transition_prob, gamma, V, s, a)
61
if q_sa > q_best and policy[s] != a:
62
policy[s] = a
63
q_best = q_sa
64
policy_stable = False
65
66
return policy, policy_stable
67
68
69
def get_optimal_policy(num_states, num_actions, rewards, transition_prob, gamma):
70
optimal_policy = np.zeros(num_states, dtype=int)
71
max_policy_iter = 10000
72
73
for i in range(max_policy_iter):
74
policy_stable = True
75
76
V = evaluate_policy(num_states, rewards, transition_prob, gamma, optimal_policy)
77
optimal_policy, policy_stable = improve_policy(num_states, num_actions, rewards, transition_prob, gamma, V, optimal_policy)
78
79
if policy_stable:
80
break
81
82
return optimal_policy, V
83
84
def calculate_Q_values(num_states, rewards, transition_prob, gamma, optimal_policy):
85
# Left and then optimal policy
86
q_left_star = np.zeros(num_states)
87
88
# Right and optimal policy
89
q_right_star = np.zeros(num_states)
90
91
V_star = evaluate_policy(num_states, rewards, transition_prob, gamma, optimal_policy)
92
93
for s in range(num_states):
94
q_left_star[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V_star, s, 0)
95
q_right_star[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V_star, s, 1)
96
97
return q_left_star, q_right_star
98
99
100
def plot_optimal_policy_return(num_states, optimal_policy, rewards, V):
101
actions = [r"$\leftarrow$" if a == 0 else r"$\rightarrow$" for a in optimal_policy]
102
actions[0] = ""
103
actions[-1] = ""
104
105
fig, ax = plt.subplots(figsize=(2*num_states,2))
106
107
for i in range(num_states):
108
ax.text(i+0.5, 0.5, actions[i], fontsize=32, ha="center", va="center", color="orange")
109
ax.text(i+0.5, 0.25, rewards[i], fontsize=16, ha="center", va="center", color="black")
110
ax.text(i+0.5, 0.75, round(V[i],2), fontsize=16, ha="center", va="center", color="firebrick")
111
ax.axvline(i, color="black")
112
ax.set_xlim([0, num_states])
113
ax.set_ylim([0, 1])
114
115
ax.set_xticklabels([])
116
ax.set_yticklabels([])
117
ax.tick_params(axis='both', which='both', length=0)
118
ax.set_title("Optimal policy",fontsize = 16)
119
120
def plot_q_values(num_states, q_left_star, q_right_star, rewards):
121
fig, ax = plt.subplots(figsize=(3*num_states,2))
122
123
for i in range(num_states):
124
ax.text(i+0.2, 0.6, round(q_left_star[i],2), fontsize=16, ha="center", va="center", color="firebrick")
125
ax.text(i+0.8, 0.6, round(q_right_star[i],2), fontsize=16, ha="center", va="center", color="firebrick")
126
127
ax.text(i+0.5, 0.25, rewards[i], fontsize=20, ha="center", va="center", color="black")
128
ax.axvline(i, color="black")
129
ax.set_xlim([0, num_states])
130
ax.set_ylim([0, 1])
131
132
ax.set_xticklabels([])
133
ax.set_yticklabels([])
134
ax.tick_params(axis='both', which='both', length=0)
135
ax.set_title("Q(s,a)",fontsize = 16)
136
137
def generate_visualization(terminal_left_reward, terminal_right_reward, each_step_reward, gamma, misstep_prob):
138
num_states = 6
139
num_actions = 2
140
141
rewards = generate_rewards(num_states, each_step_reward, terminal_left_reward, terminal_right_reward)
142
transition_prob = generate_transition_prob(num_states, num_actions, misstep_prob)
143
144
optimal_policy, V = get_optimal_policy(num_states, num_actions, rewards, transition_prob, gamma)
145
q_left_star, q_right_star = calculate_Q_values(num_states, rewards, transition_prob, gamma, optimal_policy)
146
147
plot_optimal_policy_return(num_states, optimal_policy, rewards, V)
148
plot_q_values(num_states, q_left_star, q_right_star, rewards)
149
150