{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "kzN-Q9Zv1He1" }, "source": [ "# DQN\n", "\n", "The goal of this exercise is to implement DQN and to apply it to the cartpole balancing problem. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "\n", "if IN_COLAB:\n", " !pip install -U gymnasium pygame moviepy\n", " !pip install gymnasium[box2d]" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZuVpP0LaxKM5", "outputId": "948030f3-cdfb-43a1-882a-6140df11639b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gym version: 0.26.3\n" ] } ], "source": [ "import numpy as np\n", "rng = np.random.default_rng()\n", "import matplotlib.pyplot as plt\n", "import os\n", "from IPython.display import clear_output\n", "from collections import deque\n", "\n", "import gymnasium as gym\n", "print(\"gym version:\", gym.__version__)\n", "\n", "import pygame\n", "from moviepy.editor import ImageSequenceClip, ipython_display\n", "\n", "import tensorflow as tf\n", "import logging\n", "tf.get_logger().setLevel(logging.ERROR)\n", "\n", "class GymRecorder(object):\n", " \"\"\"\n", " Simple wrapper over moviepy to generate a .gif with the frames of a gym environment.\n", " \n", " The environment must have the render_mode `rgb_array_list`.\n", " \"\"\"\n", " def __init__(self, env):\n", " self.env = env\n", " self._frames = []\n", "\n", " def record(self, frames):\n", " \"To be called at the end of an episode.\"\n", " for frame in frames:\n", " self._frames.append(np.array(frame))\n", "\n", " def make_video(self, filename):\n", " \"Generates the gif video.\"\n", " directory = os.path.dirname(os.path.abspath(filename))\n", " if not os.path.exists(directory):\n", " os.mkdir(directory)\n", " self.clip = ImageSequenceClip(list(self._frames), fps=self.env.metadata[\"render_fps\"])\n", " self.clip.write_gif(filename, fps=self.env.metadata[\"render_fps\"], loop=0)\n", " del self._frames\n", " self._frames = []\n", "\n", "def running_average(x, N):\n", " kernel = np.ones(N) / N\n", " return np.convolve(x, kernel, mode='same')" ] }, { "cell_type": "markdown", "metadata": { "id": "EPakRvKRoA79" }, "source": [ "## Cartpole balancing task\n", "\n", "We are going to use the Cartpole balancing problem, which can be loaded with:\n", "\n", "```python\n", "gym.make('CartPole-v0')\n", "```\n", "\n", "States have 4 continuous values (position and speed of the cart, angle and speed of the pole) and 2 discrete outputs (going left or right). The reward is +1 for each transition where the pole is still standing (angle of less than 30° with the vertical). \n", "\n", "In CartPole-v0, the episode ends when the pole fails or after 200 steps. In CartPole-v1, the maximum episode length is 500 steps, which is too long for us, so we stick to v0 here.\n", "\n", "The maximal (undiscounted) return is therefore 200. Can DQN learn this?" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 438 }, "id": "zBkpg0MDoIxJ", "outputId": "58411c0e-4248-4e15-9f5e-b284aeea321e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Return: 19.0\n", "MoviePy - Building file videos/cartpole.gif with imageio.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "