{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Policy-Gradient and Actor-Critic Methods\n", "\n", "> 그로킹 심층 강화학습 중 11장 내용인 \"정책 경사법과 액터-크리틱 학습 방법들\"에 대한 내용입니다.\n", "\n", "- hide: true\n", "- toc: true\n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Reinforcement_Learning, Grokking_Deep_Reinforcement_Learning]\n", "- permalink: /book/:title:output_ext\n", "- search_exclude: false" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: 실행을 위해 아래의 패키지들을 설치해주기 바랍니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#collapse\n", "!pip install tqdm numpy scikit-learn pyglet setuptools && \\\n", "!pip install gym asciinema pandas tabulate tornado==5.* PyBullet && \\\n", "!pip install git+https://github.com/pybox2d/pybox2d#egg=Box2D && \\\n", "!pip install git+https://github.com/mimoralea/gym-bandits#egg=gym-bandits && \\\n", "!pip install git+https://github.com/mimoralea/gym-walk#egg=gym-walk && \\\n", "!pip install git+https://github.com/mimoralea/gym-aima#egg=gym-aima && \\\n", "!pip install gym[atari]\n", "!pip install torch torchvision" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings ; warnings.filterwarnings('ignore')\n", "import os\n", "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n", "os.environ['CUDA_VISIBLE_DEVICES']=''\n", "os.environ['OMP_NUM_THREADS'] = '1'\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import torch.multiprocessing as mp\n", "import threading\n", "\n", "import numpy as np\n", "from IPython.display import display\n", "from collections import namedtuple, deque\n", "import matplotlib.pyplot as plt\n", "import matplotlib.pylab as pylab\n", "from itertools import cycle, count\n", "from textwrap import wrap\n", "\n", "import matplotlib\n", "import subprocess\n", "import os.path\n", "import tempfile\n", "import random\n", "import base64\n", "import pprint\n", "import glob\n", "import time\n", "import json\n", "import sys\n", "import gym\n", "import io\n", "import os\n", "import gc\n", "import platform\n", "\n", "from gym import wrappers\n", "from subprocess import check_output\n", "from IPython.display import HTML\n", "\n", "LEAVE_PRINT_EVERY_N_SECS = 30\n", "ERASE_LINE = '\\x1b[2K'\n", "EPS = 1e-6\n", "RESULTS_DIR = os.path.join('.', 'gym-results')\n", "SEEDS = (12, 34, 56, 78, 90)\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "plt.style.use('fivethirtyeight')\n", "params = {\n", " 'figure.figsize': (15, 8),\n", " 'font.size': 24,\n", " 'legend.fontsize': 20,\n", " 'axes.titlesize': 28,\n", " 'axes.labelsize': 24,\n", " 'xtick.labelsize': 20,\n", " 'ytick.labelsize': 20\n", "}\n", "pylab.rcParams.update(params)\n", "np.set_printoptions(suppress=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_make_env_fn(**kargs):\n", " def make_env_fn(env_name, seed=None, render=None, record=False,\n", " unwrapped=False, monitor_mode=None, \n", " inner_wrappers=None, outer_wrappers=None):\n", " mdir = tempfile.mkdtemp()\n", " env = None\n", " if render:\n", " try:\n", " env = gym.make(env_name, render=render)\n", " except:\n", " pass\n", " if env is None:\n", " env = gym.make(env_name)\n", " if seed is not None: env.seed(seed)\n", " env = env.unwrapped if unwrapped else env\n", " if inner_wrappers:\n", " for wrapper in inner_wrappers:\n", " env = wrapper(env)\n", " env = wrappers.Monitor(\n", " env, mdir, force=True, \n", " mode=monitor_mode, \n", " video_callable=lambda e_idx: record) if monitor_mode else env\n", " if outer_wrappers:\n", " for wrapper in outer_wrappers:\n", " env = wrapper(env)\n", " return env\n", " return make_env_fn, kargs" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_videos_html(env_videos, title, max_n_videos=5):\n", " videos = np.array(env_videos)\n", " if len(videos) == 0:\n", " return\n", " \n", " n_videos = max(1, min(max_n_videos, len(videos)))\n", " idxs = np.linspace(0, len(videos) - 1, n_videos).astype(int) if n_videos > 1 else [-1,]\n", " videos = videos[idxs,...]\n", "\n", " strm = '