{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "numpy 1.15.4\n", "sklearn 0.20.2\n", "scipy 1.1.0\n", "matplotlib 3.0.2\n", "tensorflow 1.13.1\n", "gym 0.10.9\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -v -p numpy,sklearn,scipy,matplotlib,tensorflow,gym" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**16장 – 강화 학습**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_이 노트북은 15장에 있는 모든 샘플 코드와 연습문제 해답을 가지고 있습니다._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 설정" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "파이썬 2와 3을 모두 지원합니다. 공통 모듈을 임포트하고 맷플롯립 그림이 노트북 안에 포함되도록 설정하고 생성한 그림을 저장하기 위한 함수를 준비합니다:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# 파이썬 2와 파이썬 3 지원\n", "from __future__ import division, print_function, unicode_literals\n", "\n", "# 공통\n", "import numpy as np\n", "import os\n", "import sys\n", "\n", "# 일관된 출력을 위해 유사난수 초기화\n", "def reset_graph(seed=42):\n", " tf.reset_default_graph()\n", " tf.set_random_seed(seed)\n", " np.random.seed(seed)\n", "\n", "# 맷플롯립 설정\n", "from IPython.display import HTML\n", "import matplotlib\n", "import matplotlib.animation as animation\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['axes.labelsize'] = 14\n", "plt.rcParams['xtick.labelsize'] = 12\n", "plt.rcParams['ytick.labelsize'] = 12\n", "\n", "# 한글출력\n", "plt.rcParams['font.family'] = 'NanumBarunGothic'\n", "plt.rcParams['axes.unicode_minus'] = False\n", "\n", "# 그림을 저장할 폴더\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"rl\"\n", "\n", "def save_fig(fig_id, tight_layout=True):\n", " path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format='png', dpi=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# OpenAI 짐(gym)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 노트북에서는 강화 학습 알고리즘을 개발하고 비교할 수 있는 훌륭한 도구인 [OpenAI 짐(gym)](https://gym.openai.com/)을 사용합니다. 짐은 *에이전트*가 학습할 수 있는 많은 환경을 제공합니다. `gym`을 임포트해 보죠:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그다음 MsPacman 환경 버전 0을 로드합니다." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] } ], "source": [ "env = gym.make('MsPacman-v0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`reset()` 메서드를 호출하여 환경을 초기화합니다. 이 메서드는 하나의 관측을 반환합니다:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측은 환경마다 다릅니다. 여기에서는 [width, height, channels] 크기의 3D 넘파이 배열로 저장되어 있는 RGB 이미지입니다(채널은 3개로 빨강, 초록, 파랑입니다). 잠시 후에 보겠지만 다른 환경에서는 다른 오브젝트가 반환될 수 있습니다." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경은 `render()` 메서드를 사용하여 화면에 나타낼 수 있고 렌더링 모드를 고를 수 있습니다(렌더링 옵션은 환경마다 다릅니다). 이 경우에는 `mode=\"rgb_array\"`로 지정해서 넘파이 배열로 환경에 대한 이미지를 받겠습니다:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "img = env.render(mode=\"rgb_array\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지를 그려보죠:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVEAAAGoCAYAAAD/69aTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC+NJREFUeJzt3TGu3MYdB+BlkEsEim4gpDAQwI0DNwZSuEvjSlfw65QjRN2ewZUadyoCuBHiRkAAF4Gb1JKQYzCF8J6e1rvkLn87O5zh93WPK4LD2d2f/pwdDodxHHcALPO72g0AaJkQBQgIUYCAEAUICFGAgBAFCAhRgIAQBQgIUYDA72s34LFhGNw+BazSOI7Dse0qUYDAqirR999/X7sJABdRiQIEVlWJTvnjj3+o3YTVeP+3/518TT9xjM/Meab66RSVKEBAiAIEhChAQIgCBIQoQECIAgSEKEBAiAIEmplsPyeZTLx037mJuaX2Teinum2q0f8J/TRPJQoQEKIAASEKEBjGcT3rIH+4uzvZGIskfGIxCS7lM3OeqX56st9blBng2oQoQECIAgSEKEBAiAIEhChAQIgCBIQoQKCbBUgSS57wd47eJjGX6ic+8nk5z9r6SSUKEOimEnVbG1CDShQgIEQBAkIUICBEAQJCFCAgRAEC3Uxxak2LT/ukrtae9rkVKlGAgBAFCAhRgIAx0UpqjSUZw2pXjffO52WeShQg0E0l6n9MoAaVKEBAiAIEhChAQIgCBIQoQECIAgS6meKUaG16lIn6XMpE/XJUogABIQoQEKIAASEKEBCiAAEhChAQogABIQoQ6GayffIkxKX7Jk9CLLVvqXNN9q3VT1O8d+ftu7V+WkIlChAQogABIQoQGMZxrN2GBx/u7k42ZiuLGQD1TI2nPtnvh2PbVaIAASEKEBCiAAEhChAQogABIQoQEKIAASEKEGhmAZK5xQymlJqoX2pxjPS4S7XW3rVqrR99PzIqUYCAEAUICFGAgBAFCAhRgIAQBQgIUYCAEAUINDPZvqRbPx1w7pglj5uo0U+9aa0P1zhhfm39pBIFCAhRgIAQBQgYE93VGWNZ27jOOVps89q01oe12ttSP6lEAQJCFCAgRAECQhQgIEQBAkIUICBEAQJCFCCwqsn2LU2wram1fmqtvWulH89Tqp/G/fHtKlGAwKoqUUi8ffZm8vUvf/36Ju1gW1SiAAGVKM2bq0AP/52KlGtSiQIEVKI07e2zN7+pLE9VnPfbj+0DS6lEAQJClOa9ffbm6Ljoqe1wTUIUINDNmGjyiNXWHplc61yX7lvqXHe73e79f0//2j417lmyTd67XEvf525ClO06dcnuUp5bcDkPEBjGcazdhgfD05fracwZal3q8MnSatMUp/J6+36M714Mx7arRAECQhQgIEQBAkIUICBEAQJClKY9/pX9y1+/vuhvuAYhChBwxxLNO6wuL/0bEipRgMCqKtG5OxyWqnVnRKnzKaXWYiu90Y913bqfVKIAASEKEBCiAAEhChAQogABIQoQEKIAASEKEFjVZPtaajzts0Wl+qm1Pk4mc/usnaelflKJAgSEKEDA5fxufZcHa6WfcvrwPC31k0oUICBEAQJCFCAgRAECQhQgIEQBAkIUICBEAQJCFCCwqjuWkrsUajwJsaW7Kmifz1su6cNxf3y7ShQgIEQBAkIUICBEAQJCFCAgRAECQhQgIEQBAquabF/L0icLzk3wr7Hv3GTiWvvyUWvvXYuf8VtTiQIEhChAQIgCBIZxHGu34cHw9OXixrQ0hrI1ydjYJV4/fzX5+rc/fHeV49zqfFimVBaM714Mx7arRAECfp2neXMV6OG/u1ZFCrudEKVxr5+/+k0ongrL++3H9oGlXM4DBIQozXv9/NXRS/pT2+GahChAwJgozTs1vmnck1tQiQIEVlWJ1nhiZ6K19s5pdZL4qXHP3sZDfd5yJfpQJQoQEKIAASEKEBCiAAEhChAQogABIUrTHk+o//aH7y76G65BiAIEVjXZvhar4p9nrf10WF1e+vctrbUP16alflKJAgSEKEBAiAIEjInu1jfGslal+qm3hTWm+Kydp6V+UokCBIQoQECIAgSEKEBAiAIEhChAQIgCBIQoQGAYx7F2Gx4MT1+upzFnmJsk3tKEYbi23r4f47sXw7HtKlGAgNs+6cZ//vGvo9v/9Pe/fPb6/d9wDSpRgIBKlOadqkAPX39ckapGuRaVKEBAJUq3DsdCjYlSghCleYeheBiaUJLLeYBAN5Vo8nTAGk8WTCYi1zrXpfuWOtdT+56qTGu26Zx9e3vvEi19n1WiAAEhChAQogABC5AEeltgoRdzv8qb4nQbvX0/LEACUEA3v86zXeaDUpNKFCCgEmUzjIVSgkoUIKASpXmHC42ceh1KEKJ0Q1hSg8t5gMCqKtG5yblLtTapd06NBVNqHbfUZ2LOGhfWaPG4Ndz6M6MSBQgIUYCAEAUICFGAgBAFCAhRgIAQBQgIUYDAqibb17LGSeQ9TfbubYVz713d467t86ISBQgIUYCAEAUIGBPd1RljqTWus6VzLcV71+9xl1CJAgSEKEBAiAIEhChAQIgCBIQoQECIAgSEKECgmcn2a5x8u8Y29UYft6vFifpLnhSqEgUICFGAgBAFCAhRgIAQBQgIUYCAEAUICFGAgBAFCDRzx9Kc5BGrS/dNHie7pUfRrrGf1timOVt672p8n5dSiQIEhChAYBjHsXYbHny4uzvZGAtRfFLjsq7mcXvivVu3qX56st8Px7arRAECQhQgIEQBAkIUICBEAQJCFCAgRAECQhQgIEQBAt0sQJJY8pjUc/R2J0iN8yn13swpda69fSambOV7pRIFCAhRgIAQBQgIUYCAEAUICFGAgBAFCAhRgEA3k+1be/xBi0+MrPFU1DXaylMs545Z8rgJT/sEaIgQBQgIUYBAN2OiaxybmVKrvclxl+7b2nszp0Yfpvu2dMzUrdusEgUICFGAgBAFCAhRgIAQBQgIUYCAEAUICFGAQDeT7fmot6dj9qbW+0M5KlGAQDeVaGtL4QF9UIkCBIQoQECIAgSEKEBAiAIEhChAoJspTq1p8SmKS/V2rr2dz5QtnetSKlGAgBAFCAhRgIAx0Uq2NJbU27n2dj5TtnSuS6lEAQLdVKL+xwRqUIkCBIQoQECIAgSEKEBAiAIEhChAoJspTonWpke11t5Eb+fa2/lM2cq5qkQBAkIUICBEAQJCFCAgRAECQhQgIEQBAkIUINDNZPuppxLOTfpdum/yJMRS+5Y612TfWv00xXt33r5b66clVKIAASEKEBCiAIFhHMfabXjw4e7uZGO2spgBUM/UeOqT/X44tl0lChAQogABIQoQEKIAASEKEBCiAAEhChAQogCBZhYgmVvMAKAGlShAQIgCBIQoQGBVC5AMw7CexgA8Mo6jBUgArq2ZX+ev4aef/rzb7Xa7b77592d/P3b/2rWPO3XMUseFUv75xRef/f3XX36p1JL6NnE5f054HrpGqD0+7jnHvNZxoZT78DwMzceh2mugupwHKGATl/OnKtBLKtNrHbf0MaGkqSrz/rVT1WqvVKIAgU1UoqccqxBvecxbHhduZWsVqUoUILDpSrRGFajypCfHqs3D6U+9U4kCBDZdiQLLXFKB9j42qhIFCGy6EvXrPGR6rzLPoRIFCGzi3vl7l9wldM0KsdZx4dqmfnk/nB96uL11p+6d39TlfK1bLi9ZgARa10tonsvlPEBgU5fzh261nujcMW9xXLimXi/Zp1gKD6CATVeiAOdSiQIUIEQBAkIUICBEAQJCFCAgRAECQhQgsKl75+FWft5/dXT7V3c/37gllGayPVzZ4wC9D83DUBWm7THZHqAAlShcybEKdOrfTP071kclClCAEAUICFGAgBAFCAhRgIBf5+HKzBPtk1/nAQpQiUIBbvvsz6lKVIgCnMHlPEABQhQgIEQBAkIUICBEAQJCFCAgRAECQhQgIEQBAkIUICBEAQJCFCAgRAECQhQgIEQBAkIUICBEAQJCFCAgRAECQhQgIEQBAkIUICBEAQJCFCAgRAECQhQgIEQBAkIUICBEAQLDOI612wDQLJUoQECIAgSEKEBAiAIEhChAQIgCBIQoQECIAgSEKEBAiAIEhChAQIgCBIQoQECIAgSEKEBAiAIEhChAQIgCBIQoQECIAgSEKEBAiAIEhChAQIgCBIQoQOD/Vi1MGHYh7UUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5,6))\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"MsPacman\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1980년대로 돌아오신 걸 환영합니다! :)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 환경에서는 렌더링된 이미지가 관측과 동일합니다(하지만 많은 경우에 그렇지 않습니다):" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(img == obs).all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 그리기 위한 유틸리티 함수를 만들겠습니다:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def plot_environment(env, figsize=(5,6)):\n", " plt.figure(figsize=figsize)\n", " img = env.render(mode=\"rgb_array\")\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 어떻게 다루는지 보겠습니다. 에이전트는 \"행동 공간\"(가능한 행동의 모음)에서 하나의 행동을 선택합니다. 이 환경의 액션 공간을 다음과 같습니다:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(9)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Discrete(9)`는 가능한 행동이 정수 0에서부터 8까지있다는 의미입니다. 이는 조이스틱의 9개의 위치(0=중앙, 1=위, 2=오른쪽, 3=왼쪽, 4=아래, 5=오른쪽위, 6=왼쪽위, 7=오른쪽아래, 8=왼쪽아래)에 해당합니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그다음 환경에게 플레이할 행동을 알려주고 게임의 다음 단계를 진행시킵니다. 왼쪽으로 110번을 진행하고 왼쪽아래로 40번을 진행해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "env.reset()\n", "for step in range(110):\n", " env.step(3) #왼쪽\n", "for step in range(40):\n", " env.step(8) #왼쪽아래" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "어디에 있을까요?" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASMAAAFrCAYAAACaK+8sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACthJREFUeJzt3U2K3EgaBuDMYS4xePoKXhh6Y7AXBkPvB3yNrl37CNO7OkfB7BsMtRhDbRq8mCt0mz6GejFdhZxOqSplSfFGxPOsyvlDhJTyyxfKT8rjMAwHgNL+VnoCAIeDMAJCCCMggjACIggjIIIwAiIIIyCCMAIiCCMggjACIvy99ATGjseja1OgYcMwHKeeiwqj33/8sfQUgEKiwmjKP//zj9JT2NTv//rj7OOtb3cPev1sp7Z7jnNGQARhBEQQRkAEYQREEEZABGEERBBGQIQq+oymzPUyTPVxXNr3sccYS6y1HXNz2mOMkmOX3L6t57TkPWuOsYTKCIggjIAIx6TfTft8dXV2Mr22zre+3T3o9bOd2u5n19eTF8qqjIAIwgiIIIyACMIIiCCMgAjCCIggjIAIVV8OssSS22FOqalXZM3tbl2vn6vLQQAOwggIIYyACMIIiCCMgAjCCIggjIAI3fUZ7SHx9qR8O5/rtlRGQARhBESwTNvAHiV1D2V7ml4/173mpDICIggjIIIwAiIIIyCCMAIiCCMgQndf7ff61Wnidvdg6/3e0ueqMgIiCCMggjACIggjIIIwAiIIIyCCMAIiVN1ntOSXLi+9k17JMeZ6SFoZo+TYLYyR+n9gCZUREEEYARGOwzCUnsODz1dXZyfTUss79GBqaffs+vo49R6VERBBGAERhBEQQRgBEYQREEEYARGEERChistB5trRp6zVm7RHK/yS7ZuyZk/WmvMqJXV/9Hp8zlEZARGEERBBGAERhBEQQRgBEYQREEEYARGq6DNa05q3Ri05xqUS51Ra4j7p9fg8HFRGQAhhBETobpm2Rylautw9J3FOpSXuk16Pz8NBZQSEEEZABGEERBBGQARhBEQQRkAEYQREiOozSu1/2FrqdqfOq5Re98ea2z1cTz+nMgIiCCMggjACIggjIIIwAiIIIyCCMAIiRPUZXWrJT/sm3tZzyZxaGaPk2K2Mcam15jT3niVURkAEYQREOA7DUHoOD47f/Zwzmb/sVaLCErUdn8NvPx2nnlMZARGEERBBGAERhBEQQRgBEYQREEEYARGiLgeZ65m41B49FmvOdy17XDZQE/vj2+213SojIIIwAiIIIyCCMAIiCCMggjACIkR9tb+HPe6kl2jN7U7cV0u+fnYsfKn0dquMgAjCCIjQ3TKtdClaSq/bPafXfZK63SojIIIwAiIIIyCCMAIiCCMggjACIggjIEJUn9Gl/Q973IEutSeD/TkWvnbpPhmup59TGQERoiojuPfm3cuzj9/e3O08E/YijIgyFUKnzwul9limARGEETHGVdFp5XN7c/fFY49VUNRHGBHrNICmHqMNwgiI0N0J7EtvuTnXy3Tpe9Z6/V5jlHS6DNtiWdbC57TH8bkXlREQQRgBEbpbpl1aiu7x6xmpY5R0e3N39tu1NZdrLXxOrfy6y+GgMgJCdFcZkW1cAZ3rK9qiQiKDMCLKOGTOBY4QapdlGhBBGBFl3GE99ff9v2mLZRoxpsJGCPVBZQREiKqM9riN7KUS5zQntYekBY6Fr625T1RGQARhBEQQRkAEYQREEEZABGEERBBGQISoPqM9pN5yc2u9bvecXvdJ6narjIAIwgiI0N0yrXQpWsqa213bZRFTHAtZVEZABGEERBBGQARhBEQQRkAEYQREEEZAhOMwDKXn8OD43c85k/nLXE9Nar8G/ajt+Bx+++k49ZzKCIggjIAIwgiIIIyACMIIiCCMgAhV30JkydeaqXe540uXfk5LPtdWxrjUWnOae88SKiMggjACIujAfkRtHa70pbbjUwc2EE8YARGEERBBGAERhBEQQRgBEYQRECHqcpA1f6k0scdiSuIlAGuOsaZW9kkLx+faVEZABGEERBBGQARhBEQQRkAEYQREEEZAhKg+oz200PdxOOiROdXK/mhljCVURkAEYQRE6G6ZVroUXcse21HTvmplf7QyxhIqIyCCMAIiCCMggjACIggjIIIwAiIIIyBCFX1GJfsiUnsyEtlX+0vtS1pyq1qVERBBGAERhBEQQRgBEYQREEEYARGEERChij6j2rx59/Ls47c3d6uNkXh70jXnVHLstea0xxhz/Txr7aslYywhjFY0FUKnz68ZStAKyzQggspoJeOq6Pbm7qt/j1/z5t3Lb66OEi8DWHNOJceuaYwlc0rct4eDymgztzd3XwXOuceA/xNGQATLtA2cnsh+7MQ2oDICQggjIIIw2oAT13A5YQRE6O4E9pLbYU457b8Y9xONK6HTzusSJ7T36BVZc99OKdnLlGjL43lv3YXRlsYhcy5wfKsG0yzTgAjCaEXjE9VTf9//G/iSZdpKpsJGCMHTqIyACMIIiGCZtoGa7gg4N6c9tmMte2xf4ue6h73u9KgyAiIIIyCCMAIiOGe0gV5vT1pSK7dfTdznbjsLdEUYARGEERBBGAERhBEQQRgBEXy1H6y2OyfWZo/92wKXgwBdEUZABGEERBBGQARhBEQQRkAEYQRE0Ge0gcRbhy5R03bUNNc5rWzHEiojIIIwAiJYpm2glZK6pu2oaa5zErfDnR6BrggjIIIwAiIIIyCCMAIiCCMggjACInTXZ9RzH8fWatuO2uZ7TgvbcE9lBESoujL637//+9Vjz9+/fnju/m8gX7VhdC6Ixo8/f//6i7+BbJZpQITqKqPHlmbj16mIoB7NVUbP378WQlCh5sIIqFN1y7THTJ3YvnfpbT2X/LTvWmPM9ZC0MkbJsVsYo+TxOfeeJVRGQARhBEQ4DsNQeg4PPl9dnZ3MuBQcL8POnah+7Hlge1NLu2fX18ep96iMgAjCCIhQXRg9tY/IEg3qUl0YAW2qts9oqvJREUGdVEZABGEERKhimTbXjg60QWUERBBGQISoy0GOx2POZIDVDcMweTlIFeeMlvrw4fvD4XA4vH3768PfY2/f/lrFGPTtlxcvHv7+4dOngjPZlmUaEKHZZdqHD98/VCXnKpaxpdXLU8dQHbHEfUU0roZqr5K6XKadLpueGkxpY9CvqbC5f/yXFy+qDKQplmlAhGYro1PnKpgax4B7P3z6dHYpVyuVERChm8poj0pFNcRWHjuZ3YJuwghqdXqi+lwItbBcs0wDInRTGTmBTc1aqHwe02zT4+Hw9H6fbwmOp4whmFhi6pzQuM/o9LF0c02PlmlAhKaXaXt0ROu6ppRaqqGnanqZNrbHFfWu2mcLNS7HplimAfG6qYyA8lRGQDxhBEQQRkAEYQREEEZAhKabHqnLx+tXD3+/uvr46GvOueR9U6+lDJUREEGfEcU9pWo5fc3c8+Pnzr3vKRUY2+jy10Gox30gPLYEu9RU6IzHu3+NUCrPMg2IoDKiCqfV01QVpcKplzCiCk89Z/Tx+pVAqpRlGhBBZUSzxt+cjU9U+zYtk8oIiKDPiGKe+lX+XN/Q3GsfG0tVtL+5PiNhBOzGzdWAeMIIiCCMgAjCCIggjIAIwgiIIIyACMIIiCCMgAjCCIggjIAIwgiIIIyACMIIiCCMgAjCCIggjIAIwgiIIIyACMIIiCCMgAjCCIggjIAIwgiIEPUjjkC/VEZABGEERBBGQARhBEQQRkAEYQREEEZABGEERBBGQARhBEQQRkAEYQREEEZABGEERBBGQARhBEQQRkAEYQREEEZABGEERBBGQARhBEQQRkAEYQREEEZAhD8B961RhWBu3+kAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_environment(env)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "사실 `step()` 함수는 여러 개의 중요한 객체를 반환해 줍니다:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "obs, reward, done, info = env.step(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "앞서 본 것처럼 관측은 보이는 환경을 설명합니다. 여기서는 210x160 RGB 이미지입니다:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경은 마지막 스텝에서 받을 수 있는 보상을 알려 줍니다:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reward" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "게임이 종료되면 환경은 `done=True`를 반환합니다:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "done" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "마지막으로 `info`는 환경의 내부 상태에 관한 추가 정보를 제공하는 딕셔너리입니다. 디버깅에는 유용하지만 에이전트는 학습을 위해서 이 정보를 사용하면 안됩니다(학습이 아니고 속이는 셈이므로)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ale.lives': 3}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "10번의 스텝마다 랜덤한 방향을 선택하는 식으로 전체 게임(3개의 팩맨)을 플레이하고 각 프레임을 저장해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "frames = []\n", "\n", "n_max_steps = 1000\n", "n_change_steps = 10\n", "\n", "obs = env.reset()\n", "for step in range(n_max_steps):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " if step % n_change_steps == 0:\n", " action = env.action_space.sample() # play randomly\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 애니메이션으로 한번 보죠:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def update_scene(num, frames, patch):\n", " plt.close() # 이전 그래프를 닫지 않으면 두 개의 그래프가 출력되는 matplotlib의 버그로 보입니다.\n", " patch.set_data(frames[num])\n", " return patch,\n", "\n", "def plot_animation(frames, figsize=(5,6), repeat=False, interval=40):\n", " fig = plt.figure(figsize=figsize)\n", " patch = plt.imshow(frames[0])\n", " plt.axis('off')\n", " return animation.FuncAnimation(fig, update_scene, fargs=(frames, patch), \n", " frames=len(frames), repeat=repeat, interval=interval)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames)\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 더 이상 사용하지 않으면 환경을 종료하여 자원을 반납합니다:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "첫 번째 에이전트를 학습시키기 위해 간단한 Cart-Pole 환경을 사용하겠습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 간단한 Cart-Pole 환경" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cart-Pole은 아주 간단한 환경으로 왼쪽이나 오른쪽으로 움직일 수 있는 카트와 카트 위에 수직으로 서 있는 막대로 구성되어 있습니다. 에이전트는 카트를 왼쪽이나 오른쪽으로 움직여서 막대가 넘어지지 않도록 유지시켜야 합니다." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] } ], "source": [ "env = gym.make(\"CartPole-v0\")" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.01592466, -0.02766193, -0.02049984, 0.01750777])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측은 4개의 부동소수로 구성된 1D 넘파이 배열입니다. 각각 카트의 수평 위치, 속도, 막대의 각도(0=수직), 각속도를 나타냅니다. 이 환경을 렌더링하려면 먼저 몇 가지 이슈를 해결해야 합니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 렌더링 이슈 해결하기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "일부 환경(Cart-Pole을 포함하여)은 `rgb_array` 모드를 설정하더라도 별도의 창을 띄우기 위해 디스플레이 접근이 필수적입니다. 일반적으로 이 창을 무시하면 됩니다. 주피터가 헤드리스(headless) 서버로 (즉 스크린이 없이) 실행중이면 예외가 발생합니다. 이를 피하는 한가지 방법은 Xvfb 같은 가짜 X 서버를 설치하는 것입니다. `xvfb-run` 명령을 사용해 주피터를 실행합니다:\n", "\n", " $ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n", " \n", "주피터가 헤드리스 서버로 실행 중이지만 Xvfb를 설치하기 번거롭다면 Cart-Pole에 대해서는 다음 렌더링 함수를 사용할 수 있습니다:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw\n", "\n", "try:\n", " from pyglet.gl import gl_info\n", " openai_cart_pole_rendering = True # 문제없음, OpenAI 짐의 렌더링 함수를 사용합니다\n", "except Exception:\n", " openai_cart_pole_rendering = False # 가능한 X 서버가 없다면, 자체 렌더링 함수를 사용합니다\n", "\n", "def render_cart_pole(env, obs):\n", " if openai_cart_pole_rendering:\n", " # OpenAI 짐의 렌더링 함수를 사용합니다\n", " return env.render(mode=\"rgb_array\")\n", " else:\n", " # Cart-Pole 환경을 위한 렌더링 (OpenAI 짐이 처리할 수 없는 경우)\n", " img_w = 600\n", " img_h = 400\n", " cart_w = img_w // 12\n", " cart_h = img_h // 15\n", " pole_len = img_h // 3.5\n", " pole_w = img_w // 80 + 1\n", " x_width = 2\n", " max_ang = 0.2\n", " bg_col = (255, 255, 255)\n", " cart_col = 0x000000 # 파랑 초록 빨강\n", " pole_col = 0x669acc # 파랑 초록 빨강\n", "\n", " pos, vel, ang, ang_vel = obs\n", " img = Image.new('RGB', (img_w, img_h), bg_col)\n", " draw = ImageDraw.Draw(img)\n", " cart_x = pos * img_w // x_width + img_w // x_width\n", " cart_y = img_h * 95 // 100\n", " top_pole_x = cart_x + pole_len * np.sin(ang)\n", " top_pole_y = cart_y - cart_h // 2 - pole_len * np.cos(ang)\n", " draw.line((0, cart_y, img_w, cart_y), fill=0)\n", " draw.rectangle((cart_x - cart_w // 2, cart_y - cart_h // 2, cart_x + cart_w // 2, cart_y + cart_h // 2), fill=cart_col) # draw cart\n", " draw.line((cart_x, cart_y - cart_h // 2, top_pole_x, top_pole_y), fill=pole_col, width=pole_w) # draw pole\n", " return np.array(img)\n", "\n", "def plot_cart_pole(env, obs):\n", " img = render_cart_pole(env, obs)\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABDFJREFUeJzt3NFJw2AYQFEjncI1nMM16kztGs7RNVwjvkmJgkIb/5h7DhTaQsv3kF4+wk+neZ4fANi3x9EDALA+sQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCDiMHmDBfzcAfDXd+gU2e4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAgMPoAWBLLufXL+89H08DJoH7stkDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9/OByfh09AtxM7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOzhyvPxNHoEWIXYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPQnTNP36scbnYTSxBwg4jB4Atujt/fj5/OXpPHASuA+bPSxch/671/AfiT1AgNgDBIg9LCzv0btnzx5M8zyPnuHapoZhP/7ySOTGflPsw80X8KZO4zijzB64jrm3eywQm4q9jYi12Oypc88eIEDsAQLEHiBA7AECxB4gQOwBAsQeIGBT5+xhLc6+U2ezBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIOIweYGEaPQDAHtnsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQI+ALOaKEVPVYdnAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_cart_pole(env, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "행동 공간을 확인해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(2)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "네 딱 두 개의 행동이 있네요. 왼쪽이나 오른쪽 방향으로 가속합니다. 막대가 넘어지기 전까지 카트를 왼쪽으로 밀어보죠:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()\n", "while True:\n", " obs, reward, done, info = env.step(0)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAAEYCAYAAAAeWvJ8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABThJREFUeJzt3MFx01AUQFHMuAraIG3QRlxT3AZtQBu0YRYMAyRxYklfvsI+Z8abLDx/Ec31k5+8O51OHwCg8rE+AAD3TYgASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABI7esDPOP3hgD+H7sRb2IiAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECILWvDwC37vvx8OJvnx+fgpPANpmIIPD9eHg1UHCPhAgipiL4RYgASAkRrMjtN3ifEEHAbTn4Q4gASAkRACkhAiAlRLASiwpwGSGCK7OoAP8SIliBaQguJ0QApIQIgJQQwRX5fgheEiIAUkIEg1lUgGmECK7EbTl4nRABkBIiAFJCBEBKiGAgiwownRDBFVhUgPOECICUEMEg527LmYbgbUIEQEqIAEgJEQApIYIBrG3DfEIEQEqIYEU25uB9QgRASogASAkRACkhgoVszMEyQgQrsagAlxEiWMA0BMsJEazANASXEyIAUkIEQEqIAEgJEcxkUQHGECIYzKICTCNEAKSECICUEAGQEiKYwaICjCNEMJBFBZhOiABICRFMdO62nGkI5hEiAFJCBEBKiABICRFMYG0bxhMiGMCiAswnRACkhAiAlBABkBIiuJBFBViHEMFCFhVgGSECICVEcAG/LwfrESIAUkIEQEqIAEgJEQApIYJ3eH4I1iVE3JXdbjf5dc7D4bj4PQAhgjd9e3qsjwA3b18fALbs64+XIfry6RicBG6XiQgmei1OwHxCBEBKiGCGh4PbczCKEMEZ5xYVfEcEY+1Op1N9hr9t6jDcniWr1L/DNGca2th1BqMMeTZBiLgr1TM9G7vOYJQhF9Sm1rc9+Met8r/NLRr1AWtTIfKpkbWZiGB7LCsAkBIiAFJCBEBKiABICREAKSECICVEAKQ29RwRrM3zPLA9JiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEjt6wM8s6sPAMB1mYgASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASP0EnbRmEsNMcIUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img = render_cart_pole(env, obs)\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"cart_pole_plot\")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(400, 600, 3)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "막대가 실제로 넘어지지 않더라도 너무 기울어지면 게임이 끝납니다. 환경을 다시 초기화하고 이번에는 오른쪽으로 밀어보겠습니다:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()\n", "while True:\n", " obs, reward, done, info = env.step(1)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABKNJREFUeJzt3MFNW0EUQFF+5CrSRmgjbUAZiDJCG2kjaSNtOBsWljHY2InHf+45kjdIWLMw109vvli22+0dAHP7MvoAAPx/Yg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPELAZfYA9/ncDwFvLpW9gsgcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxJ+n3y+PoI8BVbUYfAK5F4Ckz2ZNwKPTiT4nYk/Dt4cfoI8BQYk+a6Z4KsQcIEHuAALEnw96eMrEHCBB78lzSUiD2AAFiT4q9PVViDxAg9nBnb8/8xB4gQOwBAsSeHJe0FIk9vLK3Z2ZiDxAg9iRZ5VAj9gABYg877O2ZldgDBIg9QIDYk+WSlhKxhz329sxI7Ekz3VMh9gABYg8QIPZwgL09sxF78uztKRB7eIfpnpmIPUCA2AMEiD3c2dszP7EHCBB7+IBLWmYh9gABYg8QIPbwyiUtMxN7OMLenhmIPUCA2MMOqxxmJfYAAWIPJ7C3Z+3EHiBA7AECxB72uKRlRmIPJ7K3Z83EHiBA7AECxB4OsLdnNmIPn2Bvz1qJPbzDdM9MxB4+yXTPGok9QIDYAwSIPXzA3p5ZiD1AgNjDGVzSsjZiT9ayLCe9Lv39Y+8D1yD2cMT948voI8DFNqMPAGvx88/D3k98CbAeJns409PTr9FHgJOJPZzg7VQP6yL2cILn5/vRR4CLiD2c6ftXO3vWY9lut6PPsOumDsPcrv045I39rbEuF39Yb+ppHM8iMzOfb871LwaFm4q9yYdrMtlTYmcPECD2AAFiDxAg9gABYg8QIPYAAWIPEHBTz9nDNXnunRKTPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AwGb0AfYsow8AMCOTPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUDAX68dV0s9rwzRAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_cart_pole(env, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아까 말했던 것과 같은 상황인 것 같습니다. 어떻게 막대가 똑 바로 서있게 만들 수 있을까요? 이를 위한 *정책*을 만들어야 합니다. 이 정책은 에이전트가 각 스텝에서 행동을 선택하기 위해 사용할 전략입니다. 어떤 행동을 할지 결정하기 위해 지난 행동이나 관측을 사용할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 하드 코딩 정책" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "간단한 정책을 하드 코딩해 보겠습니다. 막대가 왼쪽으로 기울어지면 카트를 왼쪽으로 밀고 반대의 경우는 오른쪽으로 밉니다. 작동이 되는지 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "frames = []\n", "\n", "n_max_steps = 1000\n", "n_change_steps = 10\n", "\n", "obs = env.reset()\n", "for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", "\n", " # hard-coded policy\n", " position, velocity, angle, angular_velocity = obs\n", " if angle < 0:\n", " action = 0\n", " else:\n", " action = 1\n", "\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아니네요, 불안정해서 몇 번 움직이고 막대가 너무 기울어져 게임이 끝났습니다. 더 똑똑한 정책이 필요합니다!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 신경망 정책" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측을 입력으로 받고 각 관측에 대해 선택할 행동을 출력하는 신경망을 만들어 보겠습니다. 행동을 선택하기 위해 네트워크는 먼저 각 행동에 대한 확률을 추정하고 그다음 추정된 확률을 기반으로 랜덤하게 행동을 선택합니다. Cart-Pole 환경의 경우에는 두 개의 행동(왼쪽과 오른쪽)이 있으므로 하나의 출력 뉴런만 있으면 됩니다. 행동 0(왼쪽)에 대한 확률 `p`를 출력할 것입니다. 행동 1(오른쪽)에 대한 확률은 `1 - p`가 됩니다." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :12: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.dense instead.\n", "WARNING:tensorflow:From /home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Colocations handled automatically by placer.\n", "WARNING:tensorflow:From :18: multinomial (from tensorflow.python.ops.random_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.random.categorical instead.\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "# 1. 네트워크 구조를 설정합니다\n", "n_inputs = 4 # == env.observation_space.shape[0]\n", "n_hidden = 4 # 간단한 작업이므로 너무 많은 뉴런이 필요하지 않습니다\n", "n_outputs = 1 # 왼쪽으로 가속할 확률을 출력합니다\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "# 2. 네트워크를 만듭니다\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu,\n", " kernel_initializer=initializer)\n", "outputs = tf.layers.dense(hidden, n_outputs, activation=tf.nn.sigmoid,\n", " kernel_initializer=initializer)\n", "\n", "# 3. 추정된 확률을 기반으로 랜덤하게 행동을 선택합니다\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "init = tf.global_variables_initializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 환경은 각 관측이 환경의 모든 상태를 포함하고 있기 때문에 지난 행동과 관측은 무시해도 괜찮습니다. 숨겨진 상태가 있다면 이 정보를 추측하기 위해 이전 행동과 상태를 고려해야 합니다. 예를 들어, 속도가 없고 카트의 위치만 있다면 현재 속도를 예측하기 위해 현재의 관측뿐만 아니라 이전 관측도 고려해야 합니다. 관측에 잡음이 있을 때도 같은 경우입니다. 현재 상태를 근사하게 추정하기 위해 과거 몇 개의 관측을 사용하는 것이 좋을 것입니다. 이 문제는 아주 간단해서 현재 관측에 잡음이 없고 환경의 모든 상태가 담겨 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책 네트워크에서 만든 확률을 기반으로 가장 높은 확률을 가진 행동을 고르지 않고 왜 랜덤하게 행동을 선택하는지 궁금할 수 있습니다. 이런 방식이 에이전트가 새 행동을 *탐험*하는 것과 잘 동작하는 행동을 *이용*하는 것 사이에 균형을 맞추게 합니다. 만약 어떤 레스토랑에 처음 방문했다고 가정합시다. 모든 메뉴에 대한 선호도가 동일하므로 랜덤하게 하나를 고릅니다. 이 메뉴가 맛이 좋았다면 다음에 이를 주문할 가능성을 높일 것입니다. 하지만 100% 확률이 되어서는 안됩니다. 그렇지 않으면 다른 메뉴를 전혀 선택하지 않게 되고 더 좋을 수 있는 메뉴를 시도해 보지 못하게 됩니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책 신경망을 랜덤하게 초기화하고 게임 하나를 플레이해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "n_max_steps = 1000\n", "frames = []\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", " action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " if done:\n", " break\n", "\n", "env.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "랜덤하게 초기화한 정책 네트워크가 얼마나 잘 동작하는지 확인해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "음.. 별로 좋지 않네요. 신경망이 더 잘 학습되어야 합니다. 먼저 앞서 사용한 기본 정책을 학습할 수 있는지 확인해 보겠습니다. 막대가 왼쪽으로 기울어지면 왼쪽으로 움직이고 오른쪽으로 기울어지면 오른쪽으로 이동해야 합니다. 다음 코드는 같은 신경망이지만 타깃 확률 `y`와 훈련 연산(`cross_entropy`, `optimizer`, `training_op`)을 추가했습니다:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "reset_graph()\n", "\n", "n_inputs = 4\n", "n_hidden = 4\n", "n_outputs = 1\n", "\n", "learning_rate = 0.01\n", "\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "y = tf.placeholder(tf.float32, shape=[None, n_outputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs)\n", "outputs = tf.nn.sigmoid(logits) # 행동 0(왼쪽)에 대한 확률\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "training_op = optimizer.minimize(cross_entropy)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "동일한 네트워크를 동시에 10개의 다른 환경에서 플레이하고 1,000번 반복동안 훈련시키겠습니다. 완료되면 환경을 리셋합니다." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] } ], "source": [ "n_environments = 10\n", "n_iterations = 1000\n", "\n", "envs = [gym.make(\"CartPole-v0\") for _ in range(n_environments)]\n", "observations = [env.reset() for env in envs]\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " target_probas = np.array([([1.] if obs[2] < 0 else [0.]) for obs in observations]) # angle<0 이면 proba(left)=1. 이 되어야 하고 그렇지 않으면 proba(left)=0. 이 되어야 합니다\n", " action_val, _ = sess.run([action, training_op], feed_dict={X: np.array(observations), y: target_probas})\n", " for env_index, env in enumerate(envs):\n", " obs, reward, done, info = env.step(action_val[env_index][0])\n", " observations[env_index] = obs if not done else env.reset()\n", " saver.save(sess, \"./my_policy_net_basic.ckpt\")\n", "\n", "for env in envs:\n", " env.close()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "def render_policy_net(model_path, action, X, n_max_steps = 1000):\n", " frames = []\n", " env = gym.make(\"CartPole-v0\")\n", " obs = env.reset()\n", " with tf.Session() as sess:\n", " saver.restore(sess, model_path)\n", " for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", " action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " if done:\n", " break\n", " env.close()\n", " return frames " ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use standard file APIs to check for files with this prefix.\n", "INFO:tensorflow:Restoring parameters from ./my_policy_net_basic.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = render_policy_net(\"./my_policy_net_basic.ckpt\", action, X)\n", "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책을 잘 학습한 것 같네요. 이제 스스로 더 나은 정책을 학습할 수 있는지 알아 보겠습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 정책 그래디언트" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "신경망을 훈련하기 위해 타깃 확률 `y`를 정의할 필요가 있습니다. 행동이 좋다면 이 확률을 증가시켜야 하고 반대로 나쁘면 이를 감소시켜야 합니다. 하지만 행동이 좋은지 나쁜지 어떻게 알 수 있을까요? 대부분의 행동으로 인한 영향은 뒤늦게 나타나는 것이 문제입니다. 게임에서 이기거나 질 때 어떤 행동이 이런 결과에 영향을 미쳤는지 명확하지 않습니다. 마지막 행동일까요? 아니면 마지막 10개의 행동일까요? 아니면 50번 스텝 앞의 행동일까요? 이를 *신용 할당 문제*라고 합니다.\n", "\n", "*정책 그래디언트* 알고리즘은 먼저 여러번 게임을 플레이하고 성공한 게임에서의 행동을 조금 더 높게 실패한 게임에서는 조금 더 낮게 되도록 하여 이 문제를 해결합니다. 먼저 게임을 진행해 보고 다시 어떻게 한 것인지 살펴 보겠습니다." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :21: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.cast instead.\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "reset_graph()\n", "\n", "n_inputs = 4\n", "n_hidden = 4\n", "n_outputs = 1\n", "\n", "learning_rate = 0.01\n", "\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs)\n", "outputs = tf.nn.sigmoid(logits) # 행동 0(왼쪽)에 대한 확률\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "y = 1. - tf.to_float(action)\n", "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n", "gradients = [grad for grad, variable in grads_and_vars]\n", "gradient_placeholders = []\n", "grads_and_vars_feed = []\n", "for grad, variable in grads_and_vars:\n", " gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n", " gradient_placeholders.append(gradient_placeholder)\n", " grads_and_vars_feed.append((gradient_placeholder, variable))\n", "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def discount_rewards(rewards, discount_rate):\n", " discounted_rewards = np.zeros(len(rewards))\n", " cumulative_rewards = 0\n", " for step in reversed(range(len(rewards))):\n", " cumulative_rewards = rewards[step] + cumulative_rewards * discount_rate\n", " discounted_rewards[step] = cumulative_rewards\n", " return discounted_rewards\n", "\n", "def discount_and_normalize_rewards(all_rewards, discount_rate):\n", " all_discounted_rewards = [discount_rewards(rewards, discount_rate) for rewards in all_rewards]\n", " flat_rewards = np.concatenate(all_discounted_rewards)\n", " reward_mean = flat_rewards.mean()\n", " reward_std = flat_rewards.std()\n", " return [(discounted_rewards - reward_mean)/reward_std for discounted_rewards in all_discounted_rewards]" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-22., -40., -50.])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_rewards([10, 0, -50], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([-0.28435071, -0.86597718, -1.18910299]),\n", " array([1.26665318, 1.0727777 ])]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_and_normalize_rewards([[10, 0, -50], [10, 20]], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "반복: 0" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "반복: 249" ] } ], "source": [ "env = gym.make(\"CartPole-v0\")\n", "\n", "n_games_per_update = 10\n", "n_max_steps = 1000\n", "n_iterations = 250\n", "save_iterations = 10\n", "discount_rate = 0.95\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " print(\"\\r반복: {}\".format(iteration), end=\"\")\n", " all_rewards = []\n", " all_gradients = []\n", " for game in range(n_games_per_update):\n", " current_rewards = []\n", " current_gradients = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " action_val, gradients_val = sess.run([action, gradients], feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " current_rewards.append(reward)\n", " current_gradients.append(gradients_val)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_gradients.append(current_gradients)\n", "\n", " all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n", " feed_dict = {}\n", " for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n", " mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n", " for game_index, rewards in enumerate(all_rewards)\n", " for step, reward in enumerate(rewards)], axis=0)\n", " feed_dict[gradient_placeholder] = mean_gradients\n", " sess.run(training_op, feed_dict=feed_dict)\n", " if iteration % save_iterations == 0:\n", " saver.save(sess, \"./my_policy_net_pg.ckpt\")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_policy_net_pg.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = render_policy_net(\"./my_policy_net_pg.ckpt\", action, X, n_max_steps=1000)\n", "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 마르코프 연쇄" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태: 0 0 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 ...\n", "상태: 0 0 3 \n", "상태: 0 0 0 1 2 1 2 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n" ] } ], "source": [ "transition_probabilities = [\n", " [0.7, 0.2, 0.0, 0.1], # s0에서 s0, s1, s2, s3으로\n", " [0.0, 0.0, 0.9, 0.1], # s1에서 ...\n", " [0.0, 1.0, 0.0, 0.0], # s2에서 ...\n", " [0.0, 0.0, 0.0, 1.0], # s3에서 ...\n", " ]\n", "\n", "n_max_steps = 50\n", "\n", "def print_sequence(start_state=0):\n", " current_state = start_state\n", " print(\"상태:\", end=\" \")\n", " for step in range(n_max_steps):\n", " print(current_state, end=\" \")\n", " if current_state == 3:\n", " break\n", " current_state = np.random.choice(range(4), p=transition_probabilities[current_state])\n", " else:\n", " print(\"...\", end=\"\")\n", " print()\n", "\n", "for _ in range(10):\n", " print_sequence()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 마르코프 결정 과정" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "policy_fire\n", "상태 (+보상): 0 (10) 0 (10) 0 1 (-50) 2 2 2 (40) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 210\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 2 (40) 0 (10) ... 전체 보상 = 70\n", "상태 (+보상): 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 70\n", "상태 (+보상): 0 1 (-50) 2 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 ... 전체 보상 = -10\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) ... 전체 보상 = 290\n", "요약: 평균=121.1, 표준 편차=129.333766, 최소=-330, 최대=470\n", "\n", "policy_random\n", "상태 (+보상): 0 1 (-50) 2 1 (-50) 2 (40) 0 1 (-50) 2 2 (40) 0 ... 전체 보상 = -60\n", "상태 (+보상): 0 (10) 0 0 0 0 0 (10) 0 0 0 (10) 0 ... 전체 보상 = -30\n", "상태 (+보상): 0 1 1 (-50) 2 (40) 0 0 1 1 1 1 ... 전체 보상 = 10\n", "상태 (+보상): 0 (10) 0 (10) 0 0 0 0 1 (-50) 2 (40) 0 0 ... 전체 보상 = 0\n", "상태 (+보상): 0 0 (10) 0 1 (-50) 2 (40) 0 0 0 0 (10) 0 (10) ... 전체 보상 = 40\n", "요약: 평균=-22.1, 표준 편차=88.152740, 최소=-380, 최대=200\n", "\n", "policy_safe\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "상태 (+보상): 0 (10) 0 (10) 0 (10) 0 1 1 1 1 1 1 ... 전체 보상 = 30\n", "상태 (+보상): 0 (10) 0 1 1 1 1 1 1 1 1 ... 전체 보상 = 10\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "요약: 평균=22.3, 표준 편차=26.244312, 최소=0, 최대=170\n", "\n" ] } ], "source": [ "transition_probabilities = [\n", " [[0.7, 0.3, 0.0], [1.0, 0.0, 0.0], [0.8, 0.2, 0.0]], # s0에서, 행동 a0이 선택되면 0.7의 확률로 상태 s0로 가고 0.3의 확률로 상태 s1로 가는 식입니다.\n", " [[0.0, 1.0, 0.0], None, [0.0, 0.0, 1.0]],\n", " [None, [0.8, 0.1, 0.1], None],\n", " ]\n", "\n", "rewards = [\n", " [[+10, 0, 0], [0, 0, 0], [0, 0, 0]],\n", " [[0, 0, 0], [0, 0, 0], [0, 0, -50]],\n", " [[0, 0, 0], [+40, 0, 0], [0, 0, 0]],\n", " ]\n", "\n", "possible_actions = [[0, 1, 2], [0, 2], [1]]\n", "\n", "def policy_fire(state):\n", " return [0, 2, 1][state]\n", "\n", "def policy_random(state):\n", " return np.random.choice(possible_actions[state])\n", "\n", "def policy_safe(state):\n", " return [0, 0, 1][state]\n", "\n", "class MDPEnvironment(object):\n", " def __init__(self, start_state=0):\n", " self.start_state=start_state\n", " self.reset()\n", " def reset(self):\n", " self.total_rewards = 0\n", " self.state = self.start_state\n", " def step(self, action):\n", " next_state = np.random.choice(range(3), p=transition_probabilities[self.state][action])\n", " reward = rewards[self.state][action][next_state]\n", " self.state = next_state\n", " self.total_rewards += reward\n", " return self.state, reward\n", "\n", "def run_episode(policy, n_steps, start_state=0, display=True):\n", " env = MDPEnvironment()\n", " if display:\n", " print(\"상태 (+보상):\", end=\" \")\n", " for step in range(n_steps):\n", " if display:\n", " if step == 10:\n", " print(\"...\", end=\" \")\n", " elif step < 10:\n", " print(env.state, end=\" \")\n", " action = policy(env.state)\n", " state, reward = env.step(action)\n", " if display and step < 10:\n", " if reward:\n", " print(\"({})\".format(reward), end=\" \")\n", " if display:\n", " print(\"전체 보상 =\", env.total_rewards)\n", " return env.total_rewards\n", "\n", "for policy in (policy_fire, policy_random, policy_safe):\n", " all_totals = []\n", " print(policy.__name__)\n", " for episode in range(1000):\n", " all_totals.append(run_episode(policy, n_steps=100, display=(episode<5)))\n", " print(\"요약: 평균={:.1f}, 표준 편차={:1f}, 최소={}, 최대={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Q-러닝" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q-러닝은 에이전트가 플레이하는 것(가령, 랜덤하게)을 보고 점진적으로 Q-가치 추정을 향상시킵니다. 정확한 (또는 충분히 이에 가까운) Q-가치가 추정되면 최적의 정책은 가장 높은 Q-가치(즉, 그리디 정책)를 가진 행동을 선택하는 것이 됩니다." ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "n_states = 3\n", "n_actions = 3\n", "n_steps = 20000\n", "alpha = 0.01\n", "gamma = 0.99\n", "exploration_policy = policy_random\n", "q_values = np.full((n_states, n_actions), -np.inf)\n", "for state, actions in enumerate(possible_actions):\n", " q_values[state][actions]=0\n", "\n", "env = MDPEnvironment()\n", "for step in range(n_steps):\n", " action = exploration_policy(env.state)\n", " state = env.state\n", " next_state, reward = env.step(action)\n", " next_value = np.max(q_values[next_state]) # 그리디한 정책\n", " q_values[state, action] = (1-alpha)*q_values[state, action] + alpha*(reward + gamma * next_value)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def optimal_policy(state):\n", " return np.argmax(q_values[state])" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[39.13508139, 38.88079412, 35.23025716],\n", " [18.9117071 , -inf, 20.54567816],\n", " [ -inf, 72.53192111, -inf]])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "q_values" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태 (+보상): 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) ... 전체 보상 = 230\n", "상태 (+보상): 0 (10) 0 (10) 0 (10) 0 1 (-50) 2 2 1 (-50) 2 (40) 0 (10) ... 전체 보상 = 90\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 170\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 220\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) ... 전체 보상 = -50\n", "요약: 평균=125.6, 표준 편차=127.363464, 최소=-290, 최대=500\n", "\n" ] } ], "source": [ "all_totals = []\n", "for episode in range(1000):\n", " all_totals.append(run_episode(optimal_policy, n_steps=100, display=(episode<5)))\n", "print(\"요약: 평균={:.1f}, 표준 편차={:1f}, 최소={}, 최대={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n", "print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# DQN 알고리즘으로 미스팩맨 게임 학습하기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 미스팩맨 환경 만들기" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] }, { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env = gym.make(\"MsPacman-v0\")\n", "obs = env.reset()\n", "obs.shape" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(9)" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 전처리" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지 전처리는 선택 사항이지만 훈련 속도를 크게 높여 줍니다." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "mspacman_color = 210 + 164 + 74\n", "\n", "def preprocess_observation(obs):\n", " img = obs[1:176:2, ::2] # 자르고 크기를 줄입니다.\n", " img = img.sum(axis=2) # 흑백 스케일로 변환합니다.\n", " img[img==mspacman_color] = 0 # 대비를 높입니다.\n", " img = (img // 3 - 128).astype(np.int8) # -128~127 사이로 정규화합니다.\n", " return img.reshape(88, 80, 1)\n", "\n", "img = preprocess_observation(obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트 `preprocess_observation()` 함수가 책에 있는 것과 조금 다릅니다. 64비트 부동소수를 -1.0~1.0 사이로 나타내지 않고 부호있는 바이트(-128~127 사이)로 표현합니다. 이렇게 하는 이유는 재생 메모리가 약 8배나 적게 소모되기 때문입니다(52GB에서 6.5GB로). 정밀도를 감소시켜도 눈에 띄이게 훈련에 미치는 영향은 없습니다." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAArkAAAGoCAYAAABCPP0XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X28bFlZH/jfg6ICiolIgDRoDyI00qi0wfCqOKFpfNe0Lxi5DBMzRCRqy+BrEImjBvGFi0aJhNHWwxAwUaI0b6IRaZGOIEHSqGTUQXM7Dc1bQDAq4po/9j7ddavrVJ17zqlTtVd9v5/P/XSf2nvttfaqOnWf+9RTa1VrLQAA0JPbbHoAAABw0gS5AAB0R5ALAEB3BLkAAHRHkAsAQHcEuQAAdEeQCwBAdwS5E1ZVH1lVz6qqt1bV71fVb1TVZeOxr6+qq2fO/SdVdc2Cazy9qs4ecP3nV9XbZv68s6r+n/HYdVX1iAVtXjnX5o6HuI/bVNW/r6oHHP7umVdV/7CqfnDT4wCAbSDInbYnJrkkyf1ba/dN8qNJXrTk/EdW1dtn/yR5ykEnt9Ye21q7eP9Pku9O8v5F51bVZVX1B0k+OclfJPnz8dBNVfXpK+7jm5L8UWvtP4/Xun1VPbOqWlU9aK6fj6iqH6yqt1TV66vq12aOfcX4+Jur6jVVde8V/S66j7tU1U+N/2j47aq6tqruP3P8E6pqbxzbXefaPqmq/qCqrq+qa6rqLguu//Sqesf4j4TXjn8eOXP8tlX17VX1pqr63ar6rap6WVU9fDx+dVX96dj+t6vqDVX1hUnSWvvFJPetqssv9L4BoDeC3GmrDM/h/vP4kSvOf0Vr7a6zf5L88NIOqj52Pzuc5M5Jblp0Xmvtja21S1prlyS5LMlPJPlAki9urb15yfVvl+Tb58bxbUl+P8mfLGjyfyc5l+TS1toDk3zVeJ17JfmpJF/aWvv0JD+d5Bfm+vrRqvqqmZ8/oap+taruPHPaZUle2Vq7b2vts5P8hyQ/MnP8e5P8/IL7eMR4Hw9vrV2a5A3jWBd5cWvtQa21hyb5xiQvrqpPHI+9IMkDklzeWvuM1tpDkjwuyYdm2v/k2P6zk3xLkp+dOfYvkzzzgH4BYGesCorYbs9J8qlJfq+qWpK3J/maJedfUVXn5h67Y4aA8CCXJHn++N97JXn1zLF/XlVXtNa+c8w0XprkERmyub+Z5H1JvqmqPiPJrxwQ7H5Jkje31t65/0Br7elJUlXfM3vimJm9W5L/Jcl/qqr3ZsguvzvJlRmC0z8cT99L8sNVdf/W2n8ZH/v+JC+vqtsneVmSlyf5kbm+Xz43vhsz83vSWvtn41jm7+Orkzx/5lrPTvLOqvr41tr7Ftz3/vXeWFXvT3LxGCjfJ8lntdY+NHPOu5K864BLfHKS350593VV9YlVdb/W2lsO6hcAeifInaiq+rQk907ya0l+PclHJ7lDks+rqiszlwVtrT0vyfPGtk9Jcklr7Z9cYLffl+S9VfXx48+vSPLa8f/vlSHb+AOttZuDrjHD+rkZss6LfFZmgrQVHp7ks5P8UGvtyVX16CQvq6pPSXLPJH+0f2Jr7cNV9Sfj4/9lfOzd40f5L0vy9CRPHj/iX2gsN/jeJF93iLHdMzMZ3tbae6vqfUkuXnZ/VfVlSf4qyfVJHp/kl/cD3Kr6xiRfm+TjMgTwTx6bfcPY7s4Zfoe/fO6yv5thXgW5AOwsQe50XZzkYeP/X5Tkc5I8K0PG73cyBMBHVlX/IkN29GOS3KOq3pYhUP3LJC8cT/ud1tp1VfXKcQz7bQ+65l5rbf6LUR+b5B2HHNbfSfKbrbVfTZLW2ivGcT1sHNuH587/69y6JOdvZ8he35ghMD9orHfKGAy31n7jEGM7bP9J8uVV9ZkZsrDXJXlQa+0vxnm73f5JrbUfT/LjVfUdGTLp+36ytfaMcZyfmeSlVfWomcztB5J8fABghwlyJ6q19rIMQdh+PeinttZ+uKpenyHD93FJXlJVd88QSNX452/2r7GgdOHrWmuvHP//GRmC5jsm+bMk72utzbZ91Ey7L8/h6rv/asFjf5Kh5OIwbsqtv/j2NxmCy3MZAv9Z9xgfT5JU1X2S/GKSr0/y20n+XVXdfr88Yua8u2UoZXhma+0FhxzbuSSfNHON2ye502z/M17cWvv68Qtj/ypDUJoMz9OTq+o2s3O9TGvtTVX1uiRX5JbM7T2S/H+HHDcAdMkXzyauql6U5NP2f26tPXBcCeGfjz+fa63dPcMXkl7WWrv7kj+vnLnO/2yt/Y8kP5fkocuCrtbanye5PMkfHPDnVa21D7TWFgW51yT5Xw95u9ckedT+qglV9bAMgeXrMnxh6wvHoD5V9ZUZVnn4nZn2X5rkf2+tXdta+8sMmeo7jVnbjO0+OUMJyP91AQFuMtQAf+1MKceTkrx2tt53XmvtpRlKKZ4xPvTzGQL2n6iqj5059e4HXWMsqfj7Sf7T+PPHJ7lfhhIWANhZMrnTd58MS4f9zNzj70ryp2vs9weSvHX/h9bai5O8eP6kqrokQ+3uQq21t1TVH1XV5y/40tf8ue+oqq9N8gtV9aEMpRNfNn6x631jDetLxmN/luTzZ4Pz1toz5673oQyrG8z6kSR3SfKtVfWt42N/2Vr73BVj+/Wq+okkvzH2/9+TPGZZm9E3J7m+qn6+tXZtVX1eku9Kcl1VfThDoH5Tzl/hYb8md391jae21vZro/+PJD/bWvvgIfoGgG5Va23TY+AYqupNGVYc+NCCw29prV0xnvfPkvxgkvcuOO8PW2uPOOD6r86wasJfLDj8Pa21g5bJ2m9/SYalyy5ecs59k/ybJJ83u6oAF2bMSL88yRWttUXPMwDsDEEuW6GqPjfJu1tr1296LFNVVZ+T5F2ttd/b9FhgW1RVNX/RwU5Sk8tWaK39hgD3eFprrxHgsutq2E783Lgz4Lkk76+q247HXlhVj5859yl1yw6Qr5+7zrlxCcRU1Y+NP8/++ZTx2CXjKi/77S6uqg/X+dubv62q7l9V91rwhd9U1d9bcP6HqupTq+rR4ydq820+f+78p9VBS9uc3+6hNW7P3puqus/cc/Q14+NX17C1/fx29583lsu9u4ZdNj9i5tjbq+riA/r56ap60iHHVGPfV1fVI6rqN2eOPbCGXTVvqqo3VtUVM8dufq1W1f9ZVd9wYbNBoiYXgI7MlkZV1VMzt7nK3Lk/nJndFg/K+rbWvinD9uP7551Lctslw3jnohKt/aB5wfXfkLnVYWrYdv2gbdR/OslDcksZ2YeSfGeGpQb/0UGDqqo7ZNiJcX+b8L+f5IcyLBV52yR/nOSbWms3jF/wfXaG1Xoqw/c8vuVC/yG9oo+PSvJjuWUt9VdmWL98fjnGVNVbs3hpxLskuXNr7V2ttbdm/KJuVT0/w9rxB43rEzMsh/mFSX4vyS8n+adJfvIQt3XvJAuXlhz7fViGlX/ukGHnzTcuOO8OSV6S5B+31l5WVQ/MsBzkA1tr87t9PivDBki/3lr7/UOMj9FWBbk17NoFcKpaayszYEzLmGm9Ksk/qKrvTfKEDEHSK8bjT06yv8HKXyf5W0l+PMMuihtRwyY/NyV5T4axvnvRea21fzzT5gFJnpbkTRmWR1zmGzOssrO/6ssLMnxx9d+OWeCzGbYF/9oMO2r+amvtX479PDnDdyceOv58myQvTfKU/TW6x7H8i9bal8z0uayP78oQlN4vwyfLL8+wVfmttptvrd1n/rGq+sgMS1P+2Yr7XuQxGXbifMN4rR8Yx7U0yK2qz0ryGRm+nPyi1tp531dprT125tynJfmILHbvJB8YlwNNa+31NXzH5rLcejOnv6mqH83whe/5zX9YQrkCAF2pYU3slyb5YJKHtNae1lq7a5Jf2j+ntfaj+8snZtgq/E8zBFkn4c4Lyg++4hDtnpnkkRk2irmptfbX4+P3rarnjx9v37WqvrKqvr+qfitDBvc9GQLF76+qx44Z0kX+tyT/bubnG8axVpKPyhDo3zBz7BOq6iPGj/HvNHMs48o1351htZsHjBnbf5tx+cpD9vHVSX6stfY3473+RJZvTT/vThkCxb886ISq+uwM2eh590oymxX9vSSfsqyzMfv73CTfmuH19YIx0J495yFjWcIXJHlUFmRxR29Ncruq+qKxpOHBGYLn1x9w/kuSPLKq/vayMXK+rcrknvvmb970EACYqKq6XZJvSPKUDMvz/UqSa6rqQeNjB/mWJG9vrf3W+LH4HZLcdea6Vyf5styycUuyeEWbtNbeljF7V1XvSvL3xscOLFdY4KYkj65h3evbJnl7kqszZPg+JkMW8HUZNqx538y9PyzJvRatSV7DBjWX5Pxtxv9hhmDt25PcPsNa5N85HntCho/a35Eh0/3mzGURW2tvqGFN8l9I0pJ8yVgyMGtZH+dtxz7+/z2Xzsz57pRh98r9e7wkwxbp+3XPL8mwhftnZsESlzl/u/kDP80ZA/QvzVB28cLW2r8eM9k/m+TaqvqnrbU3j6d/aYb7vEOGbeifnWFnz2tmr9la+/MaNgT6oQxZ8/+W5DGttUUbCKW19oEatqq/f5LXHDRWzieTC0AvHpPkAUke3Fr7+XFDm0dk+Ch/YTncmGG9KsmlVXVpa+0+Y3b37XOnPn1u85w/uvXVFrptVd2pqu6fBRnFqrpLVV1fVddn2J79bIYM48uS/McMW6a/tw3bmT94fPxrMmwi89qZtq/PULv5pPGxj5vr6mMz1InOfrz+ggx1sJ+UIaj/q9yyJvePjXNw0fjnuiSLvrB2UYY1yz+U5O8uOL6sj/nt0A/aCv0g90jytrnHzrXWLh7/vKi19gUZ/rEz7w9z/nbp9835AfesB2UIzJ/UWvvuZMhkt9bOZFij/kVVdY/x8W/PkOm9V4b64ouTfMeii7bW3tRau7y1do/W2kNaa7Ob+LwkQ8A+y5btF2irMrnL3P0X7rbpIWzcuStvPPCY+cHrY7ll80MfWms/k+RnxqDynq21Px4/Bn9WktTM4gNj5vO7MmQnH5qhFvKVVfWEcTfCpWpYseGeST565rEHZAhOkiGg/FCGGuD3ZAgYFwWJN2XIwN42ye2SvGP24/eqevTMuS/NsCPjSq21D8w99M4MwejfTXJu/Oj9kRmyrx9O8uGqemaGYPmqDF9g+6z9sYw1q39eVZ/QWnvP+NiXZ6gHviJDPHFNVX1ba22/7nlVH/vbob9tHOMnZfFW6Ae5X4Yvsi1UVXfMUB6xyAuTfPdYY/v7GYLYn150YmvtdRl2llx07LkZShj2+7x3hp1Cf6W19q9W3UBVfUuGHTLn3SVDDfUbZh6zZfsFkskFoDdfnKG+daHx4+dfzhBUPri19t9aa7+U5CtyfnZv34eTPHWsrf3jsaThuiTfk5kgt7X2n8cs8OckeVNr7W6ttU9pw3brX5zzA5b9Nm3MON8vyS8uqy9trf31GLy+PAdvo/75CwLcjKtGvDTJPxgfem+GFRNma4W/MrfsZPlfx59nj70zyf9Ibp7Dy5M8qrX231trf5oh2P2CmTar+thL8sSxJvU2GUpN/v1B97/A/XJ+XW1LcrequmH8aP8VSb5kUcPW2rsyBPJ7Sf7fDGUcP3UBfd9KDdus7yX5vtba0w7TprX2rNbaveb/ZKZ+fLz2p2X4R9NbjjPGXTOZTC4AHNOTk3ywtdaq6oo2s+13cnPG7nXzjVprX5fk6xZdcKwDnXf7DJnhee/JsA37hbg+Q93m7HgefsBYXpgly2ZlKIX4kQxbf3+4qr4oyTOr6psz1BHfmGHVg2Qo/XhWVb0xQ5D/wSRfvD9nY9B83tqtrbUbM7PU2iH6eMY4pt8Zf35Nku+fu6erMmR9F/mYJFeMKz/82vg8ffT8STWzNu7ceH8tyacdcO39tg/KkPVd5YOttfvlgIzvkut/X4b655sWHJ79kuA3JHn2OO8ckiAXgB59QS3YeCFDjei3zQe4a3DXA/pPVT2vtbZoDdxPP6DN+zJkYY+ltfbaqnpLVX3VWLP82xlqlhed+18zrCF73D6X9fE/M6xNu6z92QyB8Ea01q7L3BrGa/BRGWqm5z04yS+NJRCfkeELklwAQS4AXWmtXZ1hNYLjXOPuhzzvDzIXBLVh98YL+vu1tfbqDMHOkbXWHnOI056Y5HHH6WdqWmuPn/nxXx+yzV1Xn3Xo/p+X5Hnjjw+bO/bUJE9dcYlLknx1O2BTEw4myAWAHTEuL/a8lSeyNVprv7zpMUyVL54BANAdQS4AAN1RrgAwbb5tDey6hTvWdRHkrlrkfdlC+EddQH8d7Va1PapNjPW053WVKT3Pp72pw5Se5038/gAwTcoVAADojiAXAIDu1DZtnnHDVVcdOBgfQ57+x9hMi9fHcsvm56KzZxfWc03E0jfxM2fOnNY4Jmlvb2/lOeaQ41j1GvP6Wu0Qv6cL38NlcgEA6I4gFwCA7ghyAQDojiAXAIDuCHIBAOiOIBcAgO4IcgEA6E4X2/oCcDTHXcNz6u1PwqbvYdPtV9n0+Dbd/iRs+h423f6odj7IXbZA/FH1svD+OuaGgdfIwXqZGwA2S7kCAADd6SKTuyqbJDMEALBbqrWl256fqhuuuurAwSwLVI8T5E7p49ZlY11Hn8oV1ue0n69NvCaPahNjvejs2YX7nk/E0jfx06gXnLJVtYKJOeR4tqGmd+oO8Xu68D1cuQIAAN0R5AIA0B1BLgAA3RHkAgDQHUEuAADd6WIJsSnZxHJnm/jWPUe3batoeI0AMEUyuQAAdEeQCwBAdwS5AAB0R03uKdtEfaOaymk57efL6wOAHnUR5PpLGgCAWV0EuQAczao94c+cOXOs9se1qv8erHsOWc5r7Pi2dQ7V5AIA0B1BLgAA3RHkAgDQHUEuAADdEeQCANAdQS4AAN3Z+SXEprTGrk0CWMZGIwBwi50PcgGYruOu89vLGFifbXh+t2EMU6RcAQCA7ghyAQDojiAXAIDuqMkFYLK2oRZxG8bA+mzD87sNY5gimVwAALojyAUAoDuCXAAAutNFTe65K29cenzZgvXL2p52u0302ctYVzF3xspiav2AXsnkAgDQHUEuAADd6aJc4TgfUR617Wm320SfUxrruq67C3O3K2MFYLfI5AIA0B1BLgAA3RHkAgDQHUEuAADdqdbapsdwsxuuuurAwfjCCXAcy9bYvejs2TrFoZy0pW/i1sEFpm5vb2/VKQvfw2VyAQDojiAXAIDuCHIBAOjOZDaDWLVn/TLrqOddNp511Q8fZw4OMqWxbpspzd0mxnrav3cAMEsmFwCA7ghyAQDojiAXAIDuTKYmF4CTt2r9yVXr7PbefhvGsO3tV9n0+KbefhvGsOn2RyWTCwBAdwS5AAB0R5ALAEB3qrWl256fqhuuumotg7FO7sGmNNZtM6W5s07uchedPbtw3/OJWPq+ua5aN4DTcoja+YXv4Tv/xbNt+0t6XcHIUUxprNtmSnO3ibFu4h+JAOwW5QoAAHRHkAsAQHcEuQAAdGfna3JPu/5vSvWGUxrrtpnS3G1irFOaHwCmSSYXAIDuCHIBAOiOIBcAgO4IcgEA6I4gFwCA7ghyAQDozs4vIQbQs0Ps+b7UmTNnTmgki60a36b7Pwk93MM2m/r8rnv8yW78ni0ikwsAQHe2KpNrgfjlpjQ/UxrrtjF3B1vX3LSza7ksABskkwsAQHe2KpMLF+K6S1994LEHXf+IUxsHALB9BLlMzrLgdv4cwS4A7CblCgAAdEcml8k4KDt73aWvXvjYsjYAQN8EuQAc2bavv3kaa5AeVw/3sM16mN9N38Om+z8q5QpMznWXvvpWdbmLHgMAdpcgFwCA7nRRrnDuyhuXHl+2gPyytutYeH4TY11Hu030ee6tw38X1deuqrnd9blb11iPakpjBWCaughyAdiMTdfibbr/k9DDPWyzHuZ30/ew6f6PSpDL5CyqvVWPCwDM6iLIPc5HlKf98eYmxnra7dbV53WXvvWow9n5uVtHu+OY0lgBmCZfPAMAoDuCXAAAuiPIBQCgO4JcAAC6I8gFAKA7glwm40HXP+K8TR9mf152DADYPYJcAAC608U6ueyW+QztfAZ32bkAwG7YqiB31b70R3XaC8iv6z7WYV1zM6U5OCpzd/rMDQCHpVwBAIDubFUmF4ALc+bMmU0PYdJ6mL/TvocnPOEJS48/97nPPaWRnI4eXiObtu453NvbW/i4TC4AAN0R5AIA0B1BLgAA3VGTCwAcaL4Gd1XN7YWeD+sikwsAQHcEuQAAdGfnyxWWLS5/2ptIbJtVC+8fdX6mNK9H3XxgXXPXC793AKzbzge5ALvsoPUl961a3/K47Y9r3eM/iWtMvf2111679Pgq235/297+pK5xHNswB0ehXAEAgO4IcgEA6M7Olyuo/zuYuTk6c7ec+QFg3aq1tukx3OyGq65ay2DW8RfqUb+QtG3WFWz08sWiTdyH19bB1jU3F509W2u58Ck4c+bM9ryJH8Gmaw1ZbX7d2wtlndzN6/33bG9vb+F7uHIFAAC6I8gFAKA7glwAALojyAUAoDuCXAAAuiPIBQCgO4JcAAC6s1WbQRx1Xc1NrCs6pbVemRavraM76ty1syc8EOiIdW6ZKplcAAC6I8gFAKA7glwAALqzVTW5AFyYVXvSH9em97Rf9/2dhnXPYQ9zdBzmd/ttag5lcgEA6I4gFwCA7ghyAQDojiAXAIDu7PwXz5ZtJLFsYfl1tNtEn+sa667blef5tMcKAIclkwsAQHcEuQAAdGfnyxWO+tHoabfbRJ/b/rHxNY974YHHvujnHnOKI7m1XXmee31tcXir1r/c9Dq7U7DuOZz6c3DcNVa9Ro9vqnMokwsAQHd2PpPL9CzL4M6fs+mMLgCwGTK5AAB0RyaXyTgoO3vN41648LFlbYCTsa21eFNiDtfL/B7fVOdQJhcAgO4Icpmcax73wlvV5S56DADYXYJcAAC6oyaXyVlUX6vmFgCYtVVB7rL97LfNlMa6zBQX5V9UltBTqYLX1tH1MncAHJ9yBQAAuiPIBQCgO4JcAAC6I8gFAKA7glwAALojyAUAoDtbtYQYABfmuHvK7+3tndBIjua444d18xpdv3W9j8nkMhlf9HOPOW/Th9mflx0DAHbPzmdyly0eP8WNEk7SqoX1NzU/88HrfHC77NzTsq1zty383gGwbjK5AAB0R5ALAEB3BLkAAHRn52ty1f8dbF1zs6petQdeV8uZHwDWTSYXAIDu7HwmF4CjW7XO7qr1L7e9/TaMwTqtx7Pp5+cknt9Nj2Gqr1GZXAAAuiPIBQCgO4JcAAC6oyYXgCM7bi3e1NtvyxhO0xOf+MSlx5/znOec0kgOZ9PPTw+vsam9RvfJ5AIA0B1BLgAA3dmqcoVeFojv5T6Owxwc3ZTm7rQ39ljX3LSza7ksABu0VUEuALBd5mtwV9XcXuj5sC6CXCbrzc94za0e+/Tv+Jzzju3/DADsFjW5AAB0RyaXyVmUwZ0/JqMLALtNkEsX5oNawS0A7DblCgAAdEcml8mZz86++RmvWVrCAADsHplcAAC600Umd9WC9MsWkF/Wdh0Lz29irOtot4k+D2o3m9k9KKNr7tYz1qOa0li33d7e3qaHcCxTH/9hnDlzZtNDOFHz6+D2zmt0/dY1xzK5AAB0R5ALAEB3uihXOM5HlKf98eYmxnra7U67z8N86czcnXy745jSWAGYJplcAAC600Uml91iuTAAYBWZXAAAuiOTS5ds5wsAu02QC8CRrVrfctPrb07Bts/hc57znI32f1zbPr9TMNU5FOQyOftZ2kW1uTK4AECiJhcAgA7J5DJZsrYAwEG2KshdtS/9UfWygPyy+VnHPa56PjbR5zrswtztwu8Am7GttXhTsu45XFVP2Tuv0eOb6hwqVwAAoDuCXAAAuiPIBQCgO4JcAAC6I8gFAKA7glwAALojyAUAoDuCXAAAurNVm0Fswq5vsLCuTQJ62ZhgF+ZuV16TAOwWmVwAALojyAUAoDs7X64AMGVT3VP+sHq/v5Nw2nP0/Oc//4LOf+xjH7umkZyOHl6De3t7G+1/3XN40P3tfJB72vV/m6g33JU+12EX5m4X7hGA3aNcAQCA7ghyAQDozs6XKwAARzdfc3uhNbuwLjK5AAB0R5ALAEB3BLkAAHRHTS7ADlu1fuaq9S2P2/641j3+k7hGb+0vdN3bqd3ftrU/qWscxzbMwVHI5AIA0J3JZHK3bfH4bRtPL8zrtExp44pzV954giMBYNvJ5AIA0J1qrW16DDe74aqrDhyMDN/yTNQ65mdV5stzcrBNzN1pvz6mZtn8XHT2bJ3iUE7UmTNntudN/Ag2XWvI8a1aF/dCa3g5eb3/nu3t7S18D5fJBQCgO4JcAAC6I8gFAKA7k1ldAQDYPDW4TIVMLgAA3RHkAgDQnS7KFY6zXNNRl11aR7tVbY9qE0tLnfa8Tm08y0xpqbhd+P0BYJq6CHIBgNOh5papUK4AAEB3BLkAAHSni3KF49ThHbXtabc7jin1ua6xbtt4tqlPvz8A9KiLIBdgV63ak/64pr6n/SqHmb91z8GqMWy6/3Xb9P313v822NRrTLkCAADdEeQCANAdQS4AAN0R5AIA0B1BLgAA3RHkAgDQHUEuAADd2fl1cs9deeOJX7OXBes3cR/reD5WWcd99vIaWMXvD5teA3TT/W/DGDbd/7pt+v423f82jGHT/R+VTC4AAN0R5AIA0B1BLgAA3dn5mlwAjm7TtXib7n8bxrDp/tdt0/e36f63YQyb7v+oZHIBAOiOIBcAgO4IcgEA6I4gFwCA7nTxxbNVC9Jv0+Lymxjrsj6X9XecsR61z22zC3Pn9weAHsnkAgDQHUEuAADd6aJcYUofUW5irEft8zhjndJzsswuzN2UnqspjRWAzZLJBQCgO4JcAAC6I8gFAKA7XdTkArDYtu85v+3j2wWeg/5t+jk+bv97e3tHaieTCwBAdwS5AAB0R5ALAEB3BLkAAHTHF886ce7KG0+9TwvzH2wTzwcAcAuZXAAAutNFJndV1kzGEQBgt3QR5AJwNKvWn1zP61QiAAAGAElEQVS1vuXU25+ETY9h2+dw28e3yqb7P4kxTL39USlXAACgO4JcAAC6I8gFAKA7anIBdthxa+Gm3v4kbHoM2z6H2z6+be//JMYw9fZHJZMLAEB3ZHJP2a4sd7bsPqd0j73cxzK78poEYLfI5AIA0B1BLgAA3RHkAgDQHTW5p2xX6ht7uc9e7mOZXbhHAHZPF0Guv6QBAJilXAEAgO4IcgEA6I4gFwCA7ghyAQDojiAXAIDudLG6AgCbsbe3t9brnzlzZq3X3wabvsd1P4errPv+Nz2/22BXf09lcgEA6M7OZ3KntMbulMZ6VL3cYy/3scqu3CcA0yOTCwBAdwS5AAB0R5ALAEB3BLkAAHRHkAsAQHd2fnUFgF22av3MbV3/ct82jP+4Y9h0+2236fnZhvndhjEcx6bGL5MLAEB3BLkAAHSni3KFc1feuPT4sgXrl7U97Xab6LOXsa5i7owVgN3SRZALwNFsey3fKtsw/uOOYdPtt92m52cb5ncbxnAcmxq/cgUAALrTRSb3OB9RHrXtabfbRJ9TGuu6rrsLc7crYwVgt8jkAgDQHUEuAADdEeQCANAdQS4AAN0R5AIA0B1BLgAA3RHkAgDQHUEuAADdmcxmEKv2rAcAgH2TCXIBOHl7e3tLj29qz/ltsmqO1s1zcDybfv7YHOUKAAB0R5ALAEB3tqpc4e7PfvamhwDsoHb27KaHAMAJk8kFAKA7W5XJXadXveqBSZLLL3/9eT/P2j825T7hpLzisstu/v9Hv/GNGxwJAFy4nQhyX/WqBy4NNGfPS44feB4muD3pPuGk7Ae3s4HtbMA7fwwAtpFyBQAAurMTmdzLL3/9rbKph8myHqe/RddeNA7YNsuytPvHFmV7YROs89v/HPR+f4dhDo5GJhcAgO7sRCZ3kdmM6mnVw26iT1gHGV0Atp1MLgAA3dnZTO4mMqmyt0zRKy677FaZ2vnVFmBT1CL2Pwe9399hmIOjkckFAKA7O5vJBZY7zHq5s4+rywVgm+xskOuLZ3A4AlgApmhng1wA1PoB/RLkZvGmDT32CYexqCRhvnTBNr8AbDtfPAMAoDs7k8ld5za+29QnnAaZWwC2nUwuAADd2ZlM7r5l2dV11cVuok84qkV1tzK3AEyNTC4AAN3ZuUzuPtv6wnKytwBMWbXWNj2Gm1XV9gwG2Bmttdr0GI5h6fumdXCBqdvb21t1ysL3cOUKAAB0R5ALAEB3BLkAAHRHkAsAQHcEuQAAdGdnlxCDk3Lt2Yff6rGHX3XtBkYCAOyTyQUAoDsyuXBE8xnc/ezttWcffvMxGV223ar1J1ets9t7+20Yw7a3X2XT45t6+20Yw6bbH5UgFy7QQcHt7M/75wh2AWAzlCsAANAdQS4AAN2p1pZue36qqmp7BgMHWFWucNhz2B6ttYX7nk/E0vfNddW6AZyWQ9TOL3wPl8kFAKA7glwAALpjdQW4QLNLhc3+d/7x2ccAgNMlyIUjOijYnT0GAGyGcgUAALojkwvHJGsLANtHJhcAgO4IcgEA6I4gFwCA7ghyAQDojiAXAIDuVGtLtz0/VVW1PYMBdkZrbeG+5xPhfRPYdQvfw2VyAQDojiAXAIDuCHIBAOiOHc8AJqxqs+XE73//+8/7+Y53vONO9Q/b4Oqrrz7v58c//vEbGcemHPT9MplcAAC6I8gFAKA7glwAALpjnVxg5015ndzTft+cr4Fd5aRrZDfdP7B9DnoPl8kFAKA7glwAALojyAUAoDvWyQWYsG2vJ9709z423T+wOTK5AAB0R5ALAEB3BLkAAHRHkAsAQHcEuQAAdEeQCwBAdwS5AAB0R5ALAEB3BLkAAHRHkAsAQHfKlocAAPRGJhcAgO4IcgEA6I4gFwCA7ghyAQDojiAXAIDuCHIBAOiOIBcAgO4IcgEA6I4gFwCA7ghyAQDojiAXAIDuCHIBAOiOIBcAgO4IcgEA6I4gFwCA7ghyAQDojiAXAIDuCHIBAOiOIBcAgO4IcgEA6I4gFwCA7ghyAQDojiAXAIDuCHIBAOiOIBcAgO78/8KMFwBuquGUAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"원본 관측 (160×210 RGB)\")\n", "plt.imshow(obs)\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"전처리된 관측 (88×80 그레이스케일)\")\n", "plt.imshow(img.reshape(88, 80), interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "save_fig(\"preprocessing_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DQN 만들기" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "reset_graph()\n", "\n", "input_height = 88\n", "input_width = 80\n", "input_channels = 1\n", "conv_n_maps = [32, 64, 64]\n", "conv_kernel_sizes = [(8,8), (4,4), (3,3)]\n", "conv_strides = [4, 2, 1]\n", "conv_paddings = [\"SAME\"] * 3 \n", "conv_activation = [tf.nn.relu] * 3\n", "n_hidden_in = 64 * 11 * 10 # conv3은 11x10 크기의 64개의 맵을 가집니다\n", "n_hidden = 512\n", "hidden_activation = tf.nn.relu\n", "n_outputs = env.action_space.n # 9개의 행동이 가능합니다\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "def q_network(X_state, name):\n", " prev_layer = X_state / 128.0 # 픽셀 강도를 [-1.0, 1.0] 범위로 스케일 변경합니다.\n", " with tf.variable_scope(name) as scope:\n", " for n_maps, kernel_size, strides, padding, activation in zip(\n", " conv_n_maps, conv_kernel_sizes, conv_strides,\n", " conv_paddings, conv_activation):\n", " prev_layer = tf.layers.conv2d(\n", " prev_layer, filters=n_maps, kernel_size=kernel_size,\n", " strides=strides, padding=padding, activation=activation,\n", " kernel_initializer=initializer)\n", " last_conv_layer_flat = tf.reshape(prev_layer, shape=[-1, n_hidden_in])\n", " hidden = tf.layers.dense(last_conv_layer_flat, n_hidden,\n", " activation=hidden_activation,\n", " kernel_initializer=initializer)\n", " outputs = tf.layers.dense(hidden, n_outputs,\n", " kernel_initializer=initializer)\n", " trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\n", " scope=scope.name)\n", " trainable_vars_by_name = {var.name[len(scope.name):]: var\n", " for var in trainable_vars}\n", " return outputs, trainable_vars_by_name" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :26: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.conv2d instead.\n" ] } ], "source": [ "X_state = tf.placeholder(tf.float32, shape=[None, input_height, input_width,\n", " input_channels])\n", "online_q_values, online_vars = q_network(X_state, name=\"q_networks/online\")\n", "target_q_values, target_vars = q_network(X_state, name=\"q_networks/target\")\n", "\n", "copy_ops = [target_var.assign(online_vars[var_name])\n", " for var_name, target_var in target_vars.items()]\n", "copy_online_to_target = tf.group(*copy_ops)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'/conv2d/kernel:0': ,\n", " '/conv2d/bias:0': ,\n", " '/conv2d_1/kernel:0': ,\n", " '/conv2d_1/bias:0': ,\n", " '/conv2d_2/kernel:0': ,\n", " '/conv2d_2/bias:0': ,\n", " '/dense/kernel:0': ,\n", " '/dense/bias:0': ,\n", " '/dense_1/kernel:0': ,\n", " '/dense_1/bias:0': }" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "online_vars" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.cast instead.\n" ] } ], "source": [ "learning_rate = 0.001\n", "momentum = 0.95\n", "\n", "with tf.variable_scope(\"train\"):\n", " X_action = tf.placeholder(tf.int32, shape=[None])\n", " y = tf.placeholder(tf.float32, shape=[None, 1])\n", " q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n", " axis=1, keepdims=True)\n", " error = tf.abs(y - q_value)\n", " clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n", " linear_error = 2 * (error - clipped_error)\n", " loss = tf.reduce_mean(tf.square(clipped_error) + linear_error)\n", "\n", " global_step = tf.Variable(0, trainable=False, name='global_step')\n", " optimizer = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)\n", " training_op = optimizer.minimize(loss, global_step=global_step)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트: 처음 책을 쓸 때는 타깃 Q-가치(y)와 예측 Q-가치(q_value) 사이의 제곱 오차를 사용했습니다. 하지만 매우 잡음이 많은 경험 때문에 작은 오차(1.0 이하)에 대해서만 손실에 이차식을 사용하고, 큰 오차에 대해서는 위의 계산식처럼 선형적인 손실(절대 오차의 두 배)을 사용하는 것이 더 낫습니다. 이렇게 하면 큰 오차가 모델 파라미터를 너무 많이 변경하지 못합니다. 또 몇 가지 하이퍼파라미터를 조정했습니다(작은 학습률을 사용하고 논문에 따르면 적응적 경사 하강법 알고리즘이 이따금 나쁜 성능을 낼 수 있으므로 Adam 최적화대신 네스테로프 가속 경사를 사용합니다). 아래에서 몇 가지 다른 하이퍼파라미터도 수정했습니다(재생 메모리 크기 확대, e-그리디 정책을 위한 감쇠 단계 증가, 할인 계수 증가, 온라인 DQN에서 타깃 DQN으로 복사 빈도 축소 등입니다)." ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", "replay_memory_size = 500000\n", "replay_memory = deque([], maxlen=replay_memory_size)\n", "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n", " cols = [[], [], [], [], []] # 상태, 행동, 보상, 다음 상태, 계속\n", " for idx in indices:\n", " memory = replay_memory[idx]\n", " for col, value in zip(cols, memory):\n", " col.append(value)\n", " cols = [np.array(col) for col in cols]\n", " return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ReplayMemory 클래스를 사용한 방법 ==================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "랜덤 억세스(random access)가 훨씬 빠르기 때문에 deque 대신에 ReplayMemory 클래스를 사용합니다(기여해 준 @NileshPS 님 감사합니다). 또 기본적으로 중복을 허용하여 샘플하면 큰 재생 메모리에서 중복을 허용하지 않고 샘플링하는 것보다 훨씬 빠릅니다." ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "class ReplayMemory:\n", " def __init__(self, maxlen):\n", " self.maxlen = maxlen\n", " self.buf = np.empty(shape=maxlen, dtype=np.object)\n", " self.index = 0\n", " self.length = 0\n", " \n", " def append(self, data):\n", " self.buf[self.index] = data\n", " self.length = min(self.length + 1, self.maxlen)\n", " self.index = (self.index + 1) % self.maxlen\n", " \n", " def sample(self, batch_size, with_replacement=True):\n", " if with_replacement:\n", " indices = np.random.randint(self.length, size=batch_size) # 더 빠름\n", " else:\n", " indices = np.random.permutation(self.length)[:batch_size]\n", " return self.buf[indices]" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "replay_memory_size = 500000\n", "replay_memory = ReplayMemory(replay_memory_size)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "def sample_memories(batch_size):\n", " cols = [[], [], [], [], []] # 상태, 행동, 보상, 다음 상태, 계속\n", " for memory in replay_memory.sample(batch_size):\n", " for col, value in zip(cols, memory):\n", " col.append(value)\n", " cols = [np.array(col) for col in cols]\n", " return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### =============================================" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "eps_min = 0.1\n", "eps_max = 1.0\n", "eps_decay_steps = 2000000\n", "\n", "def epsilon_greedy(q_values, step):\n", " epsilon = max(eps_min, eps_max - (eps_max-eps_min) * step/eps_decay_steps)\n", " if np.random.rand() < epsilon:\n", " return np.random.randint(n_outputs) # 랜덤 행동\n", " else:\n", " return np.argmax(q_values) # 최적 행동" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "n_steps = 4000000 # 전체 훈련 스텝 횟수\n", "training_start = 10000 # 10,000번 게임을 반복한 후에 훈련을 시작합니다\n", "training_interval = 4 # 4번 게임을 반복하고 훈련 스텝을 실행합니다\n", "save_steps = 1000 # 1,000번 훈련 스텝마다 모델을 저장합니다\n", "copy_steps = 10000 # 10,000번 훈련 스텝마다 온라인 DQN을 타깃 DQN으로 복사합니다\n", "discount_rate = 0.99\n", "skip_start = 90 # 게임의 시작 부분은 스킵합니다 (시간 낭비이므로).\n", "batch_size = 50\n", "iteration = 0 # 게임 반복횟수\n", "checkpoint_path = \"./my_dqn.ckpt\"\n", "done = True # 환경을 리셋해야 합니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "학습 과정을 트래킹하기 위해 몇 개의 변수가 필요합니다:" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "loss_val = np.infty\n", "game_length = 0\n", "total_max_q = 0\n", "mean_max_q = 0.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 훈련 반복 루프입니다!" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n", "반복 13992\t훈련 스텝 3999999/4000000 (100.0)%\t손실 1.739694\t평균 최대-Q 221.029893 " ] } ], "source": [ "with tf.Session() as sess:\n", " if os.path.isfile(checkpoint_path + \".index\"):\n", " saver.restore(sess, checkpoint_path)\n", " else:\n", " init.run()\n", " copy_online_to_target.run()\n", " while True:\n", " step = global_step.eval()\n", " if step >= n_steps:\n", " break\n", " iteration += 1\n", " print(\"\\r반복 {}\\t훈련 스텝 {}/{} ({:.1f})%\\t손실 {:5f}\\t평균 최대-Q {:5f} \".format(\n", " iteration, step, n_steps, step * 100 / n_steps, loss_val, mean_max_q), end=\"\")\n", " if done: # 게임이 종료되면 다시 시작합니다\n", " obs = env.reset()\n", " for skip in range(skip_start): # 게임 시작 부분은 스킵합니다\n", " obs, reward, done, info = env.step(0)\n", " state = preprocess_observation(obs)\n", "\n", " # 온라인 DQN이 해야할 행동을 평가합니다\n", " q_values = online_q_values.eval(feed_dict={X_state: [state]})\n", " action = epsilon_greedy(q_values, step)\n", "\n", " # 온라인 DQN으로 게임을 플레이합니다.\n", " obs, reward, done, info = env.step(action)\n", " next_state = preprocess_observation(obs)\n", "\n", " # 재생 메모리에 기록합니다\n", " replay_memory.append((state, action, reward, next_state, 1.0 - done))\n", " state = next_state\n", "\n", " # 트래킹을 위해 통계값을 계산합니다 (책에는 없습니다)\n", " total_max_q += q_values.max()\n", " game_length += 1\n", " if done:\n", " mean_max_q = total_max_q / game_length\n", " total_max_q = 0.0\n", " game_length = 0\n", "\n", " if iteration < training_start or iteration % training_interval != 0:\n", " continue # 워밍엄 시간이 지난 후에 일정 간격으로 훈련합니다\n", " \n", " # 메모리에서 샘플링하여 타깃 Q-가치를 얻기 위해 타깃 DQN을 사용합니다\n", " X_state_val, X_action_val, rewards, X_next_state_val, continues = (\n", " sample_memories(batch_size))\n", " next_q_values = target_q_values.eval(\n", " feed_dict={X_state: X_next_state_val})\n", " max_next_q_values = np.max(next_q_values, axis=1, keepdims=True)\n", " y_val = rewards + continues * discount_rate * max_next_q_values\n", "\n", " # 온라인 DQN을 훈련시킵니다\n", " _, loss_val = sess.run([training_op, loss], feed_dict={\n", " X_state: X_state_val, X_action: X_action_val, y: y_val})\n", "\n", " # 온라인 DQN을 타깃 DQN으로 일정 간격마다 복사합니다\n", " if step % copy_steps == 0:\n", " copy_online_to_target.run()\n", "\n", " # 일정 간격으로 저장합니다\n", " if step % save_steps == 0:\n", " saver.save(sess, checkpoint_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아래 셀에서 에이전트를 테스트하기 위해 언제든지 위의 셀을 중지할 수 있습니다. 그런다음 다시 위의 셀을 실행하면 마지막으로 저장된 파라미터를 로드하여 훈련을 다시 시작할 것입니다." ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n" ] } ], "source": [ "frames = []\n", "n_max_steps = 10000\n", "\n", "with tf.Session() as sess:\n", " saver.restore(sess, checkpoint_path)\n", "\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " state = preprocess_observation(obs)\n", "\n", " # 온라인 DQN이 해야할 행동을 평가합니다\n", " q_values = online_q_values.eval(feed_dict={X_state: [state]})\n", " action = np.argmax(q_values)\n", "\n", " # 온라인 DQN이 게임을 플레이합니다\n", " obs, reward, done, info = env.step(action)\n", "\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", "\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(5,6))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 추가 자료" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 브레이크아웃(Breakout)을 위한 전처리" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "다음은 Breakout-v0 아타리 게임을 위한 DQN을 훈련시키기 위해 사용할 수 있는 전처리 함수입니다:" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "def preprocess_observation(obs):\n", " img = obs[34:194:2, ::2] # 자르고 크기를 줄입니다.\n", " return np.mean(img, axis=2).reshape(80, 80) / 255.0" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] } ], "source": [ "env = gym.make(\"Breakout-v0\")\n", "obs = env.reset()\n", "for step in range(10):\n", " obs, _, _, _ = env.step(1)\n", "\n", "img = preprocess_observation(obs)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAF2CAYAAABH1m23AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGglJREFUeJzt3Xu4blVdL/DvzyAFCSoktVAJUVHQgkLJMLGjgJdS00qPPdVJO6e8BlGWXSS7GSJqqT2dzCzJY3QxFRFM0zIVL+AlvJ28oOERAS94RRF+5485l7x78a7N3jD2Xmvt/fk8z3xY75xjjjnmu/bm/T6/Mea7q7sDAMCNc5P1HgAAwK5AqAIAGECoAgAYQKgCABhAqAIAGECoAgAYQKgCABhAqNrJqmqPqnpWVX2wqt5fVf9aVUfOx36hql600PYxVXXWkj5Oqapnr9H/GVV10cJ2WVX9zXzsvKo6dsk55646Z99tuI+bVNXfV9UR2373rFZVP1ZVf7Te4wDgxhOqdr5fTHJokrt2952TnJ7kb7fS/r5VdcniluTktRp3909190ErW5LfSvL5ZW2r6siq+kCS2yW5MsmX50OXVtXdruc+npjkw939zrmvvavq1Krqqjp61XW+qar+qKreW1Vvr6rXLRx7+Lz/PVX1b1V1x+u57rL7uGVV/dkcUt9WVW+sqrsuHP/2qnrxPLZbrTr3cVX1gaq6sKrOqqpbLun/lKr61BxK3zRv9104vmdVPbmq3lVV766qN1fV2VV1r/n4i6rq4/P5b6uqd1TVA5Oku/8xyZ2r6n7be98AbCxC1c5Xmd73lfd+j+tpf05332pxS3LaVi9Qtc9K9SvJAUkuXdauuy/o7kO7+9AkRyZ5XpIvJvmR7n7PVvrfK8mTV43jV5O8P8nHlpzyF0kuTnJ4dx+V5Cfmfg5J8mdJHtzdd0vywiT/sOpap1fVTyy8/vaqem1VHbDQ7Mgk53b3nbv77kn+KckzF44/LcmZS+7j2Pk+7tXdhyd5xzzWZV7W3Ud39w8meUKSl1XVLeZjL0lyRJL7dff3dPc9k/x0kqsWzn/+fP7dk5yY5K8Wjv1hklPXuC4Am8T1faAz3p8muUOS91VVJ7kkySO30v74qrp41b59MwWQtRya5Iz5v4ckecPCsd+oquO7+9fnSsrhSY7NVK369yRXJHliVX1PktesEa5+NMl7uvuylR3dfUqSVNVTFxvOladbJ/nuJG+tqs9mqp59OsnDMoWhD83NX5zktKq6a3f/x7zv95O8uqr2TnJ2klcneeaqa7961fg+mYU/2939+Hksq+/jJ5OcsdDXc5JcVlX7dfcVS+57pb8LqurzSQ6ag9mdknxfd1+10ObyJJev0cXtkrx7oe1bquoWVXVYd793resCsLEJVTtRVd0lyR2TvC7J65PcNMnNk9ynqh6WVVWe7n5BkhfM556c5NDufsx2Xvb3kny2qvabX5+T5E3zz4dkqqb8QXd/40N+riDdO1NVbZnvy0IouB73SnL3JM/o7pOq6oQkZ1fV7ZMcnOTDKw27++qq+ti8/z/mfZ+ep8bOTnJKkpPmKbOl5um7pyV59DaM7eAsVLC6+7NVdUWSg7Z2f1X1kCRfS3Jhkp9N8oqVQFVVT0jyqCTfkikwnjSf9tj5vAMy/b176Kpu353pfRWqADYpoWrnOijJMfPP35Xkh5I8K1NF4/xMgesGq6rfyVT9uVmS21TVRZmC0VeTvHRudn53n1dV585jWDl3rT5f3N2rF1Lvk+RT2zis70jy79392iTp7nPmcR0zj+3qVe2/nutOS39bpurcJzMFwbXGun/m8NXd/7oNY9vW6yfJQ6vqezNVmc5LcnR3Xzm/b3utNOruP0nyJ1X1a5kqhSue391Pn8f5vUleVVXHLVSmvphkvwCwaQlVO1F3n53pQ39lPc8duvu0qnp7pgrGtyR5ZVUdmOmDu+btmpU+lkwFPrq7z51/fnqmkLZvki8kuaK7F889buG8h2bb1tR9bcm+j2WawtwWl+a6C+WvyRRmLs4UNBfdZt6fJKmqOyX5xyS/kORtSf6uqvZemW5caHfrTFODp3b3S7ZxbBcnue1CH3sn2X/x+gte1t2/MC8wf26mEJRMv6eTquomi+/11nT3u6rqLUmOz7WVqdsk+eg2jhuADchC9XVQVX+b5C4rr7v7qPlJvd+YX1/c3QdmWsB8dncfuJXt3IV+vtLdn0vy10l+cGsf8t395ST3S/KBNbZ/7u4vdveyUHVWkh/exts9K8lxK0/1VdUxmYLMWzIt8H7gHCJTVT+e6SnE8xfOf3CS/9Hdb+zur2aqxO0/V6Uyn3e7TFOqv7sdgSqZ1nA9amFq9HFJ3rS4Xmu17n5VpqnJp8+7zswUEJ9XVfssND1wrT7mKcp7JHnr/Hq/JIdlmhIGYJNSqVofd8r0VQp/uWr/5Uk+vgOv+wdJPrjyortfluRlqxtV1aGZ1l4t1d3vraoPV9X9lywSX932U1X1qCT/UFVXZZqKfMi8EPyKeQ3SK+djX0hy/8Uw2N2nrurvqkxP3y16ZpJbJvmVqvqVed9Xu/ve1zO211fV85L863z9/5fkEVs7Z/akJBdW1Znd/caquk+SpyQ5r6quzhQML82WTyCurKlaefrzN7t7ZW3bzyf5q+7+0jZcG4ANqrp7vcew26mqd2V6Iu6qJYff293Hz+0en+SPknx2SbsPdfexa/T/hkxP9V255PBTu3utrw1YOf/QTF/lcNBW2tw5yZ8nuc/iU29sn7ni9uokx3f3st8zAJuEUMUNVlX3TvLp7r5wvceyWVXVDyW5vLvft95jAeDGEaoAAAawUB0AYAChCgBggA319N/8z7YAu6juXutb+gE2PZUqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAATbUVyrs6k4//fRhfZ100kmb/hpr9b+z7Cr3sZa17m+jjhdgs1OpAgAYQKgCABjA9N8GsCtM5e2sa+wMpscAuCFUqgAABhCqAAAGMP0Hq2zvdKXpQgASlSoAgCGEKgCAAUz/sdva3mm7zfYUIwA7l0oVAMAAQhUAwACm/zaAnTGttKtcY6TNNl4ANjaVKgCAAYQqgF1UVdV6jwF2K929YbYkbbPZdt1tvf8fs6tvSS5KcnGSj8///UKSPedjL03yswttT05yyby9fVU/Fyc5ZP75j+fXi9vt52OHJrlo4byDklw9j2Nxu2uSQ5JcvGTM37+k/VVJ7pDkhCRvWHLO/Ve1/+0ktQ3vzw8m+Zv1/j3toN/9f1v1Ozpm3v+GJPdN8vQkpyy0f8Tc7lNJfnlh/8229nc1yb8keeA2jqmS/F6SU5L8bJIzFo4dn+SCJJcleWOSoxaOnZfk2PnnZyd50Hq/v9u6WVMFsIvo7oNWfq6q30zyfd191RptT0ty2kL76vlTbFW7JyZ54kK7i5PsuZVhXLY4joXzDlljHO/IFMYW216S5PPL2lfVC5PcM8mV866rkvx6poD339caVFXdPMlfJLnX/PoeSZ6RKUTsmeQjSZ7Y3Z+oqm/OFCbvnSkYnJvkpO6+eq3+17jmAzIFip6vc8F8jS9U1X5J/jzJYUm+KcmLu/v3l/Sxb5L/u6T7myTZu7v3SZLufl2SA+dz/n2+3lrjukuSZyY5JtP7/Iaqen93n70Nt3WHJF9co9835Npg/W2ZgtyydrdL8uIkD+jud1TVA5O8oqoO6e4vrWr+lCTnV9Vbu/uybRjfujL9B7CLqarbJ/mlJKdU1dPmkPLgheMnVdXF83ZRVX0uydPWa7zzmO5SVbeoqpsk2S/Jp5e16+6f6+5Du/vQJI9M8qEkf5/k56/nEk9IcvbCB/NLkvxpd989yZGZqjanzseekimgHJbk8CR3SXLiwlhvXlX/UlXfubDvuKr6i4XX3zyP60ndfY/5Gvsl+bW5yXOSfLq7D8tUrXtEVT18yf1+vrtvtXpL8gOZKpE3xM8leWF3f7S7P53kWfO+raqqh8338NvLppa7+9juPqi7b5/knzJVoZY5Msl75kCd7n5Vpnu545I+v5zkrzP9Tja8DVWp8jQWwI1TVXdK8vIkX0pyz+7+7Uwfgi9dadPdpyc5fW5fSd6d5NWDhnBAVV20at/JSd51PeedmuSMJG9Ncml3f33+3L5zVZ2RKYT8V6ZK0/cmuU+mIPSZJEck+f2qekeSM7v7a0v6/5lsGRw+MY+1knxzkm+d9yXJT2YKQ9ckuaaqnpfktzJX9rr7S1V1WpLXVNWDktwt0zTXAxb6vzrT1Np3zK/3SvItST4xB8eHZwoX6e4vzoHskZmC2La4RdYOLUmyZ1UdnWTfJccOSXLmwuv3ZQqda6qqgzNV9h6R6b18TrasYN4kyQ9lquztk2mq9TlJbr+ku3ckOXyuFr4tyY8muXmSD65x+b9L8qaq+uX5d7JhbahQBcANU1V7JXlspgDzpCSvSXLW/MF68lZOPTHJJd395qr6YKYPt1st9PuiJA/JllM+a00pXpRpKitVdXmS75/3rTn9t8SlSU6Yp8f2zLTm60VJPpZpSuuOSd6S5NTuvmLh3o/JtA7sOoGqqvbOND347oXdP5bkVUmenGTvJGdlmkZMkoOTfHih7YfnfYv3enZVfSXJ6zK9N8d19yULx6+uqvsmObuqnpNk/yTP7e7nV9Wt52t+ZNU1fmZb3qDZ/kk+uXCPJ2SqDq2MYY9MQe+71zi/1vh5y0ZVe2aaVv2DJL8z3/frk7yyqs5K8oTu/mimma+fTPK1JN+V5DaZwv23ZloX9Q3d/V9V9chMwf62ST6QaZ3Wl5eNobs/NP8OD8y0XnDDEqoAdg2PyFSx+YGFIHNspgrEddZKzccfnmmacI+qOry77zTvv3hV01O6+9nX6eD67VlV+yf5zkzVi9XXv2WmUJJMH67fn6lidXWmCtSfJ/lsd7+2qh6cZGXN0SPn85fd0+MyvQeLU2P7JLkm167DSqbpv3OT/E6mz8LnZ1pn9EuZQsbi+qmvZ/lymYOSfDZTQDog1waalTDy8iTP6O4XzCHx/1TVL2WqEvU8puu7xlpuk2mR/qLzuvvYhdevnNc5rfahTCFzxZ2zZYhc9OBMYe8h3f32JOnur1TV8Ul+Nck/VdUx8/v9i/Pv+41JHtHdr6iq31vWaXe/PlM1a5m/yVSVXPSlTFOPG5pQBbAL6O6/TPKXVbV/VR3c3R/p7q9nWi+zRQCZKztPSfLQTB9sRyY5t6r+57y+ZavmwHBwkpsu7DsiySvnl9dkqmadkykcXZLpg3K1SzNVmPbMND32qe7+6kKfJyy0fVWuDWBb1d2rF1JfluSrmcLdxVV1i0xPxP3ovPj86qo6NcnbM4WqizOFvIvm82877/uGqnp8kh/PNA15cJIzq+qnuvv8ucn3JNm/u18wj+mKuWL1jCTPyxSiFisv17nG9TgsW1a6tlBV3561F6u/MMk/z4v+P5+pWvnkZQ27+++zZEpyft/+cN5WrnlUpqris7v7Fdd3A/MU6kOWHPquJP+ROehV1c0yVeY+dn19rjcL1QF2LT+SaxdcX8e8hugVmULMD3T3f3X3yzOt8Tl0ySlXJ/nNeUH7R+YpwvOSPDULoaq739ndB2ZaV/Ou7r51d9++u4/q7h/JtI5mCz35XKaA8I+LgWpJ26/PYenVmaaLlm33XxKoMj/V+KpMXzuQTNWly+d7XvHjuXZNz4szVV1qXiv02CwEi3kq6rD5el/o7ndnmk588EJ/H0+yb1Xdez7nm5I8LMkH5ycyX5rkcfOxvZI8Jtu+nirz9d+/8PqaJHevqk9U1UczTQUes+zE7n5fkl9J8tpMTyT+9TY++bem+YnC5yZ5THf/7205p7tP7u5DVm9Jzl/V9IeTvLm7lz4RupGoVAHsHk5K8qXu7qo6fvWC3+5+S6a1Slm1/9FJHr2sw6paFsL2zrwAe5XPZF4cvx0uzFTZWRzPvdYYy0szrQdby7MzTe/91bze6UFJTq2qJ2VaB/bJJI+a2z59br/y4f5vuXbqceWJtF9cNa73Z/q+rJXXl85Py506VwlvmikEPX5ucmKSP6uq8zNNN/5DpirP4j2dli2D36K9kjy3qv54vqenZnrvt1BVj112cne/JNMU6Jrm6eHTttZm9sHuPj7JPbah7WL/ZyQ5Nsnnlhxe/GqFx2bVuqyNSqgC2PU8YMm6qGT6EP3VnfAE1a3WuH6q6gVrVBzutsY5V2SqMt0o3f2mqnpvVf1Ed5/Z3W/L9IG+rO1XkvyvAdc8J9MU6LJjn8lUHdva+Sdn6w8Z7FBrTf0NdrMsWW+XKaC9fa70fb27X7aDxzFE9XW/623dPOtZz9o4gwGGO/HEE/2zKayb+bujfnplnRMbX1X9VJKXr3rwYMNSqQJgtzB/3YJAtYl09xnrPYbtYaE6AMAAQhUAwACm/wBunO1eC7rsSyuB9XED15Yv/UusUgUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADDApv4HlU866aT1HgLs1k4//fT1HgLAhqFSBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMMAe6z0AgN3NBRdcsN5DAHYAlSoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABfKM6wE62zz77rPcQgB1ApQoAYAChCgBgAKEKAGCATb2m6rwTTljvIcBu7c3rPQCADUSlCgBgAKEKAGAAoQoAYAChCgBggE29UB1gM7ryyivXewjADqBSBQAwgFAFADCAUAUAMMCmXlN1zSGfX+8hAAAkUakCABhCqAIAGECoAgAYQKgCABhAqAIAGGBTP/0HsBnttdde6z0EYAdQqQIAGECoAgAYYFNP/31m3y+v9xAAAJKoVAEADCFUAQAMIFQBAAwgVAEADCBUAQAMsKmf/gPYjP7zP/9zvYcAzA455JBhfalUAQAMIFQBAAwgVAEADLCp11R95tCvrfcQYPd2+XoPAGDjUKkCABhAqAIAGECoAgAYQKgCABhAqAIAGGBTP/0HsBkdcMAB6z0EYAdQqQIAGECoAgAYYFNP/73kmtuu9xBgt3bceg8AYANRqQIAGECoAgAYQKgCABhAqAIAGECoAgAYYFM//QewGR111FHrPQRg1t3D+lKpAgAYQKgCABhgU0//fe2lp6z3EGD3dtyb13sEABuGShUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAm/ob1f/lnKPXewiwW3vQcaev9xAANgyVKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAH2WO8BALB7uuCCC7Z4feSRR67TSGAMlSoAgAGEKgCAAYQqAIABhCoAgAGEKgCAATz9B8C68LQfuxqVKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAF8pcIu4LwTTvjGz0efc846jgQAdl8qVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADePpvF+CJPwBYfypVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAAPssd4DANjMLrzwwvUeArupCy64YIdf48gjj9zh11hv73znO7f7nCOOOGLpfpUqAIABhCoAgAGEKgCAAYQqAIABhCoAgAE8/QdwIxx++OG1ved0944YCgznz+r2UakCABhAqAIAGGBDTf+d9a1fXO8hsBs474QTdvg1jj7nnB1+jY3gnq95zfadcOKJO2YgABuAShUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAG+orFWBn2F2+7gCAnUulCgBgAKEKAGAA03/ADba9U6n+aVZgV1Yb6V+grqqNMxhguO6u9R4DwI5i+g8AYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGCA6u71HgMAwKanUgUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMMD/B5H0wxbowGb1AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"원본 관측 (160×210 RGB)\")\n", "plt.imshow(obs)\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"전처리된 관측 (80×80 그레이스케일)\")\n", "plt.imshow(img, interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "여기서 볼 수 있듯이 하나의 이미지는 볼의 방향과 속도에 대한 정보가 없습니다. 이 정보들은 이 게임에 아주 중요합니다. 이런 이유로 실제로 몇 개의 연속된 관측을 연결하여 환경의 상태를 표현하는 것이 좋습니다. 한 가지 방법은 관측당 하나의 채널을 할당하여 멀티 채널 이미지를 만드는 것입니다. 다른 방법은 `np.max()` 함수를 사용해 최근의 관측을 모두 싱글 채널 이미지로 합치는 것입니다. 여기에서는 이전 이미지를 흐리게하여 DQN이 현재와 이전을 구분할 수 있도록 했습니다." ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", "def combine_observations_multichannel(preprocessed_observations):\n", " return np.array(preprocessed_observations).transpose([1, 2, 0])\n", "\n", "def combine_observations_singlechannel(preprocessed_observations, dim_factor=0.5):\n", " dimmed_observations = [obs * dim_factor**index\n", " for index, obs in enumerate(reversed(preprocessed_observations))]\n", " return np.max(np.array(dimmed_observations), axis=0)\n", "\n", "n_observations_per_state = 3\n", "preprocessed_observations = deque([], maxlen=n_observations_per_state)\n", "\n", "obs = env.reset()\n", "for step in range(10):\n", " obs, _, _, _ = env.step(1)\n", " preprocessed_observations.append(preprocess_observation(obs))" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlEAAAEuCAYAAACu4EdXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC+JJREFUeJzt3W2opGd9x/HfPy5Zk12ptUlqodLUbkEQX+SkKSm0TWpD6QM+Va0BEYOYptD6kIONLbRFQqQW05MoIizVui9EfOELW1sQlNZiDAE3JwopaGLElrSNREJ93hp3r76YO2Eczzm75785mTO7nw8MzLln5rqvCeTiO9c8bI0xAgDA7lyw7AkAAKwiEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUSd46rqcFWNqvrK3GXM3f7hqnrjdP1VVfW1hcu75u77SFVdvsU55h/39unYsaq6oaquq6q75u47qurZ28z1o1V1ww7P5Zaq+uBOj6mquxbm/6Kqum26XFdV95z+vxqwH1XV3VX1ioVji2vAX01r1eLl21W1sc241jBaDix7Ajw9xhhHnrheVT/c5m6Hk9wzxrh+l2N/LMnHprFrN4+tqoeSHJr+/Ikkn9zh7pclefg0Q/5skpePMb4wd47X7GZOwL51SZL/2ekOY4xbk9y6eLyqbk3y+DaPsYbRYifqPPXEq7MkrzzLcZ4/t8P1QJLvVNUzzvTxY4xfGGM8d4zx3CT/sMN5KsnvJfn1qjq4cPN7pufzM60nAex7VfWCJL+Y5EXNIZ6XLQLGGsbZsBN1npr+h09Vffgsx/lqkiPTWGtJbh9jnDzdi7lpkfqpJD+dZC3J5mlO9fYk/5rkC0neX1U3jjFOTbe9ZYxxbBq3+UyA/WoKkHcm+VCSP6+q42OM+xbuc1mSv99hmF9O8sKqemWS28YY9yTWMM6OiDr3nUryUFV9Ze7Y13a4/0uqav7228cY7zvDc92S5Ojc3xtJTib58tyx/8hsETmV5HuZbc0/mGTL9/mr6oLMFp/XJfmVMcY3q+qqJP+2wxb3P1bVD5L8cIzxgunYm5PclOShM3wuwD5QVQeS/G2Sn0/yq0muTvLPVXXLGGP+ReA3ktxwhsN+a5vj1jB2RUSd48YY38v0Kmsb70/y9bm/P7Hbz0QlSVW9Osmrk3xx7vB6Ztvn75ibz+U7jPG+JP+9cPhIkt9Ocs0Y45vTGDdW1WsX5j3vpfOfJ5i8N8lnktx2mqcC7C+3JPm5JNdO69m/VNU1Sd44BUqSZNrV+cb0Ie2/2Gasj48x3rbVDdYwOkTUeaKq3pzZK5lFl2W2UPzYq5uqujjJszJbBP53h7FfkeT2zF4hfqSqHjvNXP40yVu2uOknk/xxkq8+cWCM8UCSa6bHvTzJH2b2ivRkZp/nOrnN3A8meU5miy+wuv4myakxxpjC4w+SXJ7ZLtDRJN/P3O769LbYscVBavYt5Gu3OoE1jC4RdZ4YY7w3s1cyP6Kqjs39+Z0kV09v551K8t0kjyX5UrZYlKbHvynJjUl+a4zx5ar6zcxetR3a6v7TXN6d5N1bjLXt57Oq6mVJ7sxsu/7zSZ6R5MVJPpDkXXN3fSyzb8f8X5L/SnJXkh9sNy6wv40xTiZJVf1Zkt9P8tbMdosuTvLSJH89XVqsYZwNEXWeqKp3JPmTzD43sOifkh/9mu82Y2x1+O+SfGCM8f1pjP9M8oaFOFsc57ZpLlvtbn18m4e9OMmxMcZn5u87bcH/WpJ7p/OvbXM+YLX9TpJ3jjHunv7+bpIPVtXvJvmNJA8kSVW9NclfZuufQvjEFsesYbT5iYPzy4Ekz9zicnV3wDHGiScWn6fIL21z/NNJXl9V11bVxVX1rGlr/LrMvvECnNs+leRtVXVFVT2zqp5dVddnFiCfXbjvBZmtd4uXaxcHtYZxNmqMcfp7wT5QVS9J8kdJnp/Z5wi+lOSOMcbnljoxYM9NP3Pwhsw+E/W8JCeS/HuS94wxji9zbmfKGnbuEVEAAA3ezgMAaBBRAAANT/u38+64445dv3+4vr6+F1MBGjY2Nnb9mJtvvvlc+rcsdr2G+ac8YP9ofoxpy/+J7UQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAw4FlT+BMrK+vL3sKAG2bm5vLngKwB+xEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANKzEL5YfPHhw2VMAaDt8+PCypwDsATtRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGhYiR/bfPzxx5c9BYC2EydOLHsKwB6wEwUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANCwEr9YfuGFFy57CgBtF1100bKnAOwBO1EAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaFiJH9t89NFHlz0FYHLJJZcsewor58EHH1z2FIDJkSNHnrKx7EQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0rMQvlh86dGjZUwBou/TSS5c9BWAP2IkCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQMNK/NjmnXfeuewpAJONjY1lT2HlXHXVVcueAjAZYzxlY9mJAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCg4cCyJ8DeWp+7vrG0WQD0XHnllU9ev/fee5c4E/hxdqIAABpEFABAg4gCAGjwmahzkQ9CASvspptuevL68ePHlzgT2JmdKACABhEFANDg7bxznHfzgFXmZw3Yz+xEAQA0iCgAgAZv5620bb6G5z08YAVsbm4+eX1tbe3J60ePHl3GdGDX7EQBADSIKACABhEFANAgogAAGkQUAECDb+etNF/DA1bX/DfyYBXZiQIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANBx4uk/4yCOPPN2nhCTJ+vr6np9jY2Njz8+xbA8//PCyp7BU999//7KnwHlqc3Nzz8+xtra25+dYtvvuu2/Xj7niiiu2PG4nCgCgQUQBADSIKACABhEFANAgogAAGmqMsew5AACsHDtRAAANIgoAoEFEAQA0iCgAgAYRBQDQIKIAABpEFABAg4gCAGgQUQAADSIKAKBBRAEANIgoAIAGEQUA0CCiAAAaRBQAQIOIAgBoEFEAAA0iCgCgQUQBADSIKACABhEFANAgogAAGkQUAECDiAIAaBBRAAAN/w/FrWRo1t7eUAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img1 = combine_observations_multichannel(preprocessed_observations)\n", "img2 = combine_observations_singlechannel(preprocessed_observations)\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"멀티 채널 상태\")\n", "plt.imshow(img1, interpolation=\"nearest\")\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"싱글 채널 상태\")\n", "plt.imshow(img2, interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 연습문제 해답" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. to 7." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "부록 A 참조." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. BipedalWalker-v2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*문제: 정책 그래디언트를 사용해 OpenAI 짐의 ‘BypedalWalker-v2’를 훈련시켜보세요*" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "env = gym.make(\"BipedalWalker-v2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트: 만약 `BipedalWalker-v2` 환경을 만들 때 \"`module 'Box2D._Box2D' has no attribute 'RAND_LIMIT'`\"와 같은 이슈가 발생하면 다음과 같이 해보세요:\n", "```\n", "$ pip uninstall Box2D-kengz\n", "$ pip install git+https://github.com/pybox2d/pybox2d\n", "```" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "img = env.render(mode=\"rgb_array\")" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAB7FJREFUeJzt3dF100gAhlFlD02sKMNpI23EbdBGaIM2cBmIMrwPS4JiO7ZsS9HM/Pc+7ULMEWTm02isyA/7/b4DoG3/rH0AACxP7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQJ8WfsAxoah8+wGJun7tY/gvWFY+whoWd93D/f+GUXFHs4pLfBjp47NCYCSiD3FKzny53x03E4CrEHsKVKtgZ/CSYA1iD3FaDnwU9gKYkliz6rSA3+JqwDmIvasQuTv8/rvJ/pMJfZ8GoGfn5U/U4k9ixP5z2f/n0NizyIEvjyuArKJPbMR+Dq5Csgg9txF4NvkKqA9HoTGzYQ+T9/7vtfKyp6rmOh0nZV/jcSeiwSeqez/l0vsOUngmcvhWBL/dYg9bwSepQn9esQ+nMBDBnfjBBN6yGFlH0bgIZPYBxB4SmC/fl1i3yiBB8bEvjEiT4ms6tcn9g0QeOASsa+YyANTiX2FRB64lthXQNypmf36Moh9wUQemKsDYl8YgYdsSzVA7Asg8LTKFs55nzn3xX5FIg951pr3Yr8SoYcMpcx1sV/JMJQzCIB5lDynxX5Fr/uZJQ8QuFXCfn1Nc1fsC2CVD+WrfY6KPcAJtcf9kNgXwuoe1tfyHPSxhAVJ2OMkQ61judbjnkLsC9PyYIMaDEOb81DsAU5oLfpiX6DWBhnUrJW5KPYFa2WQkaXFcdvCAkzsASaqOfhiX7iaBxd5EsZrrat8sa9AjQMLWlfbvBT7StQ2sCBBTat8sa9ILYOKTMnjs4boi31lSh9QkKzk6It9hUodTEC5PAgNYGbjBVkpD1ezsq+U1T0lMR4/Vsq/jdhXrJRBBJxXwl6+2Fdu7QEETLfmfBX7Bgg+1GOtVb7YN0LwWYuxd5vPjr7YN8Skg/p81rwV+8YIPtTnM1b5Yt8gwYc6LRl9sQduZmGxjCX+XcW+USYh1G3uVb7YN6yEH+QA7jPXHBZ7gABiH8DqniUYV3UR+xAmJnMynuoj9kFMUMgl9mEEHzKJPUAAsQ9kdc89jJ86iX0oExayiH0wwYccPnA83DCU84HI3M6Jm0vEHsGfkehSKrGn67r2gi+68J7YUyzBhvmIPW9e43q4whddqJ/Yc0TcoT1uvQQIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiDAl7UPINV2+23S1728TPs6gHMe9vv92sfwZhi6cg7mRlMjvumfJ33dbvj+9t/CD5n6vnu4988Q+5ltt98mh/wa4+h3nfBDkjlibxunEocnkPEVhPADl4h9pcbxP9w6En/gkNg34KNVv+gDr9x62aAl3jMA6ib2AAHEHiCA2AME8AbtAg7viQdYmx+qKsT28bHb9Ju3/98Nu+7l588VjwgoxRw/VGUbByCA2Bdq02+67ePj2ocBNELsAQKIPUAAsQcIUO2tl1+/Hr85/etX7M08AGdVG/tXv/7erXjyBNB1TgIA1cd+bBz+scOTgPgDaezZF2Kz2XS7Ybf2YQCNampl//WDVlrJA+mqj/048KIOcFr1sRd4gMvs2QMEqH5l3xpv0gJLsLIvzKbfvD3q2MPQgLmIPUAAsS/E88vL0a/Z0gHmYs++EN+326Nf80lVwFys7AECiH0hdrvdu8+gfdV/8HA3gGuIPUAAsQcIIPYAAcS+QONbLgfP/gFmIPYAAcQeIIDYAwQQ+wq41x64l9gX4NSjEgDmJPYAAcS+Am6/BO7lqZcFGd9f74mXy9n+fvr7P8ePI3rv3FOmr3zty78/LrwA/vc2Rv+MsR/d/WPnYb8vZ9X4tHt6fzD3TMRLr7/ntZdev9ZrL73e37nruq7rV/qYgGF8nBf+zi+DE0OCdwuPkcMx+vL84+67NIqK/fb7UzkHA59keP7gN0YT3lVBG7b9QdxH3+Nzi5A5Ym8bB1bWf7/8NdvNQSTOXBm4KljPyZX66Hs15Xu9FLGHChyt+s6sAo9ODKdcsfWVfFWx/f101TbjydV5IZ8uKvbQmEnvSVwRoHcnj5neX1n66uNohX3jcfe7j3+vNmIPnPXu5HFv+P68fvv8wdXHrW/kH7zumiuhFGIPfLqb9q4F+y5+qAoggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABHvb7/drHAMDCrOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAvwHYhqbfbIagE0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(img)\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 2.74713873e-03, 2.63876747e-06, -3.44056450e-04, -1.60000169e-02,\n", " 9.22344476e-02, 7.98698165e-04, 8.60068083e-01, 7.68243025e-04,\n", " 1.00000000e+00, 3.26178409e-02, 7.98681285e-04, 8.53665292e-01,\n", " -5.43449074e-04, 1.00000000e+00, 4.40813839e-01, 4.45819944e-01,\n", " 4.61422592e-01, 4.89549994e-01, 5.34102559e-01, 6.02460802e-01,\n", " 7.09148586e-01, 8.85931492e-01, 1.00000000e+00, 1.00000000e+00])" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 24개의 숫자에 대한 의미는 [온라인 문서](https://github.com/openai/gym/wiki/BipedalWalker-v2)를 참고하세요." ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Box(4,)" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-1., -1., -1., -1.], dtype=float32)" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space.low" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1., 1., 1., 1.], dtype=float32)" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space.high" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이는 각 다리의 엉덩이 관절의 토크와 발목 관절 토크를 제어하는 연속적인 4D 행동 공간입니다(-1에서 1까지). 연속적인 행동 공간을 다루기 위한 한 가지 방법은 이를 불연속적으로 나누는 것입니다. 예를 들어, 가능한 토크 값을 3개의 값 -1.0, 0.0, 1.0으로 제한할 수 있습니다. 이렇게 하면 가능한 행동은 $3^4=81$개가 됩니다." ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "from itertools import product" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(81, 4)" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "possible_torques = np.array([-1.0, 0.0, 1.0])\n", "possible_actions = np.array(list(product(possible_torques, possible_torques, possible_torques, possible_torques)))\n", "possible_actions.shape" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()\n", "\n", "# 1. 네트워크 구조를 정의합니다\n", "n_inputs = env.observation_space.shape[0] # == 24\n", "n_hidden = 10\n", "n_outputs = len(possible_actions) # == 625\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "# 2. 신경망을 만듭니다\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.selu,\n", " kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs,\n", " kernel_initializer=initializer)\n", "outputs = tf.nn.softmax(logits)\n", "\n", "# 3. 추정 확률에 기초하여 무작위한 행동을 선택합니다\n", "action_index = tf.squeeze(tf.multinomial(logits, num_samples=1), axis=-1)\n", "\n", "# 4. 훈련\n", "learning_rate = 0.01\n", "\n", "y = tf.one_hot(action_index, depth=len(possible_actions))\n", "cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n", "gradients = [grad for grad, variable in grads_and_vars]\n", "gradient_placeholders = []\n", "grads_and_vars_feed = []\n", "for grad, variable in grads_and_vars:\n", " gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n", " gradient_placeholders.append(gradient_placeholder)\n", " grads_and_vars_feed.append((gradient_placeholder, variable))\n", "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아직 훈련되지 않았지만 이 정책 네트워크를 실행해 보죠." ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "def run_bipedal_walker(model_path=None, n_max_steps = 1000):\n", " env = gym.make(\"BipedalWalker-v2\")\n", " frames = []\n", " with tf.Session() as sess:\n", " if model_path is None:\n", " init.run()\n", " else:\n", " saver.restore(sess, model_path)\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " action_index_val = action_index.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " action = possible_actions[action_index_val]\n", " obs, reward, done, info = env.step(action[0])\n", " if done:\n", " break\n", " env.close()\n", " return frames" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = run_bipedal_walker()\n", "video = plot_animation(frames)\n", "HTML(video.to_html5_video())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "안되네요, 걷지를 못합니다. 그럼 훈련시켜 보죠!" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 1000/1000" ] } ], "source": [ "n_games_per_update = 10\n", "n_max_steps = 1000\n", "n_iterations = 1000\n", "save_iterations = 10\n", "discount_rate = 0.95\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " print(\"\\rIteration: {}/{}\".format(iteration + 1, n_iterations), end=\"\")\n", " all_rewards = []\n", " all_gradients = []\n", " for game in range(n_games_per_update):\n", " current_rewards = []\n", " current_gradients = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " action_index_val, gradients_val = sess.run([action_index, gradients],\n", " feed_dict={X: obs.reshape(1, n_inputs)})\n", " action = possible_actions[action_index_val]\n", " obs, reward, done, info = env.step(action[0])\n", " current_rewards.append(reward)\n", " current_gradients.append(gradients_val)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_gradients.append(current_gradients)\n", "\n", " all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n", " feed_dict = {}\n", " for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n", " mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n", " for game_index, rewards in enumerate(all_rewards)\n", " for step, reward in enumerate(rewards)], axis=0)\n", " feed_dict[gradient_placeholder] = mean_gradients\n", " sess.run(training_op, feed_dict=feed_dict)\n", " if iteration % save_iterations == 0:\n", " saver.save(sess, \"./my_bipedal_walker_pg.ckpt\")" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "INFO:tensorflow:Restoring parameters from ./my_bipedal_walker_pg.ckpt\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/haesun/anaconda3/envs/handson-ml/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", " result = entry_point.load(False)\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = run_bipedal_walker(\"./my_bipedal_walker_pg.ckpt\")\n", "video = plot_animation(frames)\n", "HTML(video.to_html5_video())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "최상의 결과는 아니지만 적어도 직립해서 (느리게) 오른쪽으로 이동합니다(훨씬 더 오랜 학습이 필요할 것 같습니다 :). 이 문제에 대한 더 좋은 방법은 액터-크리틱(actor-critic) 알고리즘을 사용하는 것입니다. 이 방법은 행동 공간을 이산화할 필요가 없으므로 훨씬 빠르게 수렴합니다. 이에 대한 더 자세한 내용은 Yash Patel가 쓴 멋진 [블로그 포스트](https://towardsdatascience.com/reinforcement-learning-w-keras-openai-actor-critic-models-f084612cfd69)를 참고하세요." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9.\n", "**Comming soon**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }