{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Q-learning " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "\n", "if IN_COLAB:\n", " !pip install -U gymnasium pygame moviepy\n", " !pip install gymnasium[box2d]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "rng = np.random.default_rng()\n", "import matplotlib.pyplot as plt\n", "import os\n", "\n", "import gymnasium as gym\n", "print(\"gym version:\", gym.__version__)\n", "\n", "from moviepy.editor import ImageSequenceClip, ipython_display\n", "\n", "class GymRecorder(object):\n", " \"\"\"\n", " Simple wrapper over moviepy to generate a .gif with the frames of a gym environment.\n", " \n", " The environment must have the render_mode `rgb_array_list`.\n", " \"\"\"\n", " def __init__(self, env):\n", " self.env = env\n", " self._frames = []\n", "\n", " def record(self, frames):\n", " \"To be called at the end of an episode.\"\n", " for frame in frames:\n", " self._frames.append(np.array(frame))\n", "\n", " def make_video(self, filename):\n", " \"Generates the gif video.\"\n", " directory = os.path.dirname(os.path.abspath(filename))\n", " if not os.path.exists(directory):\n", " os.mkdir(directory)\n", " self.clip = ImageSequenceClip(list(self._frames), fps=self.env.metadata[\"render_fps\"])\n", " self.clip.write_gif(filename, fps=self.env.metadata[\"render_fps\"], loop=1)\n", " del self._frames\n", " self._frames = []\n", "\n", "def running_average(x, N):\n", " kernel = np.ones(N) / N\n", " return np.convolve(x, kernel, mode='same')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this short exercise, we are going to apply **Q-learning** on the Taxi environment used last time for MC control.\n", "\n", "As a reminder, Q-learning updates the Q-value of a state-action pair **after each transition**, using the update rule:\n", "\n", "$$\\Delta Q(s_t, a_t) = \\alpha \\, (r_{t+1} + \\gamma \\, \\max_{a'} \\, Q(s_{t+1}, a') - Q(s_t, a_t))$$\n", "\n", "**Q:** Update the class you designed for online MC in the last exercise so that it implements Q-learning. \n", "\n", "The main difference is that the `update()` method has to be called after each step of the episode, not at the end. It simplifies a lot the code too (no need to iterate backwards on the episode).\n", "\n", "You can use the following parameters at the beginning, but feel free to change them:\n", "\n", "* Discount factor $\\gamma = 0.9$. \n", "* Learning rate $\\alpha = 0.1$.\n", "* Epsilon-greedy action selection, with an initial exploration parameter of 1.0 and an exponential decay of $10^{-5}$ after each update (i.e. every step!).\n", "* A total number of episodes of 20000.\n", "\n", "Keep the general structure of the class: `train()` for the main loop, `test()` to run one episode without exploration, etc. \n", "\n", "Plot the training and test performance in the end and render the learned deterministic policy for one episode.\n", "\n", "*Note:* if $s_{t+1}$ is terminal (`done` is true after the transition), the target should not be $r_{t+1} + \\gamma \\, \\max_{a'} \\, Q(s_{t+1}, a')$, but simply $r_{t+1}$ as there is no next action." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** Compare the performance of Q-learning to online MC. Experiment with parameters (gamma, epsilon, alpha, etc.)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('deeprl')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "vscode": { "interpreter": { "hash": "932956c8e5d2f79d68ff59e849758b6e4ddbf01f7f22c7d8bb3532c38341d908" } } }, "nbformat": 4, "nbformat_minor": 4 }