{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DQN in pytorch\n", "\n", "The goal of this exercise is to implement DQN using pytorch and to apply it to the cartpole balancing problem. \n", "\n", "The code is adapted from the Pytorch tutorial: ." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "\n", "if IN_COLAB:\n", " !pip install -U gymnasium pygame swig\n", " !pip install -U moviepy==1.0.3\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gym version: 1.0.0\n", "Device: mps\n" ] } ], "source": [ "# Default libraries\n", "import math\n", "import random\n", "import os\n", "import time\n", "import numpy as np\n", "rng = np.random.default_rng()\n", "import matplotlib.pyplot as plt\n", "from collections import namedtuple, deque\n", "\n", "# Gymnasium\n", "import gymnasium as gym\n", "print(\"gym version:\", gym.__version__)\n", "\n", "# pytorch\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "# Select hardware: \n", "if torch.cuda.is_available(): # GPU\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available(): # Metal (Macos)\n", " device = torch.device(\"mps\")\n", "else: # CPU\n", " device = torch.device(\"cpu\")\n", "print(f\"Device: {device}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from moviepy.editor import ImageSequenceClip, ipython_display\n", "\n", "class GymRecorder(object):\n", " \"\"\"\n", " Simple wrapper over moviepy to generate a .gif with the frames of a gym environment.\n", " \n", " The environment must have the render_mode `rgb_array_list`.\n", " \"\"\"\n", " def __init__(self, env):\n", " self.env = env\n", " self._frames = []\n", "\n", " def record(self, frames):\n", " \"To be called at the end of an episode.\"\n", " for frame in frames:\n", " self._frames.append(np.array(frame))\n", "\n", " def make_video(self, filename):\n", " \"Generates the gif video.\"\n", " directory = os.path.dirname(os.path.abspath(filename))\n", " if not os.path.exists(directory):\n", " os.mkdir(directory)\n", " self.clip = ImageSequenceClip(list(self._frames), fps=self.env.metadata[\"render_fps\"])\n", " self.clip.write_gif(filename, fps=self.env.metadata[\"render_fps\"], loop=0)\n", " del self._frames\n", " self._frames = []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cartpole balancing task\n", "\n", "We are going to use the Cartpole balancing problem, which can be loaded with:\n", "\n", "```python\n", "gym.make('CartPole-v0', render_mode=\"rgb_array_list\")\n", "```\n", "\n", "States have 4 continuous values (position and speed of the cart, angle and speed of the pole) and 2 discrete outputs (going left or right). The reward is +1 for each transition where the pole is still standing (angle of less than 30° with the vertical). \n", "\n", "In CartPole-v0, the episode ends when the pole fails or after 200 steps. In CartPole-v1, the maximum episode length is 500 steps, which is too long for us, so we stick to v0 here.\n", "\n", "The maximal (undiscounted) return is therefore 200. Can DQN learn this?" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Return: 40.0\n" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0', render_mode=\"rgb_array_list\")\n", "recorder = GymRecorder(env)\n", "\n", "# Sample the initial state\n", "state, info = env.reset()\n", "\n", "# One episode:\n", "done = False\n", "return_episode = 0\n", "while not done:\n", "\n", " # Select an action randomly\n", " action = env.action_space.sample()\n", " \n", " # Sample a single transition\n", " next_state, reward, terminal, truncated, info = env.step(action)\n", "\n", " # End of the episode\n", " done = terminal or truncated\n", "\n", " # Update undiscounted return\n", " return_episode += reward\n", " \n", " # Go in the next state\n", " state = next_state\n", "\n", "print(\"Return:\", return_episode)\n", "recorder.record(env.render())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "video = \"videos/cartpole_random.gif\"\n", "recorder.make_video(video)\n", "ipython_display(video)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Value network in pytorch\n", "\n", "As the state in Cartpole has only four dimensions, we do not need a CNN for the value network. A simple MLP with a couple of hidden layers will be enough.\n", "\n", "**Q:** Create a MLP class in pytorch taking four inputs and two outputs (one Q-value per action), and two hidden layers of 128 neurons (you can change it later). If possible, make it parameterizable, i.e. have the constructor take in the number of inputs, outputs and hidden neurons. The activation function for the hidden layers should be ReLU." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class MLP(torch.nn.Module):\n", " \"Value network for DQN on Cartpole.\"\n", "\n", " def __init__(self, nb_observations, nb_hidden1, nb_hidden2, nb_actions):\n", " super(MLP, self).__init__()\n", " \n", " # Layers\n", " self.fc1 = torch.nn.Linear(nb_observations, nb_hidden1)\n", " self.fc2 = torch.nn.Linear(nb_hidden1, nb_hidden2)\n", " self.fc3 = torch.nn.Linear(nb_hidden2, nb_actions)\n", "\n", " def forward(self, x):\n", "\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " \n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** Create a network, an environment, get the initial state using `env.reset()` and pass it to the `forward()` method of your NN. What happens?\n", "\n", "Do not forget to send the network to your device, especially if you have a GPU. Create the network using something like:\n", "\n", "```python\n", "net = MLP(...).to(device)\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.00195839 -0.04332381 0.0268923 -0.02864526]\n" ] }, { "ename": "TypeError", "evalue": "linear(): argument 'input' (position 1) must be Tensor, not numpy.ndarray", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[8], line 12\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(state)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# Predict the Q-values from the initial state\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m Q_values \u001b[38;5;241m=\u001b[39m \u001b[43mnet\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[7], line 14\u001b[0m, in \u001b[0;36mMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 14\u001b[0m x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mrelu(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfc1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mrelu(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc2(x))\n\u001b[1;32m 16\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc3(x)\n", "File \u001b[0;32m~/Teaching/DeepReinforcementLearning/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Teaching/DeepReinforcementLearning/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/Teaching/DeepReinforcementLearning/.venv/lib/python3.12/site-packages/torch/nn/modules/linear.py:125\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 125\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mTypeError\u001b[0m: linear(): argument 'input' (position 1) must be Tensor, not numpy.ndarray" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0')\n", "\n", "# Create the value network\n", "net = MLP(env.observation_space.shape[0], 128, 128, env.action_space.n).to(device)\n", "\n", "# Sample the initial state\n", "state, info = env.reset()\n", "print(state)\n", "\n", "# Predict the Q-values from the initial state\n", "Q_values = net.forward(state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alright, we need to cast the state vector into a pytorch tensor, pytorch does not do it automatically.\n", "\n", "To cast a numpy vector of shape (4,) into a tensor, one simply needs to call:\n", "\n", "```python\n", "state = torch.tensor(state, dtype=torch.float32, device=device)\n", "```\n", "\n", "The dtype must be set to `torch.float32` for floating numbers. Integers should be set to `torch.long`. Do not forget to send the tensor to your device if you plan to pass it to your network.\n", "\n", "**Q:** Pass the new tensor to your network. What is the shape of the output tensor?" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.00917135 0.04095928 -0.0089457 -0.01949579] (4,)\n", "tensor([-0.0092, 0.0410, -0.0089, -0.0195], device='mps:0') torch.Size([4])\n", "tensor([ 0.0196, -0.0394], device='mps:0', grad_fn=) torch.Size([2])\n" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0')\n", "\n", "# Create the value network\n", "net = MLP(env.observation_space.shape[0], 128, 128, env.action_space.n).to(device)\n", "\n", "# Sample the initial state\n", "state, info = env.reset()\n", "print(state, state.shape)\n", "\n", "# Cast the state to a tensor\n", "state = torch.tensor(state, dtype=torch.float32, device=device)\n", "print(state, state.shape)\n", "\n", "# Predict the Q-values from the initial state\n", "Q_values = net.forward(state)\n", "print(Q_values, Q_values.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The value network outputs one Q-value per action, great. Now, let's identify the **greedy** action, i.e. the one with the highest Q-value. The two actions expected by the cartpole environment are 0 and 1, i.e. the index of the element with the highest Q-value as a Python integer. \n", "\n", "Have a look at those two methods of `Tensor`:\n", "\n", "* ``Tensor.argmax``: \n", "* ``Tensor.item``: \n", "\n", "**Q:** Find a way to obtain the index (as a Python integer) of the element with the highest value in the tensor of Q-values. Check that it works. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Greedy action: 0\n" ] } ], "source": [ "greedy_action = Q_values.argmax().item()\n", "print(f\"Greedy action: {greedy_action}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** Create a dummy agent class (as in the previous exercises) storing a value network and acting using $\\epsilon$-greedy action selection. Add a ``test()``method running a few episodes and possibly recording them. \n", "\n", "The constructor should accept several hyperparameters, such as the `config` dictionary in the following skeleton:\n", "\n", "```python\n", "class RandomDQNAgent:\n", " def __init__(self, env, config):\n", " def act(self, state):\n", " def test(self, nb_episodes, recorder=None):\n", "```\n", "\n", "but feel free to pass the hyperparameters one by one.\n", "\n", "To prepare ourselves, implement a schedule for `epsilon` in the `act()` method: epsilon should start at a high value of 0.9 and decrease exponentially to 0.05 for each action made. The value of epsilon follows this formula:\n", "\n", "$$\n", " \\epsilon = 0.05 + (0.9 - 0.05) * \\exp ( - \\dfrac{t}{1000})\n", "$$\n", "\n", "where t is the number of steps since the start. 0.05, 0.9 and 1000 should be parameters of the class." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class RandomDQNAgent:\n", " \"\"\"\n", " Random deep Q-learning agent.\n", " \"\"\"\n", " \n", " def __init__(self, env, config):\n", "\n", " self.env = env\n", " self.config = config\n", " self.epsilon = self.config['eps_start']\n", "\n", " # Number of actions\n", " self.n_actions = self.env.action_space.n\n", "\n", " # Number of states\n", " self.state, info = self.env.reset()\n", " self.n_observations = len(self.state)\n", "\n", " # Value network\n", " self.value_net = MLP(self.n_observations, config['nb_hidden'], config['nb_hidden'], self.n_actions).to(device)\n", "\n", " self.steps_done = 0\n", " \n", " \n", " def act(self, state):\n", " \"Returns an action using epsilon-greedy action selection.\"\n", "\n", " # Decay epsilon exponentially\n", " self.epsilon = self.config['eps_end'] + (self.config['eps_start'] - self.config['eps_end']) * math.exp(-1. * self.steps_done / self.config['eps_decay'])\n", "\n", " # Keep track of time\n", " self.steps_done += 1\n", " \n", " # epsilon-greedy action selection\n", " if rng.random() < self.epsilon:\n", " return self.env.action_space.sample()\n", " else:\n", " with torch.no_grad():\n", " return self.value_net(state).argmax().item()\n", "\n", " \n", " def test(self, nb_episodes, recorder=None):\n", " \"Performs a test episode without exploration.\"\n", " previous_epsilon = self.epsilon\n", " self.epsilon = 0.0\n", "\n", " for episode in range(nb_episodes):\n", " \n", " # Reset\n", " state, _ = self.env.reset()\n", " state = torch.tensor(state, dtype=torch.float32, device=device)\n", "\n", " # Sample the episode\n", " done = False\n", " return_episode = 0\n", " while not done: \n", " action = self.act(state)\n", " next_state, reward, terminal, truncated, info = self.env.step(action)\n", " return_episode += reward\n", " done = terminal or truncated\n", " state = torch.tensor(next_state, dtype=torch.float32, device=device)\n", "\n", " print(f\"Episode {episode}: return {return_episode}, epsilon: {self.epsilon:.4f}\")\n", " \n", " self.epsilon = previous_epsilon\n", " \n", " if recorder is not None:\n", " recorder.record(self.env.render())\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode 0: return 14.0, epsilon: 0.8890\n", "Episode 1: return 12.0, epsilon: 0.8790\n", "Episode 2: return 19.0, epsilon: 0.8634\n", "Episode 3: return 9.0, epsilon: 0.8561\n", "Episode 4: return 42.0, epsilon: 0.8230\n", "Episode 5: return 31.0, epsilon: 0.7994\n", "Episode 6: return 11.0, epsilon: 0.7912\n", "Episode 7: return 46.0, epsilon: 0.7579\n", "Episode 8: return 9.0, epsilon: 0.7515\n", "Episode 9: return 28.0, epsilon: 0.7321\n" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0', render_mode=\"rgb_array_list\")\n", "recorder = GymRecorder(env)\n", "\n", "# Hyperparameters\n", "config = {}\n", "config['nb_hidden'] = 128 # number of hidden neurons in each layer\n", "config['eps_start'] = 0.9 # starting value of epsilon\n", "config['eps_end'] = 0.05 # final value of epsilon\n", "config['eps_decay'] = 1000 # rate of exponential decay of epsilon, higher means a slower decay\n", "\n", "# Create the agent\n", "agent = RandomDQNAgent(env, config)\n", "\n", "# Make 10 evaluation episodes\n", "agent.test(10, recorder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "video = \"videos/cartpole-random2.gif\"\n", "recorder.make_video(video)\n", "ipython_display(video)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Target network\n", "\n", "The original DQN algorithm implies two neural networks:\n", "\n", "1. The value network $Q_\\theta(s, a)$, learning to predict the Q-values for the current state.\n", "2. The target network $Q_{\\theta'}(s, a)$, used to predict the Q-values in the next state.\n", "\n", "The target network is a copy of the value network (in terms of structure and parameters), but the update occurs only from time to time.\n", "\n", "**Q:** Create two MLPs of the same size and predict the Q-values of a single state. What happens?" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0960, -0.0482], device='mps:0', grad_fn=)\n", "tensor([0.0574, 0.0539], device='mps:0', grad_fn=)\n" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0')\n", "\n", "# Create the value network\n", "value_net = MLP(env.observation_space.shape[0], 128, 128, env.action_space.n).to(device)\n", "\n", "# Create the target network\n", "target_net = MLP(env.observation_space.shape[0], 128, 128, env.action_space.n).to(device)\n", "\n", "# Sample the initial state\n", "state, _ = env.reset()\n", "\n", "# Cast the state to a tensor\n", "state = torch.tensor(state, dtype=torch.float32, device=device)\n", "\n", "# Predict the Q-values for both networks\n", "Q_value = value_net.forward(state)\n", "Q_target = target_net.forward(state)\n", "\n", "print(Q_value)\n", "print(Q_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Obviously, the two MLPs are initialized using random parameters, so they are different. We need a method to copy the weights of a network into another one. \n", "\n", "It is fortunately very easy to save/load the parameters of a pytorch network:\n", "\n", "```python\n", "params = net.state_dict()\n", "net.load_state_dict(params)\n", "```\n", "\n", "**Q:** Apply these methods to update the weights of the target network with the value one. Check that they now predict the same thing." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0139, -0.0111], device='mps:0', grad_fn=)\n", "tensor([ 0.0139, -0.0111], device='mps:0', grad_fn=)\n" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0')\n", "\n", "# Create the value network\n", "value_net = MLP(env.observation_space.shape[0], 128, 128, env.action_space.n).to(device)\n", "\n", "# Create the target network\n", "target_net = MLP(env.observation_space.shape[0], 128, 128, env.action_space.n).to(device)\n", "\n", "# Update the target network\n", "target_net.load_state_dict(value_net.state_dict())\n", "\n", "# Sample the initial state\n", "state, _ = env.reset()\n", "\n", "# Cast the state to a tensor\n", "state = torch.tensor(state, dtype=torch.float32, device=device)\n", "\n", "# Predict the Q-values for both networks\n", "Q_value = value_net.forward(state)\n", "Q_target = target_net.forward(state)\n", "\n", "print(Q_value)\n", "print(Q_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Experience Replay Memory\n", "\n", "The second important component of DQN is the experience replay memory (ERM) or replay buffer. It is a limited size buffer that can store $(s, a, r, s', d)$ transitions, where $d$ is a boolean indicating whether the next state $s'$ is terminal or not (in gymnasium, this is the boolean `done = terminal or truncated`).\n", "\n", "Below is a simple implementation of an ERM. The important data structure here is `deque` (double-ended queue) which behaves like a list when `append()` is called, until its capacity is reached (`maxlen`), in which case new elements overwrite older ones. \n", "\n", "`batch = sample(batch_size)` randomly samples a minibatch from the ERM and returns a structure of $(s, a, r, s', d)$ transitions, nicely casted into pytorch tensors. These tensors are accessed with `batch.state`, `batch.action`, etc." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Named tuples are fancy dictionaries\n", "Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))\n", "\n", "class ReplayMemory(object):\n", " \"Simple Experience Replay Memory using uniform sampling.\"\n", "\n", " def __init__(self, capacity):\n", " self.memory = deque([], maxlen=capacity)\n", "\n", " def append(self, state, action, reward, next_state, done):\n", " \"Appends a transition (s, a, r, s', done) to the buffer.\"\n", "\n", " # Get numpy arrays even if it is a torch tensor\n", " if isinstance(state, (torch.Tensor,)): state = state.numpy(force=True)\n", " if isinstance(next_state, (torch.Tensor,)): next_state = next_state.numpy(force=True)\n", " \n", " # Append to the buffer\n", " self.memory.append(Transition(state, action, reward, next_state, done))\n", "\n", " def sample(self, batch_size):\n", " \"Returns a minibatch of (s, a, r, s', done)\"\n", "\n", " # Sample the batch\n", " transitions = random.sample(self.memory, batch_size)\n", " \n", " # Transpose the batch.\n", " batch = Transition(*zip(*transitions))\n", " \n", " # Cast to tensors\n", " states = torch.tensor(batch.state, dtype=torch.float32, device=device)\n", " actions = torch.tensor(batch.action, dtype=torch.long, device=device)\n", " rewards = torch.tensor(batch.reward, dtype=torch.float32, device=device)\n", " next_states = torch.tensor(batch.next_state, dtype=torch.float32, device=device)\n", " dones = torch.tensor(batch.done, dtype=torch.bool, device=device)\n", "\n", " return Transition(states, actions, rewards, next_states, dones)\n", "\n", " def __len__(self):\n", " return len(self.memory)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** Modify your random DQN agent so that it stores a replay buffer of capacity 10000 and appends all transitions into it. Do a few episodes, sample a small minibatch and have a look at the data you obtain." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class RandomDQNAgent:\n", " \"\"\"\n", " Random deep Q-learning agent with memory.\n", " \"\"\"\n", " \n", " def __init__(self, env, config):\n", "\n", " self.env = env\n", " self.config = config\n", " self.epsilon = self.config['eps_start']\n", "\n", " # Number of actions\n", " self.n_actions = self.env.action_space.n\n", "\n", " # Number of states\n", " self.state, info = self.env.reset()\n", " self.n_observations = len(self.state)\n", "\n", " # Value network\n", " self.value_net = MLP(self.n_observations, config['nb_hidden'], config['nb_hidden'], self.n_actions).to(device)\n", "\n", " # Replay buffer\n", " self.memory = ReplayMemory(capacity=1000)\n", "\n", " self.steps_done = 0\n", " \n", " \n", " def act(self, state):\n", " \"Returns an action using epsilon-greedy action selection.\"\n", "\n", " # Decay epsilon exponentially\n", " self.epsilon = self.config['eps_end'] + (self.config['eps_start'] - self.config['eps_end']) * math.exp(-1. * self.steps_done / self.config['eps_decay'])\n", "\n", " # Keep track of time\n", " self.steps_done += 1\n", " \n", " # epsilon-greedy action selection\n", " if rng.random() < self.epsilon:\n", " return self.env.action_space.sample()\n", " else:\n", " with torch.no_grad():\n", " return self.value_net(state).argmax().item()\n", "\n", " \n", " def test(self, nb_episodes, recorder=None):\n", " \"Performs a test episode without exploration.\"\n", " previous_epsilon = self.epsilon\n", " self.epsilon = 0.0\n", "\n", " for episode in range(nb_episodes):\n", " \n", " # Reset\n", " state, _ = self.env.reset()\n", " state = torch.tensor(state, dtype=torch.float32, device=device)\n", "\n", " # Sample the episode\n", " done = False\n", " return_episode = 0\n", " while not done: \n", " action = self.act(state)\n", " next_state, reward, terminal, truncated, info = self.env.step(action)\n", " return_episode += reward\n", " done = terminal or truncated\n", " \n", " # Append the transition to the replay buffer\n", " self.memory.append(state, action, reward, next_state, done)\n", "\n", " state = torch.tensor(next_state, dtype=torch.float32, device=device)\n", "\n", " print(f\"Episode {episode}: return {return_episode}, epsilon: {self.epsilon:.4f}\")\n", " \n", " self.epsilon = previous_epsilon\n", " \n", " if recorder is not None:\n", " recorder.record(self.env.render())\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode 0: return 24.0, epsilon: 0.8807\n", "Episode 1: return 12.0, epsilon: 0.8708\n", "Episode 2: return 19.0, epsilon: 0.8553\n", "Episode 3: return 32.0, epsilon: 0.8300\n", "Episode 4: return 9.0, epsilon: 0.8230\n", "Episode 5: return 17.0, epsilon: 0.8099\n", "Episode 6: return 17.0, epsilon: 0.7971\n", "Episode 7: return 15.0, epsilon: 0.7860\n", "Episode 8: return 33.0, epsilon: 0.7621\n", "Episode 9: return 12.0, epsilon: 0.7536\n" ] } ], "source": [ "# Create the environment\n", "env = gym.make('CartPole-v0')\n", "\n", "# Hyperparameters\n", "config = {}\n", "config['nb_hidden'] = 128 # number of hidden neurons in each layer\n", "config['eps_start'] = 0.9 # starting value of epsilon\n", "config['eps_end'] = 0.05 # final value of epsilon\n", "config['eps_decay'] = 1000 # rate of exponential decay of epsilon, higher means a slower decay\n", "config['buffer_limit'] = 1000 # maximum number of transitions in the replay buffer\n", "\n", "# Create the agent\n", "agent = RandomDQNAgent(env, config)\n", "\n", "# Make 10 evaluation episodes\n", "agent.test(10)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "States: torch.Size([10, 4]) tensor([[-0.1452, -1.4138, 0.1999, 2.2473],\n", " [-0.1104, -0.2476, 0.2082, 0.7865],\n", " [-0.0767, -0.8002, 0.0919, 1.1989],\n", " [-0.0178, -0.6156, 0.0382, 0.8665],\n", " [-0.0724, -0.2315, 0.1183, 0.4170],\n", " [ 0.1283, 0.2030, -0.0597, -0.2070],\n", " [ 0.0689, 0.3966, 0.0116, -0.4653],\n", " [ 0.1142, -0.7662, -0.0327, 1.1176],\n", " [ 0.1016, -0.7656, -0.0328, 1.0440],\n", " [ 0.0760, 1.0158, -0.1276, -1.5086]], device='mps:0')\n", "Actions: torch.Size([10]) tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 1], device='mps:0')\n", "Rewards: torch.Size([10]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='mps:0')\n", "Next states: torch.Size([10, 4]) tensor([[-0.1735, -1.2211, 0.2449, 2.0223],\n", " [-0.1154, -0.0559, 0.2240, 0.5658],\n", " [-0.0927, -0.6064, 0.1158, 0.9363],\n", " [-0.0301, -0.4211, 0.0556, 0.5861],\n", " [-0.0770, -0.0382, 0.1266, 0.1638],\n", " [ 0.1324, 0.0088, -0.0638, 0.0662],\n", " [ 0.0768, 0.2013, 0.0023, -0.1690],\n", " [ 0.0989, -0.9609, -0.0104, 1.3998],\n", " [ 0.0863, -0.9603, -0.0120, 1.3262],\n", " [ 0.0963, 1.2122, -0.1577, -1.8382]], device='mps:0')\n", "Dones: torch.Size([10]) tensor([ True, True, False, False, False, False, False, False, False, False],\n", " device='mps:0')\n" ] } ], "source": [ "# Sample the ERM\n", "batch = agent.memory.sample(10)\n", "\n", "print(\"States:\", batch.state.shape, batch.state)\n", "print(\"Actions:\", batch.action.shape, batch.action)\n", "print(\"Rewards:\", batch.reward.shape, batch.reward)\n", "print(\"Next states:\", batch.next_state.shape, batch.next_state)\n", "print(\"Dones:\", batch.done.shape, batch.done)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** Use the value network stored into your agent to predict the Q-values of all actions for the states contained in the minibatch. Do NOT use a for loop. Check the size of the resulting tensor. " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 2])\n", "tensor([[ 0.1905, -0.1435],\n", " [ 0.0781, -0.0912],\n", " [ 0.1190, -0.1079],\n", " [ 0.0915, -0.0930],\n", " [ 0.0475, -0.0733],\n", " [ 0.0260, -0.0398],\n", " [ 0.0280, -0.0336],\n", " [ 0.1112, -0.1073],\n", " [ 0.1075, -0.1021],\n", " [ 0.0572, -0.0072]], device='mps:0', grad_fn=)\n" ] } ], "source": [ "Q_values = agent.value_net.forward(batch.state)\n", "\n", "print(Q_values.shape)\n", "print(Q_values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** The previous tensors returns the value of all actions in those visited states. We now want only the Q-value of action that was taken (whose index is in `batch.action`). The resulting tensor should therefore a vector of length `batch_size`. How do we do that?\n", "\n", "*Hint:* it would take months of practice to master all the indexing methods available in pytorch: . Meanwhile, numpy-style indexing could be useful. Check what the following statements do:\n", "\n", "```python\n", "N = 10\n", "A = torch.randn((N, 2))\n", "B = torch.randint(0, 2, (N,))\n", "C = A[range(N), B]\n", "```" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.1905, -0.1435],\n", " [ 0.0781, -0.0912],\n", " [ 0.1190, -0.1079],\n", " [ 0.0915, -0.0930],\n", " [ 0.0475, -0.0733],\n", " [ 0.0260, -0.0398],\n", " [ 0.0280, -0.0336],\n", " [ 0.1112, -0.1073],\n", " [ 0.1075, -0.1021],\n", " [ 0.0572, -0.0072]], device='mps:0', grad_fn=)\n", "tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 1], device='mps:0')\n", "tensor([-0.1435, -0.0912, -0.1079, -0.0930, -0.0733, 0.0260, 0.0280, 0.1112,\n", " 0.1075, -0.0072], device='mps:0', grad_fn=)\n" ] } ], "source": [ "print(Q_values)\n", "print(batch.action)\n", "Q_taken = Q_values[range(10), batch.action]\n", "print(Q_taken)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DQN agent\n", "\n", "Good, we should now have all the elementary bricks to create our DQN agent. \n", "\n", "Reminder from the lecture:\n", "\n", "---\n", "\n", "* Initialize value network $Q_{\\theta}$ and target network $Q_{\\theta'}$.\n", "\n", "* Initialize experience replay memory $\\mathcal{D}$ of maximal size $N$.\n", "\n", "* for $t \\in [0, T_\\text{total}]$:\n", "\n", " * Select an action $a_t$ based on $Q_\\theta(s_t, a)$ ($\\epsilon$-greedy), observe $s_{t+1}$ and $r_{t+1}$.\n", "\n", " * Store $(s_t, a_t, r_{t+1}, s_{t+1})$ in the experience replay memory.\n", "\n", " * Every $T_\\text{train}$ steps:\n", "\n", " * Sample a minibatch $\\mathcal{D}_s$ randomly from $\\mathcal{D}$.\n", "\n", " * For each transition $(s_k, a_k, r_k, s'_k)$ in the minibatch:\n", "\n", " * Compute the target value $t_k = r_k + \\gamma \\, \\max_{a'} Q_{\\theta'}(s'_k, a')$ using the target network.\n", "\n", " * Update the value network $Q_{\\theta}$ on $\\mathcal{D}_s$ to minimize:\n", "\n", " $$\\mathcal{L}(\\theta) = \\mathbb{E}_{\\mathcal{D}_s}[(t_k - Q_\\theta(s_k, a_k))^2]$$\n", "\n", " * Every $T_\\text{target}$ steps:\n", "\n", " * Update target network: $\\theta' \\leftarrow \\theta$.\n", "---\n", "\n", "Create a DQN agent class inspired from the notebooks on MC or TD. The constructor should create the value and target networks, and make sure that their parameters are the same. It also creates an empty replay buffer. \n", "\n", "The `act()` method implements $\\epsilon$-greedy action selection, with an exponentially decaying schedule for $\\epsilon$. The greedy action is read from the value network.\n", "\n", "The `train()` and `test()` methods run training and test episodes as usual, with optional rendering. The train method should return (or store in the object) the return of each episode (its length) so we can plot it at the end. \n", "\n", "The main difficulty will be the `update()` method, where learning is supposed to happen. It should sample a minibatch from the replay memory, compute a vector of Bellman targets $r_t + \\gamma \\, \\max_a Q(s_{t+1}, a)$ for each transition in the batch, compute the loss function (mse between these targets and the predicted Q-values), backpropagate the gradients and apply the optimizer (Adam, but feel free to pick your preferred optimizer). Refer to the previous notebook on pytorch if you do not know how to do that.\n", "\n", "The main tricky part is when $V(s_{t+1})$ has to be predicted by the **target** network. You do not want the target network to learn from the minibatch, so it should not compute the corresponding gradients to save computational time. You can make sure that the target network is purely in inference mode with the following context:\n", "\n", "```python\n", "with torch.no_grad():\n", " next_Q_values = target_net(batch.next_state)\n", "```\n", "\n", "Of course you want the Q-value of the greedy action in the next state, not the vector of all Q-values, so check the doc of `Tensor.max()`. \n", "\n", "Importantly, when the next state $s'$ is terminal (either the agent failed or the 200 steps are over), the Bellman target should be simply $r_t$ instead of $r_t + \\gamma \\, \\max_a Q(s_{t+1}, a)$, as no action will be taken in the next state. This is why we saved the booleans `done` were saved in the replay buffer. As they are boolean, you can use them for indexing:\n", "\n", "```python\n", "Q = torch.randn((batch_size,))\n", "Q[batch.dones] = 0.0\n", "```\n", "\n", "A minor detail: do not start learning until the replay buffer is full enough, otherwise you will not fill your minibatch. Usually, there is no learning until the buffer contains two or three times the batch size. Use `len(memory)` to know the current number of stored transitions.\n", "\n", "Here is a set of suggested hyperparameters to help you start. Of course, it is strongly advised to modify them and observe their influence, but it depends on the remaining time.\n", "\n", "* $\\gamma = 0.99$.\n", "* MLP with two layers of 128 neurons, Adam optimizer with a fixed learning rate of 0.001.\n", "* Replay buffer of maximum capacity 10000, batch size of 128.\n", "* Target network updated every 120 steps.\n", "* Epsilon-greedy action selection, with the schedule:\n", "\n", "$$\n", " \\epsilon = 0.05 + (0.9 - 0.05) * \\exp ( - \\dfrac{t}{1000})\n", "$$\n", "\n", "where $t$ is the total number of steps.\n", "\n", "\n", "**Q:** Train a DQN on cartpole for 250 episodes. How would you characterize the learning process (speed, stability, etc.). If possible, do several runs. Vary the hyperparameters." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class DQNAgent:\n", " \"DQN agent.\"\n", " \n", " def __init__(self, env, config):\n", "\n", " # Parameters\n", " self.env = env\n", " self.config = config\n", "\n", " # Number of actions\n", " self.n_actions = self.env.action_space.n\n", "\n", " # Number of states\n", " self.state, info = self.env.reset()\n", " self.n_observations = len(self.state)\n", "\n", " # Value network\n", " self.value_net = MLP(self.n_observations, config['nb_hidden'], config['nb_hidden'], self.n_actions).to(device)\n", "\n", " # Target network\n", " self.target_net = MLP(self.n_observations, config['nb_hidden'], config['nb_hidden'], self.n_actions).to(device)\n", "\n", " # Copy the value weights into the target network\n", " self.target_net.load_state_dict(self.value_net.state_dict())\n", "\n", " # Optimizer\n", " self.optimizer = torch.optim.Adam(self.value_net.parameters(), lr=self.config['learning_rate'])\n", "\n", " # Loss function\n", " self.loss_function = torch.nn.MSELoss()\n", " \n", " # Replay buffer\n", " self.memory = ReplayMemory(self.config['buffer_limit'])\n", "\n", " self.steps_done = 0\n", " self.episode_durations = []\n", "\n", "\n", " def act(self, state):\n", "\n", " # Decay epsilon exponentially\n", " self.epsilon = self.config['eps_end'] + (self.config['eps_start'] - self.config['eps_end']) * math.exp(-1. * self.steps_done / self.config['eps_decay'])\n", "\n", " # Keep track of time\n", " self.steps_done += 1\n", " \n", " # epsilon-greedy action selection\n", " if rng.random() < self.epsilon:\n", " return self.env.action_space.sample()\n", " else:\n", " with torch.no_grad():\n", " return self.value_net(state).argmax(dim=0).item()\n", "\n", " def update(self):\n", "\n", " # Only learn when the replay buffer is full enough\n", " if len(self.memory) < 2 * self.config['batch_size']:\n", " return\n", " \n", " # Sample a batch\n", " batch = self.memory.sample(self.config['batch_size'])\n", "\n", " # Compute Q(s_t, a) with the current value network.\n", " Q_values = self.value_net(batch.state)[range(self.config['batch_size']), batch.action]\n", " \n", " # Compute Q(s_{t+1}, a*) for all next states.\n", " # If the next state is terminal, set the value to zero.\n", " # Do not compute gradients.\n", " with torch.no_grad():\n", " next_Q_values = self.target_net(batch.next_state).max(dim=1).values\n", " next_Q_values[batch.done] = 0.0\n", "\n", " # Compute the target Q values\n", " targets = (next_Q_values * self.config['gamma']) + batch.reward\n", "\n", " # Compute loss\n", " loss = self.loss_function(Q_values, targets)\n", "\n", " # Reinitialize the gradients\n", " self.optimizer.zero_grad()\n", "\n", " # Backpropagation\n", " loss.backward()\n", "\n", " # In-place gradient clipping (optional)\n", " #torch.nn.utils.clip_grad_value_(self.value_net.parameters(), 100)\n", "\n", " # Optimizer step\n", " self.optimizer.step()\n", "\n", " def train(self, num_episodes):\n", " \n", " for i_episode in range(num_episodes):\n", "\n", " tstart = time.time()\n", "\n", " # Initialize the environment and get its state\n", " state, _ = self.env.reset()\n", "\n", " # Transform the state into a tensor\n", " state = torch.tensor(state, dtype=torch.float32, device=device)\n", "\n", " done = False\n", " steps_episode = 0\n", " while not done:\n", " \n", " # Select an action\n", " action = self.act(state)\n", " \n", " # Perform the action\n", " next_state, reward, terminated, truncated, _ = self.env.step(action)\n", " \n", " # Terminal state\n", " done = terminated or truncated\n", "\n", " # Store the transition in memory\n", " self.memory.append(state, action, reward, next_state, done)\n", "\n", " # Move to the next state\n", " state = torch.tensor(next_state, dtype=torch.float32, device=device)\n", "\n", " # Perform one step of the optimization (on the policy network)\n", " self.update()\n", "\n", " # Update of the target network's weights\n", " if self.steps_done % self.config['target_update_period'] == 0:\n", " self.target_net.load_state_dict(self.value_net.state_dict())\n", "\n", " # Finish episode\n", " steps_episode += 1\n", " if done:\n", " self.episode_durations.append(steps_episode)\n", " print(f\"Episode {i_episode+1}, duration {steps_episode}, epsilon {self.epsilon:.4f} done in {time.time() - tstart}\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode 1, duration 20, epsilon 0.8840 done in 0.010928869247436523\n", "Episode 2, duration 39, epsilon 0.8521 done in 0.02133011817932129\n", "Episode 3, duration 35, epsilon 0.8245 done in 0.021078109741210938\n", "Episode 4, duration 20, epsilon 0.8092 done in 0.010505914688110352\n", "Episode 5, duration 44, epsilon 0.7765 done in 0.026645898818969727\n", "Episode 6, duration 25, epsilon 0.7586 done in 0.016075849533081055\n", "Episode 7, duration 13, epsilon 0.7494 done in 0.006777763366699219\n", "Episode 8, duration 34, epsilon 0.7260 done in 0.021413803100585938\n", "Episode 9, duration 41, epsilon 0.6989 done in 2.526124954223633\n", "Episode 10, duration 21, epsilon 0.6854 done in 0.14540505409240723\n", "Episode 11, duration 17, epsilon 0.6747 done in 0.12012696266174316\n", "Episode 12, duration 14, epsilon 0.6660 done in 0.09074711799621582\n", "Episode 13, duration 13, epsilon 0.6580 done in 0.08664298057556152\n", "Episode 14, duration 28, epsilon 0.6412 done in 0.19265198707580566\n", "Episode 15, duration 22, epsilon 0.6284 done in 0.15033173561096191\n", "Episode 16, duration 22, epsilon 0.6158 done in 0.15267610549926758\n", "Episode 17, duration 23, epsilon 0.6029 done in 0.15503692626953125\n", "Episode 18, duration 15, epsilon 0.5947 done in 0.10272073745727539\n", "Episode 19, duration 17, epsilon 0.5855 done in 0.11703014373779297\n", "Episode 20, duration 15, epsilon 0.5775 done in 0.10234189033508301\n", "Episode 21, duration 14, epsilon 0.5702 done in 0.1012122631072998\n", "Episode 22, duration 21, epsilon 0.5594 done in 0.14426207542419434\n", "Episode 23, duration 32, epsilon 0.5434 done in 0.21706485748291016\n", "Episode 24, duration 27, epsilon 0.5302 done in 0.18805289268493652\n", "Episode 25, duration 27, epsilon 0.5174 done in 0.19099903106689453\n", "Episode 26, duration 22, epsilon 0.5073 done in 0.16488099098205566\n", "Episode 27, duration 33, epsilon 0.4924 done in 0.2926771640777588\n", "Episode 28, duration 63, epsilon 0.4654 done in 0.49425673484802246\n", "Episode 29, duration 108, epsilon 0.4229 done in 0.7813999652862549\n", "Episode 30, duration 84, epsilon 0.3928 done in 0.6466999053955078\n", "Episode 31, duration 66, epsilon 0.3709 done in 0.4575388431549072\n", "Episode 32, duration 45, epsilon 0.3568 done in 0.31013035774230957\n", "Episode 33, duration 177, epsilon 0.3070 done in 1.2206060886383057\n", "Episode 34, duration 123, epsilon 0.2773 done in 0.8495826721191406\n", "Episode 35, duration 156, epsilon 0.2445 done in 1.1377429962158203\n", "Episode 36, duration 151, epsilon 0.2172 done in 1.1700310707092285\n", "Episode 37, duration 166, epsilon 0.1916 done in 1.1666131019592285\n", "Episode 38, duration 173, epsilon 0.1691 done in 1.2102069854736328\n", "Episode 39, duration 200, epsilon 0.1475 done in 1.4016377925872803\n", "Episode 40, duration 190, epsilon 0.1307 done in 1.4297728538513184\n", "Episode 41, duration 168, epsilon 0.1182 done in 1.1653170585632324\n", "Episode 42, duration 200, epsilon 0.1058 done in 1.3941991329193115\n", "Episode 43, duration 157, epsilon 0.0977 done in 1.0960440635681152\n", "Episode 44, duration 174, epsilon 0.0901 done in 1.2252748012542725\n", "Episode 45, duration 200, epsilon 0.0828 done in 1.3784756660461426\n", "Episode 46, duration 173, epsilon 0.0776 done in 1.213353157043457\n", "Episode 47, duration 200, epsilon 0.0726 done in 1.4108171463012695\n", "Episode 48, duration 162, epsilon 0.0692 done in 1.1317131519317627\n", "Episode 49, duration 176, epsilon 0.0661 done in 1.2339060306549072\n", "Episode 50, duration 175, epsilon 0.0635 done in 1.2066349983215332\n", "Episode 51, duration 179, epsilon 0.0613 done in 1.2518279552459717\n", "Episode 52, duration 199, epsilon 0.0593 done in 1.370851993560791\n", "Episode 53, duration 157, epsilon 0.0579 done in 1.0979859828948975\n", "Episode 54, duration 152, epsilon 0.0568 done in 1.072814702987671\n", "Episode 55, duration 157, epsilon 0.0558 done in 1.0916438102722168\n", "Episode 56, duration 165, epsilon 0.0549 done in 1.1536707878112793\n", "Episode 57, duration 155, epsilon 0.0542 done in 1.0730979442596436\n", "Episode 58, duration 164, epsilon 0.0536 done in 1.1358988285064697\n", "Episode 59, duration 173, epsilon 0.0530 done in 1.1930968761444092\n", "Episode 60, duration 200, epsilon 0.0525 done in 1.3894200325012207\n", "Episode 61, duration 162, epsilon 0.0521 done in 1.123974084854126\n", "Episode 62, duration 191, epsilon 0.0517 done in 1.3286561965942383\n", "Episode 63, duration 182, epsilon 0.0514 done in 1.265678882598877\n", "Episode 64, duration 200, epsilon 0.0512 done in 1.3921267986297607\n", "Episode 65, duration 155, epsilon 0.0510 done in 1.0708959102630615\n", "Episode 66, duration 122, epsilon 0.0509 done in 0.8449039459228516\n", "Episode 67, duration 169, epsilon 0.0508 done in 1.171288013458252\n", "Episode 68, duration 155, epsilon 0.0506 done in 1.065803050994873\n", "Episode 69, duration 171, epsilon 0.0505 done in 1.1903810501098633\n", "Episode 70, duration 158, epsilon 0.0505 done in 1.0794060230255127\n", "Episode 71, duration 135, epsilon 0.0504 done in 0.9425179958343506\n", "Episode 72, duration 173, epsilon 0.0503 done in 1.2061538696289062\n", "Episode 73, duration 134, epsilon 0.0503 done in 1.0127298831939697\n", "Episode 74, duration 140, epsilon 0.0503 done in 1.1216182708740234\n", "Episode 75, duration 166, epsilon 0.0502 done in 1.1763079166412354\n", "Episode 76, duration 184, epsilon 0.0502 done in 1.2722370624542236\n", "Episode 77, duration 200, epsilon 0.0502 done in 1.3994901180267334\n", "Episode 78, duration 162, epsilon 0.0501 done in 1.1298630237579346\n", "Episode 79, duration 141, epsilon 0.0501 done in 0.9746568202972412\n", "Episode 80, duration 165, epsilon 0.0501 done in 1.2154066562652588\n", "Episode 81, duration 163, epsilon 0.0501 done in 1.1442301273345947\n", "Episode 82, duration 164, epsilon 0.0501 done in 1.1488800048828125\n", "Episode 83, duration 128, epsilon 0.0501 done in 0.8947670459747314\n", "Episode 84, duration 94, epsilon 0.0501 done in 0.6654331684112549\n", "Episode 85, duration 115, epsilon 0.0500 done in 0.7962470054626465\n", "Episode 86, duration 105, epsilon 0.0500 done in 0.7088890075683594\n", "Episode 87, duration 134, epsilon 0.0500 done in 0.8996789455413818\n", "Episode 88, duration 127, epsilon 0.0500 done in 0.8609750270843506\n", "Episode 89, duration 137, epsilon 0.0500 done in 0.9249551296234131\n", "Episode 90, duration 126, epsilon 0.0500 done in 0.8459548950195312\n", "Episode 91, duration 116, epsilon 0.0500 done in 0.7772917747497559\n", "Episode 92, duration 91, epsilon 0.0500 done in 0.6441352367401123\n", "Episode 93, duration 91, epsilon 0.0500 done in 0.6781599521636963\n", "Episode 94, duration 127, epsilon 0.0500 done in 0.8687968254089355\n", "Episode 95, duration 129, epsilon 0.0500 done in 0.8872959613800049\n", "Episode 96, duration 111, epsilon 0.0500 done in 0.7792479991912842\n", "Episode 97, duration 131, epsilon 0.0500 done in 0.8846781253814697\n", "Episode 98, duration 141, epsilon 0.0500 done in 0.936420202255249\n", "Episode 99, duration 105, epsilon 0.0500 done in 0.6909458637237549\n", "Episode 100, duration 146, epsilon 0.0500 done in 0.9847288131713867\n", "Episode 101, duration 127, epsilon 0.0500 done in 0.8729619979858398\n", "Episode 102, duration 119, epsilon 0.0500 done in 0.7925238609313965\n", "Episode 103, duration 107, epsilon 0.0500 done in 0.7179620265960693\n", "Episode 104, duration 151, epsilon 0.0500 done in 1.0024669170379639\n", "Episode 105, duration 110, epsilon 0.0500 done in 0.7429869174957275\n", "Episode 106, duration 105, epsilon 0.0500 done in 0.7670748233795166\n", "Episode 107, duration 101, epsilon 0.0500 done in 0.7440199851989746\n", "Episode 108, duration 120, epsilon 0.0500 done in 0.864722728729248\n", "Episode 109, duration 106, epsilon 0.0500 done in 0.7149121761322021\n", "Episode 110, duration 112, epsilon 0.0500 done in 0.8089501857757568\n", "Episode 111, duration 93, epsilon 0.0500 done in 0.6416912078857422\n", "Episode 112, duration 103, epsilon 0.0500 done in 0.7595870494842529\n", "Episode 113, duration 99, epsilon 0.0500 done in 0.6787068843841553\n", "Episode 114, duration 90, epsilon 0.0500 done in 0.6839339733123779\n", "Episode 115, duration 99, epsilon 0.0500 done in 0.70619797706604\n", "Episode 116, duration 113, epsilon 0.0500 done in 0.7977390289306641\n", "Episode 117, duration 107, epsilon 0.0500 done in 0.7442677021026611\n", "Episode 118, duration 104, epsilon 0.0500 done in 0.7060117721557617\n", "Episode 119, duration 96, epsilon 0.0500 done in 0.6665740013122559\n", "Episode 120, duration 99, epsilon 0.0500 done in 0.6954801082611084\n", "Episode 121, duration 109, epsilon 0.0500 done in 0.7392399311065674\n", "Episode 122, duration 122, epsilon 0.0500 done in 0.822786808013916\n", "Episode 123, duration 113, epsilon 0.0500 done in 0.7565281391143799\n", "Episode 124, duration 132, epsilon 0.0500 done in 0.8871691226959229\n", "Episode 125, duration 116, epsilon 0.0500 done in 0.7746288776397705\n", "Episode 126, duration 115, epsilon 0.0500 done in 0.7686679363250732\n", "Episode 127, duration 131, epsilon 0.0500 done in 0.8958530426025391\n", "Episode 128, duration 161, epsilon 0.0500 done in 1.087691068649292\n", "Episode 129, duration 160, epsilon 0.0500 done in 1.0682952404022217\n", "Episode 130, duration 169, epsilon 0.0500 done in 1.1190180778503418\n", "Episode 131, duration 171, epsilon 0.0500 done in 1.1476161479949951\n", "Episode 132, duration 195, epsilon 0.0500 done in 1.300534963607788\n", "Episode 133, duration 200, epsilon 0.0500 done in 1.3386471271514893\n", "Episode 134, duration 200, epsilon 0.0500 done in 1.3347461223602295\n", "Episode 135, duration 200, epsilon 0.0500 done in 1.3350942134857178\n", "Episode 136, duration 200, epsilon 0.0500 done in 1.3486897945404053\n", "Episode 137, duration 200, epsilon 0.0500 done in 1.3442697525024414\n", "Episode 138, duration 200, epsilon 0.0500 done in 1.3415100574493408\n", "Episode 139, duration 200, epsilon 0.0500 done in 1.3514528274536133\n", "Episode 140, duration 200, epsilon 0.0500 done in 1.324605941772461\n", "Episode 141, duration 200, epsilon 0.0500 done in 1.3301770687103271\n", "Episode 142, duration 200, epsilon 0.0500 done in 1.3422541618347168\n", "Episode 143, duration 200, epsilon 0.0500 done in 1.3305821418762207\n", "Episode 144, duration 200, epsilon 0.0500 done in 1.346459150314331\n", "Episode 145, duration 200, epsilon 0.0500 done in 1.3368988037109375\n", "Episode 146, duration 200, epsilon 0.0500 done in 1.324411153793335\n", "Episode 147, duration 200, epsilon 0.0500 done in 1.3370959758758545\n", "Episode 148, duration 200, epsilon 0.0500 done in 1.3372869491577148\n", "Episode 149, duration 200, epsilon 0.0500 done in 1.3436431884765625\n", "Episode 150, duration 200, epsilon 0.0500 done in 1.3790981769561768\n", "Episode 151, duration 200, epsilon 0.0500 done in 1.4742920398712158\n", "Episode 152, duration 200, epsilon 0.0500 done in 1.44557523727417\n", "Episode 153, duration 200, epsilon 0.0500 done in 1.4291670322418213\n", "Episode 154, duration 200, epsilon 0.0500 done in 1.4063260555267334\n", "Episode 155, duration 200, epsilon 0.0500 done in 1.405716896057129\n", "Episode 156, duration 188, epsilon 0.0500 done in 1.293168067932129\n", "Episode 157, duration 184, epsilon 0.0500 done in 1.514660120010376\n", "Episode 158, duration 200, epsilon 0.0500 done in 1.5891220569610596\n", "Episode 159, duration 184, epsilon 0.0500 done in 1.4287211894989014\n", "Episode 160, duration 200, epsilon 0.0500 done in 1.6346118450164795\n", "Episode 161, duration 200, epsilon 0.0500 done in 1.6076791286468506\n", "Episode 162, duration 200, epsilon 0.0500 done in 1.5353331565856934\n", "Episode 163, duration 188, epsilon 0.0500 done in 1.4301581382751465\n", "Episode 164, duration 164, epsilon 0.0500 done in 1.2197020053863525\n", "Episode 165, duration 119, epsilon 0.0500 done in 0.8470640182495117\n", "Episode 166, duration 91, epsilon 0.0500 done in 0.705880880355835\n", "Episode 167, duration 112, epsilon 0.0500 done in 0.8511989116668701\n", "Episode 168, duration 121, epsilon 0.0500 done in 0.9285330772399902\n", "Episode 169, duration 165, epsilon 0.0500 done in 1.4010276794433594\n", "Episode 170, duration 200, epsilon 0.0500 done in 1.4931089878082275\n", "Episode 171, duration 158, epsilon 0.0500 done in 1.1297760009765625\n", "Episode 172, duration 184, epsilon 0.0500 done in 1.2905008792877197\n", "Episode 173, duration 161, epsilon 0.0500 done in 1.2188148498535156\n", "Episode 174, duration 169, epsilon 0.0500 done in 1.2725789546966553\n", "Episode 175, duration 173, epsilon 0.0500 done in 1.2186980247497559\n", "Episode 176, duration 193, epsilon 0.0500 done in 1.338731050491333\n", "Episode 177, duration 166, epsilon 0.0500 done in 1.316972017288208\n", "Episode 178, duration 198, epsilon 0.0500 done in 1.4319920539855957\n", "Episode 179, duration 195, epsilon 0.0500 done in 1.3895812034606934\n", "Episode 180, duration 200, epsilon 0.0500 done in 1.4180948734283447\n", "Episode 181, duration 200, epsilon 0.0500 done in 1.3860301971435547\n", "Episode 182, duration 200, epsilon 0.0500 done in 1.544933795928955\n", "Episode 183, duration 200, epsilon 0.0500 done in 1.376188039779663\n", "Episode 184, duration 200, epsilon 0.0500 done in 1.37398099899292\n", "Episode 185, duration 191, epsilon 0.0500 done in 1.3100950717926025\n", "Episode 186, duration 200, epsilon 0.0500 done in 1.3645508289337158\n", "Episode 187, duration 200, epsilon 0.0500 done in 1.3671917915344238\n", "Episode 188, duration 71, epsilon 0.0500 done in 0.4840209484100342\n", "Episode 189, duration 170, epsilon 0.0500 done in 1.1554310321807861\n", "Episode 190, duration 200, epsilon 0.0500 done in 1.3751280307769775\n", "Episode 191, duration 200, epsilon 0.0500 done in 1.3681631088256836\n", "Episode 192, duration 190, epsilon 0.0500 done in 1.296057939529419\n", "Episode 193, duration 200, epsilon 0.0500 done in 1.369154930114746\n", "Episode 194, duration 163, epsilon 0.0500 done in 1.1157150268554688\n", "Episode 195, duration 172, epsilon 0.0500 done in 1.1656968593597412\n", "Episode 196, duration 200, epsilon 0.0500 done in 1.3778979778289795\n", "Episode 197, duration 200, epsilon 0.0500 done in 1.3623039722442627\n", "Episode 198, duration 200, epsilon 0.0500 done in 1.3735718727111816\n", "Episode 199, duration 200, epsilon 0.0500 done in 1.3809599876403809\n", "Episode 200, duration 200, epsilon 0.0500 done in 1.380000114440918\n", "Episode 201, duration 200, epsilon 0.0500 done in 1.3676831722259521\n", "Episode 202, duration 200, epsilon 0.0500 done in 1.3898499011993408\n", "Episode 203, duration 200, epsilon 0.0500 done in 1.3679068088531494\n", "Episode 204, duration 200, epsilon 0.0500 done in 1.3534259796142578\n", "Episode 205, duration 200, epsilon 0.0500 done in 1.36385178565979\n", "Episode 206, duration 186, epsilon 0.0500 done in 1.2612950801849365\n", "Episode 207, duration 177, epsilon 0.0500 done in 1.2120819091796875\n", "Episode 208, duration 80, epsilon 0.0500 done in 0.5482821464538574\n", "Episode 209, duration 170, epsilon 0.0500 done in 1.1629600524902344\n", "Episode 210, duration 180, epsilon 0.0500 done in 1.2311630249023438\n", "Episode 211, duration 200, epsilon 0.0500 done in 1.3580238819122314\n", "Episode 212, duration 200, epsilon 0.0500 done in 1.3839151859283447\n", "Episode 213, duration 200, epsilon 0.0500 done in 1.3652369976043701\n", "Episode 214, duration 200, epsilon 0.0500 done in 1.3564629554748535\n", "Episode 215, duration 200, epsilon 0.0500 done in 1.3642499446868896\n", "Episode 216, duration 200, epsilon 0.0500 done in 1.3575990200042725\n", "Episode 217, duration 200, epsilon 0.0500 done in 1.3559317588806152\n", "Episode 218, duration 200, epsilon 0.0500 done in 1.3548388481140137\n", "Episode 219, duration 200, epsilon 0.0500 done in 1.3770561218261719\n", "Episode 220, duration 200, epsilon 0.0500 done in 1.478581190109253\n", "Episode 221, duration 183, epsilon 0.0500 done in 1.3361449241638184\n", "Episode 222, duration 162, epsilon 0.0500 done in 1.1335511207580566\n", "Episode 223, duration 196, epsilon 0.0500 done in 1.3585941791534424\n", "Episode 224, duration 200, epsilon 0.0500 done in 1.4101898670196533\n", "Episode 225, duration 200, epsilon 0.0500 done in 1.4771928787231445\n", "Episode 226, duration 200, epsilon 0.0500 done in 1.5035011768341064\n", "Episode 227, duration 119, epsilon 0.0500 done in 0.856168270111084\n", "Episode 228, duration 200, epsilon 0.0500 done in 1.439486026763916\n", "Episode 229, duration 200, epsilon 0.0500 done in 1.316842794418335\n", "Episode 230, duration 200, epsilon 0.0500 done in 1.5013949871063232\n", "Episode 231, duration 200, epsilon 0.0500 done in 1.3974170684814453\n", "Episode 232, duration 200, epsilon 0.0500 done in 1.3904950618743896\n", "Episode 233, duration 200, epsilon 0.0500 done in 1.399580955505371\n", "Episode 234, duration 200, epsilon 0.0500 done in 1.395763874053955\n", "Episode 235, duration 200, epsilon 0.0500 done in 1.4145090579986572\n", "Episode 236, duration 200, epsilon 0.0500 done in 1.4335989952087402\n", "Episode 237, duration 200, epsilon 0.0500 done in 1.4076240062713623\n", "Episode 238, duration 146, epsilon 0.0500 done in 1.0287110805511475\n", "Episode 239, duration 192, epsilon 0.0500 done in 1.3662900924682617\n", "Episode 240, duration 28, epsilon 0.0500 done in 0.20001626014709473\n", "Episode 241, duration 148, epsilon 0.0500 done in 1.019770860671997\n", "Episode 242, duration 115, epsilon 0.0500 done in 0.8170418739318848\n", "Episode 243, duration 154, epsilon 0.0500 done in 1.069025993347168\n", "Episode 244, duration 200, epsilon 0.0500 done in 1.3859591484069824\n", "Episode 245, duration 200, epsilon 0.0500 done in 1.4307832717895508\n", "Episode 246, duration 200, epsilon 0.0500 done in 1.416254997253418\n", "Episode 247, duration 200, epsilon 0.0500 done in 1.357828140258789\n", "Episode 248, duration 200, epsilon 0.0500 done in 1.391690969467163\n", "Episode 249, duration 200, epsilon 0.0500 done in 1.431091070175171\n", "Episode 250, duration 200, epsilon 0.0500 done in 1.4289429187774658\n" ] } ], "source": [ "# Hyperparameters\n", "config = {}\n", "config['nb_hidden'] = 128 # number of hidden neurons in each layer\n", "config['batch_size'] = 128 # number of transitions sampled from the replay buffer\n", "config['gamma'] = 0.99 # discount factor\n", "config['eps_start'] = 0.9 # starting value of epsilon\n", "config['eps_end'] = 0.05 # final value of epsilon\n", "config['eps_decay'] = 1000 # rate of exponential decay of epsilon, higher means a slower decay\n", "config['learning_rate'] = 1e-3 # learning rate of the optimizer\n", "config['target_update_period'] = 120 # update period (in steps) of the target network\n", "config['buffer_limit'] = 10000 # maximum number of transitions in the replay buffer\n", "\n", "# Create the environment\n", "env = gym.make('CartPole-v0')\n", "\n", "# Create the agent\n", "agent = DQNAgent(env, config)\n", "\n", "# Train the agent\n", "agent.train(num_episodes=250)\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Returns')" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.plot(agent.episode_durations)\n", "plt.xlabel(\"Episodes\")\n", "plt.ylabel(\"Returns\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**A:** Training is very slow and unstable. Different runs lead to very different convergence profiles. There is much to improve..." ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }