Path: blob/main/intermediate_source/reinforcement_q_learning.py
1686 views
# -*- coding: utf-8 -*-1"""2Reinforcement Learning (DQN) Tutorial3=====================================4**Author**: `Adam Paszke <https://github.com/apaszke>`_5`Mark Towers <https://github.com/pseudo-rnd-thoughts>`_678This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent9on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.1011You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper1213**Task**1415The agent has to decide between two actions - moving the cart left or16right - so that the pole attached to it stays upright. You can find more17information about the environment and other more challenging environments at18`Gymnasium's website <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`__.1920.. figure:: /_static/img/cartpole.gif21:alt: CartPole2223CartPole2425As the agent observes the current state of the environment and chooses26an action, the environment *transitions* to a new state, and also27returns a reward that indicates the consequences of the action. In this28task, rewards are +1 for every incremental timestep and the environment29terminates if the pole falls over too far or the cart moves more than 2.430units away from center. This means better performing scenarios will run31for longer duration, accumulating larger return.3233The CartPole task is designed so that the inputs to the agent are 4 real34values representing the environment state (position, velocity, etc.).35We take these 4 inputs without any scaling and pass them through a36small fully-connected network with 2 outputs, one for each action.37The network is trained to predict the expected value for each action,38given the input state. The action with the highest expected value is39then chosen.404142**Packages**434445First, let's import needed packages. Firstly, we need46`gymnasium <https://gymnasium.farama.org/>`__ for the environment,47installed by using `pip`. This is a fork of the original OpenAI48Gym project and maintained by the same team since Gym v0.19.49If you are running this in Google Colab, run:5051.. code-block:: bash5253%%bash54pip3 install gymnasium[classic_control]5556We'll also use the following from PyTorch:5758- neural networks (``torch.nn``)59- optimization (``torch.optim``)60- automatic differentiation (``torch.autograd``)6162"""6364import gymnasium as gym65import math66import random67import matplotlib68import matplotlib.pyplot as plt69from collections import namedtuple, deque70from itertools import count7172import torch73import torch.nn as nn74import torch.optim as optim75import torch.nn.functional as F7677env = gym.make("CartPole-v1")7879# set up matplotlib80is_ipython = 'inline' in matplotlib.get_backend()81if is_ipython:82from IPython import display8384plt.ion()8586# if GPU is to be used87device = torch.device(88"cuda" if torch.cuda.is_available() else89"mps" if torch.backends.mps.is_available() else90"cpu"91)929394# To ensure reproducibility during training, you can fix the random seeds95# by uncommenting the lines below. This makes the results consistent across96# runs, which is helpful for debugging or comparing different approaches.97#98# That said, allowing randomness can be beneficial in practice, as it lets99# the model explore different training trajectories.100101102# seed = 42103# random.seed(seed)104# torch.manual_seed(seed)105# env.reset(seed=seed)106# env.action_space.seed(seed)107# env.observation_space.seed(seed)108# if torch.cuda.is_available():109# torch.cuda.manual_seed(seed)110111112######################################################################113# Replay Memory114# -------------115#116# We'll be using experience replay memory for training our DQN. It stores117# the transitions that the agent observes, allowing us to reuse this data118# later. By sampling from it randomly, the transitions that build up a119# batch are decorrelated. It has been shown that this greatly stabilizes120# and improves the DQN training procedure.121#122# For this, we're going to need two classes:123#124# - ``Transition`` - a named tuple representing a single transition in125# our environment. It essentially maps (state, action) pairs126# to their (next_state, reward) result, with the state being the127# screen difference image as described later on.128# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the129# transitions observed recently. It also implements a ``.sample()``130# method for selecting a random batch of transitions for training.131#132133Transition = namedtuple('Transition',134('state', 'action', 'next_state', 'reward'))135136137class ReplayMemory(object):138139def __init__(self, capacity):140self.memory = deque([], maxlen=capacity)141142def push(self, *args):143"""Save a transition"""144self.memory.append(Transition(*args))145146def sample(self, batch_size):147return random.sample(self.memory, batch_size)148149def __len__(self):150return len(self.memory)151152153######################################################################154# Now, let's define our model. But first, let's quickly recap what a DQN is.155#156# DQN algorithm157# -------------158#159# Our environment is deterministic, so all equations presented here are160# also formulated deterministically for the sake of simplicity. In the161# reinforcement learning literature, they would also contain expectations162# over stochastic transitions in the environment.163#164# Our aim will be to train a policy that tries to maximize the discounted,165# cumulative reward166# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where167# :math:`R_{t_0}` is also known as the *return*. The discount,168# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`169# that ensures the sum converges. A lower :math:`\gamma` makes170# rewards from the uncertain far future less important for our agent171# than the ones in the near future that it can be fairly confident172# about. It also encourages agents to collect reward closer in time173# than equivalent rewards that are temporally far away in the future.174#175# The main idea behind Q-learning is that if we had a function176# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell177# us what our return would be, if we were to take an action in a given178# state, then we could easily construct a policy that maximizes our179# rewards:180#181# .. math:: \pi^*(s) = \arg\!\max_a \ Q^*(s, a)182#183# However, we don't know everything about the world, so we don't have184# access to :math:`Q^*`. But, since neural networks are universal function185# approximators, we can simply create one and train it to resemble186# :math:`Q^*`.187#188# For our training update rule, we'll use a fact that every :math:`Q`189# function for some policy obeys the Bellman equation:190#191# .. math:: Q^{\pi}(s, a) = r + \gamma Q^{\pi}(s', \pi(s'))192#193# The difference between the two sides of the equality is known as the194# temporal difference error, :math:`\delta`:195#196# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))197#198# To minimize this error, we will use the `Huber199# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts200# like the mean squared error when the error is small, but like the mean201# absolute error when the error is large - this makes it more robust to202# outliers when the estimates of :math:`Q` are very noisy. We calculate203# this over a batch of transitions, :math:`B`, sampled from the replay204# memory:205#206# .. math::207#208# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)209#210# .. math::211#212# \text{where} \quad \mathcal{L}(\delta) = \begin{cases}213# \frac{1}{2}{\delta^2} & \text{for } |\delta| \le 1, \\214# |\delta| - \frac{1}{2} & \text{otherwise.}215# \end{cases}216#217# Q-network218# ^^^^^^^^^219#220# Our model will be a feed forward neural network that takes in the221# difference between the current and previous screen patches. It has two222# outputs, representing :math:`Q(s, \mathrm{left})` and223# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the224# network). In effect, the network is trying to predict the *expected return* of225# taking each action given the current input.226#227228class DQN(nn.Module):229230def __init__(self, n_observations, n_actions):231super(DQN, self).__init__()232self.layer1 = nn.Linear(n_observations, 128)233self.layer2 = nn.Linear(128, 128)234self.layer3 = nn.Linear(128, n_actions)235236# Called with either one element to determine next action, or a batch237# during optimization. Returns tensor([[left0exp,right0exp]...]).238def forward(self, x):239x = F.relu(self.layer1(x))240x = F.relu(self.layer2(x))241return self.layer3(x)242243244######################################################################245# Training246# --------247#248# Hyperparameters and utilities249# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^250# This cell instantiates our model and its optimizer, and defines some251# utilities:252#253# - ``select_action`` - will select an action according to an epsilon254# greedy policy. Simply put, we'll sometimes use our model for choosing255# the action, and sometimes we'll just sample one uniformly. The256# probability of choosing a random action will start at ``EPS_START``257# and will decay exponentially towards ``EPS_END``. ``EPS_DECAY``258# controls the rate of the decay.259# - ``plot_durations`` - a helper for plotting the duration of episodes,260# along with an average over the last 100 episodes (the measure used in261# the official evaluations). The plot will be underneath the cell262# containing the main training loop, and will update after every263# episode.264#265266# BATCH_SIZE is the number of transitions sampled from the replay buffer267# GAMMA is the discount factor as mentioned in the previous section268# EPS_START is the starting value of epsilon269# EPS_END is the final value of epsilon270# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay271# TAU is the update rate of the target network272# LR is the learning rate of the ``AdamW`` optimizer273274BATCH_SIZE = 128275GAMMA = 0.99276EPS_START = 0.9277EPS_END = 0.01278EPS_DECAY = 2500279TAU = 0.005280LR = 3e-4281282283# Get number of actions from gym action space284n_actions = env.action_space.n285# Get the number of state observations286state, info = env.reset()287n_observations = len(state)288289policy_net = DQN(n_observations, n_actions).to(device)290target_net = DQN(n_observations, n_actions).to(device)291target_net.load_state_dict(policy_net.state_dict())292293optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)294memory = ReplayMemory(10000)295296297steps_done = 0298299300def select_action(state):301global steps_done302sample = random.random()303eps_threshold = EPS_END + (EPS_START - EPS_END) * \304math.exp(-1. * steps_done / EPS_DECAY)305steps_done += 1306if sample > eps_threshold:307with torch.no_grad():308# t.max(1) will return the largest column value of each row.309# second column on max result is index of where max element was310# found, so we pick action with the larger expected reward.311return policy_net(state).max(1).indices.view(1, 1)312else:313return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)314315316episode_durations = []317318319def plot_durations(show_result=False):320plt.figure(1)321durations_t = torch.tensor(episode_durations, dtype=torch.float)322if show_result:323plt.title('Result')324else:325plt.clf()326plt.title('Training...')327plt.xlabel('Episode')328plt.ylabel('Duration')329plt.plot(durations_t.numpy())330# Take 100 episode averages and plot them too331if len(durations_t) >= 100:332means = durations_t.unfold(0, 100, 1).mean(1).view(-1)333means = torch.cat((torch.zeros(99), means))334plt.plot(means.numpy())335336plt.pause(0.001) # pause a bit so that plots are updated337if is_ipython:338if not show_result:339display.display(plt.gcf())340display.clear_output(wait=True)341else:342display.display(plt.gcf())343344345######################################################################346# Training loop347# ^^^^^^^^^^^^^348#349# Finally, the code for training our model.350#351# Here, you can find an ``optimize_model`` function that performs a352# single step of the optimization. It first samples a batch, concatenates353# all the tensors into a single one, computes :math:`Q(s_t, a_t)` and354# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our355# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal356# state. We also use a target network to compute :math:`V(s_{t+1})` for357# added stability. The target network is updated at every step with a358# `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by359# the hyperparameter ``TAU``, which was previously defined.360#361362def optimize_model():363if len(memory) < BATCH_SIZE:364return365transitions = memory.sample(BATCH_SIZE)366# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for367# detailed explanation). This converts batch-array of Transitions368# to Transition of batch-arrays.369batch = Transition(*zip(*transitions))370371# Compute a mask of non-final states and concatenate the batch elements372# (a final state would've been the one after which simulation ended)373non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,374batch.next_state)), device=device, dtype=torch.bool)375non_final_next_states = torch.cat([s for s in batch.next_state376if s is not None])377state_batch = torch.cat(batch.state)378action_batch = torch.cat(batch.action)379reward_batch = torch.cat(batch.reward)380381# Compute Q(s_t, a) - the model computes Q(s_t), then we select the382# columns of actions taken. These are the actions which would've been taken383# for each batch state according to policy_net384state_action_values = policy_net(state_batch).gather(1, action_batch)385386# Compute V(s_{t+1}) for all next states.387# Expected values of actions for non_final_next_states are computed based388# on the "older" target_net; selecting their best reward with max(1).values389# This is merged based on the mask, such that we'll have either the expected390# state value or 0 in case the state was final.391next_state_values = torch.zeros(BATCH_SIZE, device=device)392with torch.no_grad():393next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values394# Compute the expected Q values395expected_state_action_values = (next_state_values * GAMMA) + reward_batch396397# Compute Huber loss398criterion = nn.SmoothL1Loss()399loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))400401# Optimize the model402optimizer.zero_grad()403loss.backward()404# In-place gradient clipping405torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)406optimizer.step()407408409######################################################################410#411# Below, you can find the main training loop. At the beginning we reset412# the environment and obtain the initial ``state`` Tensor. Then, we sample413# an action, execute it, observe the next state and the reward (always414# 1), and optimize our model once. When the episode ends (our model415# fails), we restart the loop.416#417# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50418# episodes are scheduled so training does not take too long. However, 50419# episodes is insufficient for to observe good performance on CartPole.420# You should see the model constantly achieve 500 steps within 600 training421# episodes. Training RL agents can be a noisy process, so restarting training422# can produce better results if convergence is not observed.423#424425if torch.cuda.is_available() or torch.backends.mps.is_available():426num_episodes = 600427else:428num_episodes = 50429430for i_episode in range(num_episodes):431# Initialize the environment and get its state432state, info = env.reset()433state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)434for t in count():435action = select_action(state)436observation, reward, terminated, truncated, _ = env.step(action.item())437reward = torch.tensor([reward], device=device)438done = terminated or truncated439440if terminated:441next_state = None442else:443next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)444445# Store the transition in memory446memory.push(state, action, next_state, reward)447448# Move to the next state449state = next_state450451# Perform one step of the optimization (on the policy network)452optimize_model()453454# Soft update of the target network's weights455# θ′ ← τ θ + (1 −τ )θ′456target_net_state_dict = target_net.state_dict()457policy_net_state_dict = policy_net.state_dict()458for key in policy_net_state_dict:459target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)460target_net.load_state_dict(target_net_state_dict)461462if done:463episode_durations.append(t + 1)464plot_durations()465break466467print('Complete')468plot_durations(show_result=True)469plt.ioff()470plt.show()471472######################################################################473# Here is the diagram that illustrates the overall resulting data flow.474#475# .. figure:: /_static/img/reinforcement_learning_diagram.jpg476#477# Actions are chosen either randomly or based on a policy, getting the next478# step sample from the gym environment. We record the results in the479# replay memory and also run optimization step on every iteration.480# Optimization picks a random batch from the replay memory to do training of the481# new policy. The "older" target_net is also used in optimization to compute the482# expected Q values. A soft update of its weights are performed at every step.483#484485486