Path: blob/main/C3 - Unsupervised Learning, Recommenders, Reinforcement Learning/week3/optional-labs/utils.py
3564 views
import numpy as np1import matplotlib.pyplot as plt23def generate_rewards(num_states, each_step_reward, terminal_left_reward, terminal_right_reward):45rewards = [each_step_reward] * num_states6rewards[0] = terminal_left_reward7rewards[-1] = terminal_right_reward89return rewards1011def generate_transition_prob(num_states, num_actions, misstep_prob = 0):12# 0 is left, 1 is right1314p = np.zeros((num_states, num_actions, num_states))1516for i in range(num_states):17if i != 0:18p[i, 0, i-1] = 1 - misstep_prob19p[i, 1, i-1] = misstep_prob2021if i != num_states - 1:22p[i, 1, i+1] = 1 - misstep_prob23p[i, 0, i+1] = misstep_prob2425# Terminal States26p[0] = np.zeros((num_actions, num_states))27p[-1] = np.zeros((num_actions, num_states))2829return p3031def calculate_Q_value(num_states, rewards, transition_prob, gamma, V_states, state, action):32q_sa = rewards[state] + gamma * sum([transition_prob[state, action, sp] * V_states[sp] for sp in range(num_states)])33return q_sa3435def evaluate_policy(num_states, rewards, transition_prob, gamma, policy):36max_policy_eval = 1000037threshold = 1e-103839V = np.zeros(num_states)4041for i in range(max_policy_eval):42delta = 043for s in range(num_states):44v = V[s]45V[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V, s, policy[s])46delta = max(delta, abs(v - V[s]))4748if delta < threshold:49break5051return V5253def improve_policy(num_states, num_actions, rewards, transition_prob, gamma, V, policy):54policy_stable = True5556for s in range(num_states):57q_best = V[s]58for a in range(num_actions):59q_sa = calculate_Q_value(num_states, rewards, transition_prob, gamma, V, s, a)60if q_sa > q_best and policy[s] != a:61policy[s] = a62q_best = q_sa63policy_stable = False6465return policy, policy_stable666768def get_optimal_policy(num_states, num_actions, rewards, transition_prob, gamma):69optimal_policy = np.zeros(num_states, dtype=int)70max_policy_iter = 100007172for i in range(max_policy_iter):73policy_stable = True7475V = evaluate_policy(num_states, rewards, transition_prob, gamma, optimal_policy)76optimal_policy, policy_stable = improve_policy(num_states, num_actions, rewards, transition_prob, gamma, V, optimal_policy)7778if policy_stable:79break8081return optimal_policy, V8283def calculate_Q_values(num_states, rewards, transition_prob, gamma, optimal_policy):84# Left and then optimal policy85q_left_star = np.zeros(num_states)8687# Right and optimal policy88q_right_star = np.zeros(num_states)8990V_star = evaluate_policy(num_states, rewards, transition_prob, gamma, optimal_policy)9192for s in range(num_states):93q_left_star[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V_star, s, 0)94q_right_star[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V_star, s, 1)9596return q_left_star, q_right_star979899def plot_optimal_policy_return(num_states, optimal_policy, rewards, V):100actions = [r"$\leftarrow$" if a == 0 else r"$\rightarrow$" for a in optimal_policy]101actions[0] = ""102actions[-1] = ""103104fig, ax = plt.subplots(figsize=(2*num_states,2))105106for i in range(num_states):107ax.text(i+0.5, 0.5, actions[i], fontsize=32, ha="center", va="center", color="orange")108ax.text(i+0.5, 0.25, rewards[i], fontsize=16, ha="center", va="center", color="black")109ax.text(i+0.5, 0.75, round(V[i],2), fontsize=16, ha="center", va="center", color="firebrick")110ax.axvline(i, color="black")111ax.set_xlim([0, num_states])112ax.set_ylim([0, 1])113114ax.set_xticklabels([])115ax.set_yticklabels([])116ax.tick_params(axis='both', which='both', length=0)117ax.set_title("Optimal policy",fontsize = 16)118119def plot_q_values(num_states, q_left_star, q_right_star, rewards):120fig, ax = plt.subplots(figsize=(3*num_states,2))121122for i in range(num_states):123ax.text(i+0.2, 0.6, round(q_left_star[i],2), fontsize=16, ha="center", va="center", color="firebrick")124ax.text(i+0.8, 0.6, round(q_right_star[i],2), fontsize=16, ha="center", va="center", color="firebrick")125126ax.text(i+0.5, 0.25, rewards[i], fontsize=20, ha="center", va="center", color="black")127ax.axvline(i, color="black")128ax.set_xlim([0, num_states])129ax.set_ylim([0, 1])130131ax.set_xticklabels([])132ax.set_yticklabels([])133ax.tick_params(axis='both', which='both', length=0)134ax.set_title("Q(s,a)",fontsize = 16)135136def generate_visualization(terminal_left_reward, terminal_right_reward, each_step_reward, gamma, misstep_prob):137num_states = 6138num_actions = 2139140rewards = generate_rewards(num_states, each_step_reward, terminal_left_reward, terminal_right_reward)141transition_prob = generate_transition_prob(num_states, num_actions, misstep_prob)142143optimal_policy, V = get_optimal_policy(num_states, num_actions, rewards, transition_prob, gamma)144q_left_star, q_right_star = calculate_Q_values(num_states, rewards, transition_prob, gamma, optimal_policy)145146plot_optimal_policy_return(num_states, optimal_policy, rewards, V)147plot_q_values(num_states, q_left_star, q_right_star, rewards)148149150