Path: blob/main/C3 - Unsupervised Learning, Recommenders, Reinforcement Learning/week3/C3W3A1/utils.py
3564 views
import base641import random2from itertools import zip_longest34import imageio5import IPython6import matplotlib.pyplot as plt7import matplotlib.ticker as mticker8import numpy as np9import pandas as pd10import tensorflow as tf11from statsmodels.iolib.table import SimpleTable121314SEED = 0 # seed for pseudo-random number generator15MINIBATCH_SIZE = 64 # mini-batch size16TAU = 1e-3 # soft update parameter17E_DECAY = 0.995 # ε decay rate for ε-greedy policy18E_MIN = 0.01 # minimum ε value for ε-greedy policy192021random.seed(SEED)222324def get_experiences(memory_buffer):25experiences = random.sample(memory_buffer, k=MINIBATCH_SIZE)26states = tf.convert_to_tensor(np.array([e.state for e in experiences if e is not None]),dtype=tf.float32)27actions = tf.convert_to_tensor(np.array([e.action for e in experiences if e is not None]), dtype=tf.float32)28rewards = tf.convert_to_tensor(np.array([e.reward for e in experiences if e is not None]), dtype=tf.float32)29next_states = tf.convert_to_tensor(np.array([e.next_state for e in experiences if e is not None]),dtype=tf.float32)30done_vals = tf.convert_to_tensor(np.array([e.done for e in experiences if e is not None]).astype(np.uint8),31dtype=tf.float32)32return (states, actions, rewards, next_states, done_vals)333435def check_update_conditions(t, num_steps_upd, memory_buffer):36if (t + 1) % num_steps_upd == 0 and len(memory_buffer) > MINIBATCH_SIZE:37return True38else:39return False404142def get_new_eps(epsilon):43return max(E_MIN, E_DECAY*epsilon)444546def get_action(q_values, epsilon=0):47if random.random() > epsilon:48return np.argmax(q_values.numpy()[0])49else:50return random.choice(np.arange(4))515253def update_target_network(q_network, target_q_network):54for target_weights, q_net_weights in zip(target_q_network.weights, q_network.weights):55target_weights.assign(TAU * q_net_weights + (1.0 - TAU) * target_weights)565758def plot_history(reward_history, rolling_window=20, lower_limit=None,59upper_limit=None, plot_rw=True, plot_rm=True):6061if lower_limit is None or upper_limit is None:62rh = reward_history63xs = [x for x in range(len(reward_history))]64else:65rh = reward_history[lower_limit:upper_limit]66xs = [x for x in range(lower_limit,upper_limit)]6768df = pd.DataFrame(rh)69rollingMean = df.rolling(rolling_window).mean()7071plt.figure(figsize=(10,7), facecolor='white')7273if plot_rw:74plt.plot(xs, rh, linewidth=1, color='cyan')75if plot_rm:76plt.plot(xs, rollingMean, linewidth=2, color='magenta')7778text_color = 'black'7980ax = plt.gca()81ax.set_facecolor('black')82plt.grid()83# plt.title("Total Point History", color=text_color, fontsize=40)84plt.xlabel('Episode', color=text_color, fontsize=30)85plt.ylabel('Total Points', color=text_color, fontsize=30)86yNumFmt = mticker.StrMethodFormatter('{x:,}')87ax.yaxis.set_major_formatter(yNumFmt)88ax.tick_params(axis='x', colors=text_color)89ax.tick_params(axis='y', colors=text_color)90plt.show()919293def display_table(initial_state, action, next_state, reward, done):9495action_labels = ["Do nothing", "Fire right engine", "Fire main engine", "Fire left engine"]9697# Do not use column headers98column_headers = None99100with np.printoptions(formatter={'float': '{:.3f}'.format}):101table_info = [("Initial State:", [f"{initial_state}"]),102("Action:", [f"{action_labels[action]}"]),103("Next State:", [f"{next_state}"]),104("Reward Received:", [f"{reward:.3f}"]),105("Episode Terminated:", [f"{done}"])]106107# Generate table108row_labels, data = zip_longest(*table_info)109table = SimpleTable(data, column_headers, row_labels)110111return table112113114def embed_mp4(filename):115"""Embeds an mp4 file in the notebook."""116video = open(filename,'rb').read()117b64 = base64.b64encode(video)118tag = '''119<video width="840" height="480" controls>120<source src="data:video/mp4;base64,{0}" type="video/mp4">121Your browser does not support the video tag.122</video>'''.format(b64.decode())123return IPython.display.HTML(tag)124125126def create_video(filename, env, q_network, fps=30):127with imageio.get_writer(filename, fps=fps) as video:128done = False129state = env.reset()130frame = env.render(mode="rgb_array")131video.append_data(frame)132while not done:133state = np.expand_dims(state, axis=0)134q_values = q_network(state)135action = np.argmax(q_values.numpy()[0])136state, _, done, _ = env.step(action)137frame = env.render(mode="rgb_array")138video.append_data(frame)139140