Path: blob/master/mitdeeplearning/lab3_old.py
547 views
import io1import base642from IPython.display import HTML3import gym4import numpy as np5import cv2678def play_video(filename, width=None):9encoded = base64.b64encode(io.open(filename, "r+b").read())10video_width = 'width="' + str(width) + '"' if width is not None else ""11embedded = HTML(12data="""13<video controls {0}>14<source src="data:video/mp4;base64,{1}" type="video/mp4" />15</video>""".format(video_width, encoded.decode("ascii"))16)1718return embedded192021def preprocess_pong(image):22I = image[35:195] # Crop23I = I[::2, ::2, 0] # Downsample width and height by a factor of 224I[I == 144] = 0 # Remove background type 125I[I == 109] = 0 # Remove background type 226I[I != 0] = 1 # Set remaining elements (paddles, ball, etc.) to 127I = cv2.dilate(I, np.ones((3, 3), np.uint8), iterations=1)28I = I[::2, ::2, np.newaxis]29return I.astype(np.float)303132def pong_change(prev, curr):33prev = preprocess_pong(prev)34curr = preprocess_pong(curr)35I = prev - curr36# I = (I - I.min()) / (I.max() - I.min() + 1e-10)37return I383940class Memory:41def __init__(self):42self.clear()4344# Resets/restarts the memory buffer45def clear(self):46self.observations = []47self.actions = []48self.rewards = []4950# Add observations, actions, rewards to memory51def add_to_memory(self, new_observation, new_action, new_reward):52self.observations.append(new_observation)53self.actions.append(new_action)54self.rewards.append(new_reward)555657def aggregate_memories(memories):58batch_memory = Memory()5960for memory in memories:61for step in zip(memory.observations, memory.actions, memory.rewards):62batch_memory.add_to_memory(*step)6364return batch_memory656667def parallelized_collect_rollout(batch_size, envs, model, choose_action):68assert (69len(envs) == batch_size70), "Number of parallel environments must be equal to the batch size."7172memories = [Memory() for _ in range(batch_size)]73next_observations = [single_env.reset() for single_env in envs]74previous_frames = [obs for obs in next_observations]75done = [False] * batch_size76rewards = [0] * batch_size7778while True:79current_frames = [obs for obs in next_observations]80diff_frames = [81pong_change(prev, curr)82for (prev, curr) in zip(previous_frames, current_frames)83]8485diff_frames_not_done = [86diff_frames[b] for b in range(batch_size) if not done[b]87]88actions_not_done = choose_action(89model, np.array(diff_frames_not_done), single=False90)9192actions = [None] * batch_size93ind_not_done = 094for b in range(batch_size):95if not done[b]:96actions[b] = actions_not_done[ind_not_done]97ind_not_done += 19899for b in range(batch_size):100if done[b]:101continue102next_observations[b], rewards[b], done[b], info = envs[b].step(actions[b])103previous_frames[b] = current_frames[b]104memories[b].add_to_memory(diff_frames[b], actions[b], rewards[b])105106if all(done):107break108109return memories110111112def save_video_of_model(model, env_name, suffix=""):113import skvideo.io114from pyvirtualdisplay import Display115116display = Display(visible=0, size=(400, 300))117display.start()118119env = gym.make(env_name)120obs = env.reset()121prev_obs = obs122123filename = env_name + suffix + ".mp4"124output_video = skvideo.io.FFmpegWriter(filename)125126counter = 0127done = False128while not done:129frame = env.render(mode="rgb_array")130output_video.writeFrame(frame)131132if "CartPole" in env_name:133input_obs = obs134elif "Pong" in env_name:135input_obs = pong_change(prev_obs, obs)136else:137raise ValueError(f"Unknown env for saving: {env_name}")138139action = model(np.expand_dims(input_obs, 0)).numpy().argmax()140141prev_obs = obs142obs, reward, done, info = env.step(action)143counter += 1144145output_video.close()146print("Successfully saved {} frames into {}!".format(counter, filename))147return filename148149150def save_video_of_memory(memory, filename, size=(512, 512)):151import skvideo.io152153output_video = skvideo.io.FFmpegWriter(filename)154155for observation in memory.observations:156output_video.writeFrame(cv2.resize(255 * observation, size))157158output_video.close()159return filename160161162