{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "dK5SmJpnMM89" }, "source": [ "**18장 – 강화학습**" ] }, { "cell_type": "markdown", "metadata": { "id": "u9vrGAroMM9B" }, "source": [ "_이 노트북은 18장에 있는 모든 샘플 코드를 담고 있습니다._" ] }, { "cell_type": "markdown", "metadata": { "id": "B1XZUYIUMM9C" }, "source": [ "\n", " \n", "
\n", " 구글 코랩에서 실행하기\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "R7Ch9S54MM9C" }, "source": [ "# 설정" ] }, { "cell_type": "markdown", "metadata": { "id": "NYj6MAH1MM9D" }, "source": [ "먼저 몇 개의 모듈을 임포트합니다. 맷플롯립 그래프를 인라인으로 출력하도록 만들고 그림을 저장하는 함수를 준비합니다. 또한 파이썬 버전이 3.5 이상인지 확인합니다(파이썬 2.x에서도 동작하지만 곧 지원이 중단되므로 파이썬 3을 사용하는 것이 좋습니다). 사이킷런 버전이 0.20 이상인지와 텐서플로 버전이 2.0 이상인지 확인합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I6OcREmCMM9D" }, "outputs": [], "source": [ "# 파이썬 ≥3.5 필수\n", "import sys\n", "assert sys.version_info >= (3, 5)\n", "\n", "# 코랩에서 실행하고 있나요?\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "if IS_COLAB:\n", " !apt update && apt install -y libpq-dev libsdl2-dev swig xorg-dev xvfb\n", " %pip install -U tf-agents==0.13.0 pyvirtualdisplay\n", " %pip install -U gym~=0.21.0\n", " %pip install -U gym[box2d,atari,accept-rom-license]\n", " %pip install pyglet==1.5.27\n", "\n", "# 사이킷런 ≥0.20 필수\n", "import sklearn\n", "assert sklearn.__version__ >= \"0.20\"\n", "\n", "# 텐서플로 ≥2.0 필수\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "assert tf.__version__ >= \"2.0\"\n", "\n", "if not tf.config.list_physical_devices('GPU'):\n", " print(\"감지된 GPU가 없습니다. GPU가 없으면 LSTM과 CNN이 매우 느릴 수 있습니다.\")\n", " if IS_COLAB:\n", " print(\"런타임 > 런타임 유형 변경 메뉴를 선택하고 하드웨어 가속기로 GPU를 고르세요.\")\n", "\n", "# 공통 모듈 임포트\n", "import numpy as np\n", "import os\n", "\n", "# 노트북 실행 결과를 동일하게 유지하기 위해\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "\n", "# 깔끔한 그래프 출력을 위해\n", "%matplotlib inline\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "mpl.rc('axes', labelsize=14)\n", "mpl.rc('xtick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n", "\n", "# 부드러운 애니메이션을 위해\n", "import matplotlib.animation as animation\n", "mpl.rc('animation', html='jshtml')\n", "\n", "# 그림을 저장할 위치\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"rl\"\n", "IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n", "os.makedirs(IMAGES_PATH, exist_ok=True)\n", "\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", " path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " print(\"그림 저장\", fig_id)\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, { "cell_type": "markdown", "metadata": { "id": "HXIRyZZ8MM9F" }, "source": [ "# OpenAI 짐 소개" ] }, { "cell_type": "markdown", "metadata": { "id": "2vuke9TNMM9G" }, "source": [ "이 노트북은 강화학습 알고리즘을 개발하고 평가하는 훌륭한 도구인 [OpenAI 짐(gym)](https://gym.openai.com/)을 사용합니다. 학습 에이전트가 상호작용하기 위한 환경을 많이 제공합니다. 먼저 `gym`을 임포트합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "btRLkc6iMM9H" }, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "markdown", "metadata": { "id": "_7k2ardmMM9H" }, "source": [ "가능한 환경 목록을 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OWKEgVSnMM9I", "outputId": "7385a9ab-2e43-490a-e703-7716829ac08f" }, "outputs": [ { "data": { "text/plain": [ "dict_values([EnvSpec(Copy-v0), EnvSpec(RepeatCopy-v0), EnvSpec(ReversedAddition-v0), EnvSpec(ReversedAddition3-v0), EnvSpec(DuplicatedInput-v0), EnvSpec(Reverse-v0), EnvSpec(CartPole-v0), EnvSpec(CartPole-v1), EnvSpec(MountainCar-v0), EnvSpec(MountainCarContinuous-v0), EnvSpec(Pendulum-v0), EnvSpec(Acrobot-v1), EnvSpec(LunarLander-v2), EnvSpec(LunarLanderContinuous-v2), EnvSpec(BipedalWalker-v3), EnvSpec(BipedalWalkerHardcore-v3), EnvSpec(CarRacing-v0), EnvSpec(Blackjack-v0), EnvSpec(KellyCoinflip-v0), EnvSpec(KellyCoinflipGeneralized-v0), EnvSpec(FrozenLake-v0), EnvSpec(FrozenLake8x8-v0), EnvSpec(CliffWalking-v0), EnvSpec(NChain-v0), EnvSpec(Roulette-v0), EnvSpec(Taxi-v3), EnvSpec(GuessingGame-v0), EnvSpec(HotterColder-v0), EnvSpec(Reacher-v2), EnvSpec(Pusher-v2), EnvSpec(Thrower-v2), EnvSpec(Striker-v2), EnvSpec(InvertedPendulum-v2), EnvSpec(InvertedDoublePendulum-v2), EnvSpec(HalfCheetah-v2), EnvSpec(HalfCheetah-v3), EnvSpec(Hopper-v2), EnvSpec(Hopper-v3), EnvSpec(Swimmer-v2), EnvSpec(Swimmer-v3), EnvSpec(Walker2d-v2), EnvSpec(Walker2d-v3), EnvSpec(Ant-v2), EnvSpec(Ant-v3), EnvSpec(Humanoid-v2), EnvSpec(Humanoid-v3), EnvSpec(HumanoidStandup-v2), EnvSpec(FetchSlide-v1), EnvSpec(FetchPickAndPlace-v1), EnvSpec(FetchReach-v1), EnvSpec(FetchPush-v1), EnvSpec(HandReach-v0), EnvSpec(HandManipulateBlockRotateZ-v0), EnvSpec(HandManipulateBlockRotateZTouchSensors-v0), EnvSpec(HandManipulateBlockRotateZTouchSensors-v1), EnvSpec(HandManipulateBlockRotateParallel-v0), EnvSpec(HandManipulateBlockRotateParallelTouchSensors-v0), EnvSpec(HandManipulateBlockRotateParallelTouchSensors-v1), EnvSpec(HandManipulateBlockRotateXYZ-v0), EnvSpec(HandManipulateBlockRotateXYZTouchSensors-v0), EnvSpec(HandManipulateBlockRotateXYZTouchSensors-v1), EnvSpec(HandManipulateBlockFull-v0), EnvSpec(HandManipulateBlock-v0), EnvSpec(HandManipulateBlockTouchSensors-v0), EnvSpec(HandManipulateBlockTouchSensors-v1), EnvSpec(HandManipulateEggRotate-v0), EnvSpec(HandManipulateEggRotateTouchSensors-v0), EnvSpec(HandManipulateEggRotateTouchSensors-v1), EnvSpec(HandManipulateEggFull-v0), EnvSpec(HandManipulateEgg-v0), EnvSpec(HandManipulateEggTouchSensors-v0), EnvSpec(HandManipulateEggTouchSensors-v1), EnvSpec(HandManipulatePenRotate-v0), EnvSpec(HandManipulatePenRotateTouchSensors-v0), EnvSpec(HandManipulatePenRotateTouchSensors-v1), EnvSpec(HandManipulatePenFull-v0), EnvSpec(HandManipulatePen-v0), EnvSpec(HandManipulatePenTouchSensors-v0), EnvSpec(HandManipulatePenTouchSensors-v1), EnvSpec(FetchSlideDense-v1), EnvSpec(FetchPickAndPlaceDense-v1), EnvSpec(FetchReachDense-v1), EnvSpec(FetchPushDense-v1), EnvSpec(HandReachDense-v0), EnvSpec(HandManipulateBlockRotateZDense-v0), EnvSpec(HandManipulateBlockRotateZTouchSensorsDense-v0), EnvSpec(HandManipulateBlockRotateZTouchSensorsDense-v1), EnvSpec(HandManipulateBlockRotateParallelDense-v0), EnvSpec(HandManipulateBlockRotateParallelTouchSensorsDense-v0), EnvSpec(HandManipulateBlockRotateParallelTouchSensorsDense-v1), EnvSpec(HandManipulateBlockRotateXYZDense-v0), EnvSpec(HandManipulateBlockRotateXYZTouchSensorsDense-v0), EnvSpec(HandManipulateBlockRotateXYZTouchSensorsDense-v1), EnvSpec(HandManipulateBlockFullDense-v0), EnvSpec(HandManipulateBlockDense-v0), EnvSpec(HandManipulateBlockTouchSensorsDense-v0), EnvSpec(HandManipulateBlockTouchSensorsDense-v1), EnvSpec(HandManipulateEggRotateDense-v0), EnvSpec(HandManipulateEggRotateTouchSensorsDense-v0), EnvSpec(HandManipulateEggRotateTouchSensorsDense-v1), EnvSpec(HandManipulateEggFullDense-v0), EnvSpec(HandManipulateEggDense-v0), EnvSpec(HandManipulateEggTouchSensorsDense-v0), EnvSpec(HandManipulateEggTouchSensorsDense-v1), EnvSpec(HandManipulatePenRotateDense-v0), EnvSpec(HandManipulatePenRotateTouchSensorsDense-v0), EnvSpec(HandManipulatePenRotateTouchSensorsDense-v1), EnvSpec(HandManipulatePenFullDense-v0), EnvSpec(HandManipulatePenDense-v0), EnvSpec(HandManipulatePenTouchSensorsDense-v0), EnvSpec(HandManipulatePenTouchSensorsDense-v1), EnvSpec(Adventure-v0), EnvSpec(Adventure-v4), EnvSpec(AdventureDeterministic-v0), EnvSpec(AdventureDeterministic-v4), EnvSpec(AdventureNoFrameskip-v0), EnvSpec(AdventureNoFrameskip-v4), EnvSpec(Adventure-ram-v0), EnvSpec(Adventure-ram-v4), EnvSpec(Adventure-ramDeterministic-v0), EnvSpec(Adventure-ramDeterministic-v4), EnvSpec(Adventure-ramNoFrameskip-v0), EnvSpec(Adventure-ramNoFrameskip-v4), EnvSpec(AirRaid-v0), EnvSpec(AirRaid-v4), EnvSpec(AirRaidDeterministic-v0), EnvSpec(AirRaidDeterministic-v4), EnvSpec(AirRaidNoFrameskip-v0), EnvSpec(AirRaidNoFrameskip-v4), EnvSpec(AirRaid-ram-v0), EnvSpec(AirRaid-ram-v4), EnvSpec(AirRaid-ramDeterministic-v0), EnvSpec(AirRaid-ramDeterministic-v4), EnvSpec(AirRaid-ramNoFrameskip-v0), EnvSpec(AirRaid-ramNoFrameskip-v4), EnvSpec(Alien-v0), EnvSpec(Alien-v4), EnvSpec(AlienDeterministic-v0), EnvSpec(AlienDeterministic-v4), EnvSpec(AlienNoFrameskip-v0), EnvSpec(AlienNoFrameskip-v4), EnvSpec(Alien-ram-v0), EnvSpec(Alien-ram-v4), EnvSpec(Alien-ramDeterministic-v0), EnvSpec(Alien-ramDeterministic-v4), EnvSpec(Alien-ramNoFrameskip-v0), EnvSpec(Alien-ramNoFrameskip-v4), EnvSpec(Amidar-v0), EnvSpec(Amidar-v4), EnvSpec(AmidarDeterministic-v0), EnvSpec(AmidarDeterministic-v4), EnvSpec(AmidarNoFrameskip-v0), EnvSpec(AmidarNoFrameskip-v4), EnvSpec(Amidar-ram-v0), EnvSpec(Amidar-ram-v4), EnvSpec(Amidar-ramDeterministic-v0), EnvSpec(Amidar-ramDeterministic-v4), EnvSpec(Amidar-ramNoFrameskip-v0), EnvSpec(Amidar-ramNoFrameskip-v4), EnvSpec(Assault-v0), EnvSpec(Assault-v4), EnvSpec(AssaultDeterministic-v0), EnvSpec(AssaultDeterministic-v4), EnvSpec(AssaultNoFrameskip-v0), EnvSpec(AssaultNoFrameskip-v4), EnvSpec(Assault-ram-v0), EnvSpec(Assault-ram-v4), EnvSpec(Assault-ramDeterministic-v0), EnvSpec(Assault-ramDeterministic-v4), EnvSpec(Assault-ramNoFrameskip-v0), EnvSpec(Assault-ramNoFrameskip-v4), EnvSpec(Asterix-v0), EnvSpec(Asterix-v4), EnvSpec(AsterixDeterministic-v0), EnvSpec(AsterixDeterministic-v4), EnvSpec(AsterixNoFrameskip-v0), EnvSpec(AsterixNoFrameskip-v4), EnvSpec(Asterix-ram-v0), EnvSpec(Asterix-ram-v4), EnvSpec(Asterix-ramDeterministic-v0), EnvSpec(Asterix-ramDeterministic-v4), EnvSpec(Asterix-ramNoFrameskip-v0), EnvSpec(Asterix-ramNoFrameskip-v4), EnvSpec(Asteroids-v0), EnvSpec(Asteroids-v4), EnvSpec(AsteroidsDeterministic-v0), EnvSpec(AsteroidsDeterministic-v4), EnvSpec(AsteroidsNoFrameskip-v0), EnvSpec(AsteroidsNoFrameskip-v4), EnvSpec(Asteroids-ram-v0), EnvSpec(Asteroids-ram-v4), EnvSpec(Asteroids-ramDeterministic-v0), EnvSpec(Asteroids-ramDeterministic-v4), EnvSpec(Asteroids-ramNoFrameskip-v0), EnvSpec(Asteroids-ramNoFrameskip-v4), EnvSpec(Atlantis-v0), EnvSpec(Atlantis-v4), EnvSpec(AtlantisDeterministic-v0), EnvSpec(AtlantisDeterministic-v4), EnvSpec(AtlantisNoFrameskip-v0), EnvSpec(AtlantisNoFrameskip-v4), EnvSpec(Atlantis-ram-v0), EnvSpec(Atlantis-ram-v4), EnvSpec(Atlantis-ramDeterministic-v0), EnvSpec(Atlantis-ramDeterministic-v4), EnvSpec(Atlantis-ramNoFrameskip-v0), EnvSpec(Atlantis-ramNoFrameskip-v4), EnvSpec(BankHeist-v0), EnvSpec(BankHeist-v4), EnvSpec(BankHeistDeterministic-v0), EnvSpec(BankHeistDeterministic-v4), EnvSpec(BankHeistNoFrameskip-v0), EnvSpec(BankHeistNoFrameskip-v4), EnvSpec(BankHeist-ram-v0), EnvSpec(BankHeist-ram-v4), EnvSpec(BankHeist-ramDeterministic-v0), EnvSpec(BankHeist-ramDeterministic-v4), EnvSpec(BankHeist-ramNoFrameskip-v0), EnvSpec(BankHeist-ramNoFrameskip-v4), EnvSpec(BattleZone-v0), EnvSpec(BattleZone-v4), EnvSpec(BattleZoneDeterministic-v0), EnvSpec(BattleZoneDeterministic-v4), EnvSpec(BattleZoneNoFrameskip-v0), EnvSpec(BattleZoneNoFrameskip-v4), EnvSpec(BattleZone-ram-v0), EnvSpec(BattleZone-ram-v4), EnvSpec(BattleZone-ramDeterministic-v0), EnvSpec(BattleZone-ramDeterministic-v4), EnvSpec(BattleZone-ramNoFrameskip-v0), EnvSpec(BattleZone-ramNoFrameskip-v4), EnvSpec(BeamRider-v0), EnvSpec(BeamRider-v4), EnvSpec(BeamRiderDeterministic-v0), EnvSpec(BeamRiderDeterministic-v4), EnvSpec(BeamRiderNoFrameskip-v0), EnvSpec(BeamRiderNoFrameskip-v4), EnvSpec(BeamRider-ram-v0), EnvSpec(BeamRider-ram-v4), EnvSpec(BeamRider-ramDeterministic-v0), EnvSpec(BeamRider-ramDeterministic-v4), EnvSpec(BeamRider-ramNoFrameskip-v0), EnvSpec(BeamRider-ramNoFrameskip-v4), EnvSpec(Berzerk-v0), EnvSpec(Berzerk-v4), EnvSpec(BerzerkDeterministic-v0), EnvSpec(BerzerkDeterministic-v4), EnvSpec(BerzerkNoFrameskip-v0), EnvSpec(BerzerkNoFrameskip-v4), EnvSpec(Berzerk-ram-v0), EnvSpec(Berzerk-ram-v4), EnvSpec(Berzerk-ramDeterministic-v0), EnvSpec(Berzerk-ramDeterministic-v4), EnvSpec(Berzerk-ramNoFrameskip-v0), EnvSpec(Berzerk-ramNoFrameskip-v4), EnvSpec(Bowling-v0), EnvSpec(Bowling-v4), EnvSpec(BowlingDeterministic-v0), EnvSpec(BowlingDeterministic-v4), EnvSpec(BowlingNoFrameskip-v0), EnvSpec(BowlingNoFrameskip-v4), EnvSpec(Bowling-ram-v0), EnvSpec(Bowling-ram-v4), EnvSpec(Bowling-ramDeterministic-v0), EnvSpec(Bowling-ramDeterministic-v4), EnvSpec(Bowling-ramNoFrameskip-v0), EnvSpec(Bowling-ramNoFrameskip-v4), EnvSpec(Boxing-v0), EnvSpec(Boxing-v4), EnvSpec(BoxingDeterministic-v0), EnvSpec(BoxingDeterministic-v4), EnvSpec(BoxingNoFrameskip-v0), EnvSpec(BoxingNoFrameskip-v4), EnvSpec(Boxing-ram-v0), EnvSpec(Boxing-ram-v4), EnvSpec(Boxing-ramDeterministic-v0), EnvSpec(Boxing-ramDeterministic-v4), EnvSpec(Boxing-ramNoFrameskip-v0), EnvSpec(Boxing-ramNoFrameskip-v4), EnvSpec(Breakout-v0), EnvSpec(Breakout-v4), EnvSpec(BreakoutDeterministic-v0), EnvSpec(BreakoutDeterministic-v4), EnvSpec(BreakoutNoFrameskip-v0), EnvSpec(BreakoutNoFrameskip-v4), EnvSpec(Breakout-ram-v0), EnvSpec(Breakout-ram-v4), EnvSpec(Breakout-ramDeterministic-v0), EnvSpec(Breakout-ramDeterministic-v4), EnvSpec(Breakout-ramNoFrameskip-v0), EnvSpec(Breakout-ramNoFrameskip-v4), EnvSpec(Carnival-v0), EnvSpec(Carnival-v4), EnvSpec(CarnivalDeterministic-v0), EnvSpec(CarnivalDeterministic-v4), EnvSpec(CarnivalNoFrameskip-v0), EnvSpec(CarnivalNoFrameskip-v4), EnvSpec(Carnival-ram-v0), EnvSpec(Carnival-ram-v4), EnvSpec(Carnival-ramDeterministic-v0), EnvSpec(Carnival-ramDeterministic-v4), EnvSpec(Carnival-ramNoFrameskip-v0), EnvSpec(Carnival-ramNoFrameskip-v4), EnvSpec(Centipede-v0), EnvSpec(Centipede-v4), EnvSpec(CentipedeDeterministic-v0), EnvSpec(CentipedeDeterministic-v4), EnvSpec(CentipedeNoFrameskip-v0), EnvSpec(CentipedeNoFrameskip-v4), EnvSpec(Centipede-ram-v0), EnvSpec(Centipede-ram-v4), EnvSpec(Centipede-ramDeterministic-v0), EnvSpec(Centipede-ramDeterministic-v4), EnvSpec(Centipede-ramNoFrameskip-v0), EnvSpec(Centipede-ramNoFrameskip-v4), EnvSpec(ChopperCommand-v0), EnvSpec(ChopperCommand-v4), EnvSpec(ChopperCommandDeterministic-v0), EnvSpec(ChopperCommandDeterministic-v4), EnvSpec(ChopperCommandNoFrameskip-v0), EnvSpec(ChopperCommandNoFrameskip-v4), EnvSpec(ChopperCommand-ram-v0), EnvSpec(ChopperCommand-ram-v4), EnvSpec(ChopperCommand-ramDeterministic-v0), EnvSpec(ChopperCommand-ramDeterministic-v4), EnvSpec(ChopperCommand-ramNoFrameskip-v0), EnvSpec(ChopperCommand-ramNoFrameskip-v4), EnvSpec(CrazyClimber-v0), EnvSpec(CrazyClimber-v4), EnvSpec(CrazyClimberDeterministic-v0), EnvSpec(CrazyClimberDeterministic-v4), EnvSpec(CrazyClimberNoFrameskip-v0), EnvSpec(CrazyClimberNoFrameskip-v4), EnvSpec(CrazyClimber-ram-v0), EnvSpec(CrazyClimber-ram-v4), EnvSpec(CrazyClimber-ramDeterministic-v0), EnvSpec(CrazyClimber-ramDeterministic-v4), EnvSpec(CrazyClimber-ramNoFrameskip-v0), EnvSpec(CrazyClimber-ramNoFrameskip-v4), EnvSpec(Defender-v0), EnvSpec(Defender-v4), EnvSpec(DefenderDeterministic-v0), EnvSpec(DefenderDeterministic-v4), EnvSpec(DefenderNoFrameskip-v0), EnvSpec(DefenderNoFrameskip-v4), EnvSpec(Defender-ram-v0), EnvSpec(Defender-ram-v4), EnvSpec(Defender-ramDeterministic-v0), EnvSpec(Defender-ramDeterministic-v4), EnvSpec(Defender-ramNoFrameskip-v0), EnvSpec(Defender-ramNoFrameskip-v4), EnvSpec(DemonAttack-v0), EnvSpec(DemonAttack-v4), EnvSpec(DemonAttackDeterministic-v0), EnvSpec(DemonAttackDeterministic-v4), EnvSpec(DemonAttackNoFrameskip-v0), EnvSpec(DemonAttackNoFrameskip-v4), EnvSpec(DemonAttack-ram-v0), EnvSpec(DemonAttack-ram-v4), EnvSpec(DemonAttack-ramDeterministic-v0), EnvSpec(DemonAttack-ramDeterministic-v4), EnvSpec(DemonAttack-ramNoFrameskip-v0), EnvSpec(DemonAttack-ramNoFrameskip-v4), EnvSpec(DoubleDunk-v0), EnvSpec(DoubleDunk-v4), EnvSpec(DoubleDunkDeterministic-v0), EnvSpec(DoubleDunkDeterministic-v4), EnvSpec(DoubleDunkNoFrameskip-v0), EnvSpec(DoubleDunkNoFrameskip-v4), EnvSpec(DoubleDunk-ram-v0), EnvSpec(DoubleDunk-ram-v4), EnvSpec(DoubleDunk-ramDeterministic-v0), EnvSpec(DoubleDunk-ramDeterministic-v4), EnvSpec(DoubleDunk-ramNoFrameskip-v0), EnvSpec(DoubleDunk-ramNoFrameskip-v4), EnvSpec(ElevatorAction-v0), EnvSpec(ElevatorAction-v4), EnvSpec(ElevatorActionDeterministic-v0), EnvSpec(ElevatorActionDeterministic-v4), EnvSpec(ElevatorActionNoFrameskip-v0), EnvSpec(ElevatorActionNoFrameskip-v4), EnvSpec(ElevatorAction-ram-v0), EnvSpec(ElevatorAction-ram-v4), EnvSpec(ElevatorAction-ramDeterministic-v0), EnvSpec(ElevatorAction-ramDeterministic-v4), EnvSpec(ElevatorAction-ramNoFrameskip-v0), EnvSpec(ElevatorAction-ramNoFrameskip-v4), EnvSpec(Enduro-v0), EnvSpec(Enduro-v4), EnvSpec(EnduroDeterministic-v0), EnvSpec(EnduroDeterministic-v4), EnvSpec(EnduroNoFrameskip-v0), EnvSpec(EnduroNoFrameskip-v4), EnvSpec(Enduro-ram-v0), EnvSpec(Enduro-ram-v4), EnvSpec(Enduro-ramDeterministic-v0), EnvSpec(Enduro-ramDeterministic-v4), EnvSpec(Enduro-ramNoFrameskip-v0), EnvSpec(Enduro-ramNoFrameskip-v4), EnvSpec(FishingDerby-v0), EnvSpec(FishingDerby-v4), EnvSpec(FishingDerbyDeterministic-v0), EnvSpec(FishingDerbyDeterministic-v4), EnvSpec(FishingDerbyNoFrameskip-v0), EnvSpec(FishingDerbyNoFrameskip-v4), EnvSpec(FishingDerby-ram-v0), EnvSpec(FishingDerby-ram-v4), EnvSpec(FishingDerby-ramDeterministic-v0), EnvSpec(FishingDerby-ramDeterministic-v4), EnvSpec(FishingDerby-ramNoFrameskip-v0), EnvSpec(FishingDerby-ramNoFrameskip-v4), EnvSpec(Freeway-v0), EnvSpec(Freeway-v4), EnvSpec(FreewayDeterministic-v0), EnvSpec(FreewayDeterministic-v4), EnvSpec(FreewayNoFrameskip-v0), EnvSpec(FreewayNoFrameskip-v4), EnvSpec(Freeway-ram-v0), EnvSpec(Freeway-ram-v4), EnvSpec(Freeway-ramDeterministic-v0), EnvSpec(Freeway-ramDeterministic-v4), EnvSpec(Freeway-ramNoFrameskip-v0), EnvSpec(Freeway-ramNoFrameskip-v4), EnvSpec(Frostbite-v0), EnvSpec(Frostbite-v4), EnvSpec(FrostbiteDeterministic-v0), EnvSpec(FrostbiteDeterministic-v4), EnvSpec(FrostbiteNoFrameskip-v0), EnvSpec(FrostbiteNoFrameskip-v4), EnvSpec(Frostbite-ram-v0), EnvSpec(Frostbite-ram-v4), EnvSpec(Frostbite-ramDeterministic-v0), EnvSpec(Frostbite-ramDeterministic-v4), EnvSpec(Frostbite-ramNoFrameskip-v0), EnvSpec(Frostbite-ramNoFrameskip-v4), EnvSpec(Gopher-v0), EnvSpec(Gopher-v4), EnvSpec(GopherDeterministic-v0), EnvSpec(GopherDeterministic-v4), EnvSpec(GopherNoFrameskip-v0), EnvSpec(GopherNoFrameskip-v4), EnvSpec(Gopher-ram-v0), EnvSpec(Gopher-ram-v4), EnvSpec(Gopher-ramDeterministic-v0), EnvSpec(Gopher-ramDeterministic-v4), EnvSpec(Gopher-ramNoFrameskip-v0), EnvSpec(Gopher-ramNoFrameskip-v4), EnvSpec(Gravitar-v0), EnvSpec(Gravitar-v4), EnvSpec(GravitarDeterministic-v0), EnvSpec(GravitarDeterministic-v4), EnvSpec(GravitarNoFrameskip-v0), EnvSpec(GravitarNoFrameskip-v4), EnvSpec(Gravitar-ram-v0), EnvSpec(Gravitar-ram-v4), EnvSpec(Gravitar-ramDeterministic-v0), EnvSpec(Gravitar-ramDeterministic-v4), EnvSpec(Gravitar-ramNoFrameskip-v0), EnvSpec(Gravitar-ramNoFrameskip-v4), EnvSpec(Hero-v0), EnvSpec(Hero-v4), EnvSpec(HeroDeterministic-v0), EnvSpec(HeroDeterministic-v4), EnvSpec(HeroNoFrameskip-v0), EnvSpec(HeroNoFrameskip-v4), EnvSpec(Hero-ram-v0), EnvSpec(Hero-ram-v4), EnvSpec(Hero-ramDeterministic-v0), EnvSpec(Hero-ramDeterministic-v4), EnvSpec(Hero-ramNoFrameskip-v0), EnvSpec(Hero-ramNoFrameskip-v4), EnvSpec(IceHockey-v0), EnvSpec(IceHockey-v4), EnvSpec(IceHockeyDeterministic-v0), EnvSpec(IceHockeyDeterministic-v4), EnvSpec(IceHockeyNoFrameskip-v0), EnvSpec(IceHockeyNoFrameskip-v4), EnvSpec(IceHockey-ram-v0), EnvSpec(IceHockey-ram-v4), EnvSpec(IceHockey-ramDeterministic-v0), EnvSpec(IceHockey-ramDeterministic-v4), EnvSpec(IceHockey-ramNoFrameskip-v0), EnvSpec(IceHockey-ramNoFrameskip-v4), EnvSpec(Jamesbond-v0), EnvSpec(Jamesbond-v4), EnvSpec(JamesbondDeterministic-v0), EnvSpec(JamesbondDeterministic-v4), EnvSpec(JamesbondNoFrameskip-v0), EnvSpec(JamesbondNoFrameskip-v4), EnvSpec(Jamesbond-ram-v0), EnvSpec(Jamesbond-ram-v4), EnvSpec(Jamesbond-ramDeterministic-v0), EnvSpec(Jamesbond-ramDeterministic-v4), EnvSpec(Jamesbond-ramNoFrameskip-v0), EnvSpec(Jamesbond-ramNoFrameskip-v4), EnvSpec(JourneyEscape-v0), EnvSpec(JourneyEscape-v4), EnvSpec(JourneyEscapeDeterministic-v0), EnvSpec(JourneyEscapeDeterministic-v4), EnvSpec(JourneyEscapeNoFrameskip-v0), EnvSpec(JourneyEscapeNoFrameskip-v4), EnvSpec(JourneyEscape-ram-v0), EnvSpec(JourneyEscape-ram-v4), EnvSpec(JourneyEscape-ramDeterministic-v0), EnvSpec(JourneyEscape-ramDeterministic-v4), EnvSpec(JourneyEscape-ramNoFrameskip-v0), EnvSpec(JourneyEscape-ramNoFrameskip-v4), EnvSpec(Kangaroo-v0), EnvSpec(Kangaroo-v4), EnvSpec(KangarooDeterministic-v0), EnvSpec(KangarooDeterministic-v4), EnvSpec(KangarooNoFrameskip-v0), EnvSpec(KangarooNoFrameskip-v4), EnvSpec(Kangaroo-ram-v0), EnvSpec(Kangaroo-ram-v4), EnvSpec(Kangaroo-ramDeterministic-v0), EnvSpec(Kangaroo-ramDeterministic-v4), EnvSpec(Kangaroo-ramNoFrameskip-v0), EnvSpec(Kangaroo-ramNoFrameskip-v4), EnvSpec(Krull-v0), EnvSpec(Krull-v4), EnvSpec(KrullDeterministic-v0), EnvSpec(KrullDeterministic-v4), EnvSpec(KrullNoFrameskip-v0), EnvSpec(KrullNoFrameskip-v4), EnvSpec(Krull-ram-v0), EnvSpec(Krull-ram-v4), EnvSpec(Krull-ramDeterministic-v0), EnvSpec(Krull-ramDeterministic-v4), EnvSpec(Krull-ramNoFrameskip-v0), EnvSpec(Krull-ramNoFrameskip-v4), EnvSpec(KungFuMaster-v0), EnvSpec(KungFuMaster-v4), EnvSpec(KungFuMasterDeterministic-v0), EnvSpec(KungFuMasterDeterministic-v4), EnvSpec(KungFuMasterNoFrameskip-v0), EnvSpec(KungFuMasterNoFrameskip-v4), EnvSpec(KungFuMaster-ram-v0), EnvSpec(KungFuMaster-ram-v4), EnvSpec(KungFuMaster-ramDeterministic-v0), EnvSpec(KungFuMaster-ramDeterministic-v4), EnvSpec(KungFuMaster-ramNoFrameskip-v0), EnvSpec(KungFuMaster-ramNoFrameskip-v4), EnvSpec(MontezumaRevenge-v0), EnvSpec(MontezumaRevenge-v4), EnvSpec(MontezumaRevengeDeterministic-v0), EnvSpec(MontezumaRevengeDeterministic-v4), EnvSpec(MontezumaRevengeNoFrameskip-v0), EnvSpec(MontezumaRevengeNoFrameskip-v4), EnvSpec(MontezumaRevenge-ram-v0), EnvSpec(MontezumaRevenge-ram-v4), EnvSpec(MontezumaRevenge-ramDeterministic-v0), EnvSpec(MontezumaRevenge-ramDeterministic-v4), EnvSpec(MontezumaRevenge-ramNoFrameskip-v0), EnvSpec(MontezumaRevenge-ramNoFrameskip-v4), EnvSpec(MsPacman-v0), EnvSpec(MsPacman-v4), EnvSpec(MsPacmanDeterministic-v0), EnvSpec(MsPacmanDeterministic-v4), EnvSpec(MsPacmanNoFrameskip-v0), EnvSpec(MsPacmanNoFrameskip-v4), EnvSpec(MsPacman-ram-v0), EnvSpec(MsPacman-ram-v4), EnvSpec(MsPacman-ramDeterministic-v0), EnvSpec(MsPacman-ramDeterministic-v4), EnvSpec(MsPacman-ramNoFrameskip-v0), EnvSpec(MsPacman-ramNoFrameskip-v4), EnvSpec(NameThisGame-v0), EnvSpec(NameThisGame-v4), EnvSpec(NameThisGameDeterministic-v0), EnvSpec(NameThisGameDeterministic-v4), EnvSpec(NameThisGameNoFrameskip-v0), EnvSpec(NameThisGameNoFrameskip-v4), EnvSpec(NameThisGame-ram-v0), EnvSpec(NameThisGame-ram-v4), EnvSpec(NameThisGame-ramDeterministic-v0), EnvSpec(NameThisGame-ramDeterministic-v4), EnvSpec(NameThisGame-ramNoFrameskip-v0), EnvSpec(NameThisGame-ramNoFrameskip-v4), EnvSpec(Phoenix-v0), EnvSpec(Phoenix-v4), EnvSpec(PhoenixDeterministic-v0), EnvSpec(PhoenixDeterministic-v4), EnvSpec(PhoenixNoFrameskip-v0), EnvSpec(PhoenixNoFrameskip-v4), EnvSpec(Phoenix-ram-v0), EnvSpec(Phoenix-ram-v4), EnvSpec(Phoenix-ramDeterministic-v0), EnvSpec(Phoenix-ramDeterministic-v4), EnvSpec(Phoenix-ramNoFrameskip-v0), EnvSpec(Phoenix-ramNoFrameskip-v4), EnvSpec(Pitfall-v0), EnvSpec(Pitfall-v4), EnvSpec(PitfallDeterministic-v0), EnvSpec(PitfallDeterministic-v4), EnvSpec(PitfallNoFrameskip-v0), EnvSpec(PitfallNoFrameskip-v4), EnvSpec(Pitfall-ram-v0), EnvSpec(Pitfall-ram-v4), EnvSpec(Pitfall-ramDeterministic-v0), EnvSpec(Pitfall-ramDeterministic-v4), EnvSpec(Pitfall-ramNoFrameskip-v0), EnvSpec(Pitfall-ramNoFrameskip-v4), EnvSpec(Pong-v0), EnvSpec(Pong-v4), EnvSpec(PongDeterministic-v0), EnvSpec(PongDeterministic-v4), EnvSpec(PongNoFrameskip-v0), EnvSpec(PongNoFrameskip-v4), EnvSpec(Pong-ram-v0), EnvSpec(Pong-ram-v4), EnvSpec(Pong-ramDeterministic-v0), EnvSpec(Pong-ramDeterministic-v4), EnvSpec(Pong-ramNoFrameskip-v0), EnvSpec(Pong-ramNoFrameskip-v4), EnvSpec(Pooyan-v0), EnvSpec(Pooyan-v4), EnvSpec(PooyanDeterministic-v0), EnvSpec(PooyanDeterministic-v4), EnvSpec(PooyanNoFrameskip-v0), EnvSpec(PooyanNoFrameskip-v4), EnvSpec(Pooyan-ram-v0), EnvSpec(Pooyan-ram-v4), EnvSpec(Pooyan-ramDeterministic-v0), EnvSpec(Pooyan-ramDeterministic-v4), EnvSpec(Pooyan-ramNoFrameskip-v0), EnvSpec(Pooyan-ramNoFrameskip-v4), EnvSpec(PrivateEye-v0), EnvSpec(PrivateEye-v4), EnvSpec(PrivateEyeDeterministic-v0), EnvSpec(PrivateEyeDeterministic-v4), EnvSpec(PrivateEyeNoFrameskip-v0), EnvSpec(PrivateEyeNoFrameskip-v4), EnvSpec(PrivateEye-ram-v0), EnvSpec(PrivateEye-ram-v4), EnvSpec(PrivateEye-ramDeterministic-v0), EnvSpec(PrivateEye-ramDeterministic-v4), EnvSpec(PrivateEye-ramNoFrameskip-v0), EnvSpec(PrivateEye-ramNoFrameskip-v4), EnvSpec(Qbert-v0), EnvSpec(Qbert-v4), EnvSpec(QbertDeterministic-v0), EnvSpec(QbertDeterministic-v4), EnvSpec(QbertNoFrameskip-v0), EnvSpec(QbertNoFrameskip-v4), EnvSpec(Qbert-ram-v0), EnvSpec(Qbert-ram-v4), EnvSpec(Qbert-ramDeterministic-v0), EnvSpec(Qbert-ramDeterministic-v4), EnvSpec(Qbert-ramNoFrameskip-v0), EnvSpec(Qbert-ramNoFrameskip-v4), EnvSpec(Riverraid-v0), EnvSpec(Riverraid-v4), EnvSpec(RiverraidDeterministic-v0), EnvSpec(RiverraidDeterministic-v4), EnvSpec(RiverraidNoFrameskip-v0), EnvSpec(RiverraidNoFrameskip-v4), EnvSpec(Riverraid-ram-v0), EnvSpec(Riverraid-ram-v4), EnvSpec(Riverraid-ramDeterministic-v0), EnvSpec(Riverraid-ramDeterministic-v4), EnvSpec(Riverraid-ramNoFrameskip-v0), EnvSpec(Riverraid-ramNoFrameskip-v4), EnvSpec(RoadRunner-v0), EnvSpec(RoadRunner-v4), EnvSpec(RoadRunnerDeterministic-v0), EnvSpec(RoadRunnerDeterministic-v4), EnvSpec(RoadRunnerNoFrameskip-v0), EnvSpec(RoadRunnerNoFrameskip-v4), EnvSpec(RoadRunner-ram-v0), EnvSpec(RoadRunner-ram-v4), EnvSpec(RoadRunner-ramDeterministic-v0), EnvSpec(RoadRunner-ramDeterministic-v4), EnvSpec(RoadRunner-ramNoFrameskip-v0), EnvSpec(RoadRunner-ramNoFrameskip-v4), EnvSpec(Robotank-v0), EnvSpec(Robotank-v4), EnvSpec(RobotankDeterministic-v0), EnvSpec(RobotankDeterministic-v4), EnvSpec(RobotankNoFrameskip-v0), EnvSpec(RobotankNoFrameskip-v4), EnvSpec(Robotank-ram-v0), EnvSpec(Robotank-ram-v4), EnvSpec(Robotank-ramDeterministic-v0), EnvSpec(Robotank-ramDeterministic-v4), EnvSpec(Robotank-ramNoFrameskip-v0), EnvSpec(Robotank-ramNoFrameskip-v4), EnvSpec(Seaquest-v0), EnvSpec(Seaquest-v4), EnvSpec(SeaquestDeterministic-v0), EnvSpec(SeaquestDeterministic-v4), EnvSpec(SeaquestNoFrameskip-v0), EnvSpec(SeaquestNoFrameskip-v4), EnvSpec(Seaquest-ram-v0), EnvSpec(Seaquest-ram-v4), EnvSpec(Seaquest-ramDeterministic-v0), EnvSpec(Seaquest-ramDeterministic-v4), EnvSpec(Seaquest-ramNoFrameskip-v0), EnvSpec(Seaquest-ramNoFrameskip-v4), EnvSpec(Skiing-v0), EnvSpec(Skiing-v4), EnvSpec(SkiingDeterministic-v0), EnvSpec(SkiingDeterministic-v4), EnvSpec(SkiingNoFrameskip-v0), EnvSpec(SkiingNoFrameskip-v4), EnvSpec(Skiing-ram-v0), EnvSpec(Skiing-ram-v4), EnvSpec(Skiing-ramDeterministic-v0), EnvSpec(Skiing-ramDeterministic-v4), EnvSpec(Skiing-ramNoFrameskip-v0), EnvSpec(Skiing-ramNoFrameskip-v4), EnvSpec(Solaris-v0), EnvSpec(Solaris-v4), EnvSpec(SolarisDeterministic-v0), EnvSpec(SolarisDeterministic-v4), EnvSpec(SolarisNoFrameskip-v0), EnvSpec(SolarisNoFrameskip-v4), EnvSpec(Solaris-ram-v0), EnvSpec(Solaris-ram-v4), EnvSpec(Solaris-ramDeterministic-v0), EnvSpec(Solaris-ramDeterministic-v4), EnvSpec(Solaris-ramNoFrameskip-v0), EnvSpec(Solaris-ramNoFrameskip-v4), EnvSpec(SpaceInvaders-v0), EnvSpec(SpaceInvaders-v4), EnvSpec(SpaceInvadersDeterministic-v0), EnvSpec(SpaceInvadersDeterministic-v4), EnvSpec(SpaceInvadersNoFrameskip-v0), EnvSpec(SpaceInvadersNoFrameskip-v4), EnvSpec(SpaceInvaders-ram-v0), EnvSpec(SpaceInvaders-ram-v4), EnvSpec(SpaceInvaders-ramDeterministic-v0), EnvSpec(SpaceInvaders-ramDeterministic-v4), EnvSpec(SpaceInvaders-ramNoFrameskip-v0), EnvSpec(SpaceInvaders-ramNoFrameskip-v4), EnvSpec(StarGunner-v0), EnvSpec(StarGunner-v4), EnvSpec(StarGunnerDeterministic-v0), EnvSpec(StarGunnerDeterministic-v4), EnvSpec(StarGunnerNoFrameskip-v0), EnvSpec(StarGunnerNoFrameskip-v4), EnvSpec(StarGunner-ram-v0), EnvSpec(StarGunner-ram-v4), EnvSpec(StarGunner-ramDeterministic-v0), EnvSpec(StarGunner-ramDeterministic-v4), EnvSpec(StarGunner-ramNoFrameskip-v0), EnvSpec(StarGunner-ramNoFrameskip-v4), EnvSpec(Tennis-v0), EnvSpec(Tennis-v4), EnvSpec(TennisDeterministic-v0), EnvSpec(TennisDeterministic-v4), EnvSpec(TennisNoFrameskip-v0), EnvSpec(TennisNoFrameskip-v4), EnvSpec(Tennis-ram-v0), EnvSpec(Tennis-ram-v4), EnvSpec(Tennis-ramDeterministic-v0), EnvSpec(Tennis-ramDeterministic-v4), EnvSpec(Tennis-ramNoFrameskip-v0), EnvSpec(Tennis-ramNoFrameskip-v4), EnvSpec(TimePilot-v0), EnvSpec(TimePilot-v4), EnvSpec(TimePilotDeterministic-v0), EnvSpec(TimePilotDeterministic-v4), EnvSpec(TimePilotNoFrameskip-v0), EnvSpec(TimePilotNoFrameskip-v4), EnvSpec(TimePilot-ram-v0), EnvSpec(TimePilot-ram-v4), EnvSpec(TimePilot-ramDeterministic-v0), EnvSpec(TimePilot-ramDeterministic-v4), EnvSpec(TimePilot-ramNoFrameskip-v0), EnvSpec(TimePilot-ramNoFrameskip-v4), EnvSpec(Tutankham-v0), EnvSpec(Tutankham-v4), EnvSpec(TutankhamDeterministic-v0), EnvSpec(TutankhamDeterministic-v4), EnvSpec(TutankhamNoFrameskip-v0), EnvSpec(TutankhamNoFrameskip-v4), EnvSpec(Tutankham-ram-v0), EnvSpec(Tutankham-ram-v4), EnvSpec(Tutankham-ramDeterministic-v0), EnvSpec(Tutankham-ramDeterministic-v4), EnvSpec(Tutankham-ramNoFrameskip-v0), EnvSpec(Tutankham-ramNoFrameskip-v4), EnvSpec(UpNDown-v0), EnvSpec(UpNDown-v4), EnvSpec(UpNDownDeterministic-v0), EnvSpec(UpNDownDeterministic-v4), EnvSpec(UpNDownNoFrameskip-v0), EnvSpec(UpNDownNoFrameskip-v4), EnvSpec(UpNDown-ram-v0), EnvSpec(UpNDown-ram-v4), EnvSpec(UpNDown-ramDeterministic-v0), EnvSpec(UpNDown-ramDeterministic-v4), EnvSpec(UpNDown-ramNoFrameskip-v0), EnvSpec(UpNDown-ramNoFrameskip-v4), EnvSpec(Venture-v0), EnvSpec(Venture-v4), EnvSpec(VentureDeterministic-v0), EnvSpec(VentureDeterministic-v4), EnvSpec(VentureNoFrameskip-v0), EnvSpec(VentureNoFrameskip-v4), EnvSpec(Venture-ram-v0), EnvSpec(Venture-ram-v4), EnvSpec(Venture-ramDeterministic-v0), EnvSpec(Venture-ramDeterministic-v4), EnvSpec(Venture-ramNoFrameskip-v0), EnvSpec(Venture-ramNoFrameskip-v4), EnvSpec(VideoPinball-v0), EnvSpec(VideoPinball-v4), EnvSpec(VideoPinballDeterministic-v0), EnvSpec(VideoPinballDeterministic-v4), EnvSpec(VideoPinballNoFrameskip-v0), EnvSpec(VideoPinballNoFrameskip-v4), EnvSpec(VideoPinball-ram-v0), EnvSpec(VideoPinball-ram-v4), EnvSpec(VideoPinball-ramDeterministic-v0), EnvSpec(VideoPinball-ramDeterministic-v4), EnvSpec(VideoPinball-ramNoFrameskip-v0), EnvSpec(VideoPinball-ramNoFrameskip-v4), EnvSpec(WizardOfWor-v0), EnvSpec(WizardOfWor-v4), EnvSpec(WizardOfWorDeterministic-v0), EnvSpec(WizardOfWorDeterministic-v4), EnvSpec(WizardOfWorNoFrameskip-v0), EnvSpec(WizardOfWorNoFrameskip-v4), EnvSpec(WizardOfWor-ram-v0), EnvSpec(WizardOfWor-ram-v4), EnvSpec(WizardOfWor-ramDeterministic-v0), EnvSpec(WizardOfWor-ramDeterministic-v4), EnvSpec(WizardOfWor-ramNoFrameskip-v0), EnvSpec(WizardOfWor-ramNoFrameskip-v4), EnvSpec(YarsRevenge-v0), EnvSpec(YarsRevenge-v4), EnvSpec(YarsRevengeDeterministic-v0), EnvSpec(YarsRevengeDeterministic-v4), EnvSpec(YarsRevengeNoFrameskip-v0), EnvSpec(YarsRevengeNoFrameskip-v4), EnvSpec(YarsRevenge-ram-v0), EnvSpec(YarsRevenge-ram-v4), EnvSpec(YarsRevenge-ramDeterministic-v0), EnvSpec(YarsRevenge-ramDeterministic-v4), EnvSpec(YarsRevenge-ramNoFrameskip-v0), EnvSpec(YarsRevenge-ramNoFrameskip-v4), EnvSpec(Zaxxon-v0), EnvSpec(Zaxxon-v4), EnvSpec(ZaxxonDeterministic-v0), EnvSpec(ZaxxonDeterministic-v4), EnvSpec(ZaxxonNoFrameskip-v0), EnvSpec(ZaxxonNoFrameskip-v4), EnvSpec(Zaxxon-ram-v0), EnvSpec(Zaxxon-ram-v4), EnvSpec(Zaxxon-ramDeterministic-v0), EnvSpec(Zaxxon-ramDeterministic-v4), EnvSpec(Zaxxon-ramNoFrameskip-v0), EnvSpec(Zaxxon-ramNoFrameskip-v4), EnvSpec(CubeCrash-v0), EnvSpec(CubeCrashSparse-v0), EnvSpec(CubeCrashScreenBecomesBlack-v0), EnvSpec(MemorizeDigits-v0)])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gym.envs.registry.all()" ] }, { "cell_type": "markdown", "metadata": { "id": "QR-DxVNOMM9J" }, "source": [ "Cart-Pole은 매우 간단한 환경으로 왼쪽과 오른쪽으로 움직이는 카트와 그 위에 수직으로 놓여 있는 막대로 구성됩니다. 에이전트는 카트를 왼쪽이나 오른쪽으로 움직여 막대가 바로 서 있도록 만들어야 합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DauG0fsUMM9K" }, "outputs": [], "source": [ "env = gym.make('CartPole-v1')" ] }, { "cell_type": "markdown", "metadata": { "id": "22Z8jxZBMM9K" }, "source": [ "`reset()` 메서드를 호출해 환경을 초기화합니다. 이 메서드는 관측을 반환합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ln5prFuAMM9K" }, "outputs": [], "source": [ "env.seed(42)\n", "obs = env.reset()" ] }, { "cell_type": "markdown", "metadata": { "id": "hAKkSMyZMM9K" }, "source": [ "관측은 환경에 따라 다릅니다. 이 경우 4개의 실수로 구성된 1D 넘파이 배열입니다. 카트의 수평 위치, 속도, 막대의 각도(0=수직), 각속도를 나타냅니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3MooBNyqMM9L", "outputId": "4f1e7271-119a-47e2-ef1e-b172d9dd0535" }, "outputs": [ { "data": { "text/plain": [ "array([-0.01258566, -0.00156614, 0.04207708, -0.00180545])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": { "id": "YFnea-CUMM9L" }, "source": [ "환경은 `render()` 메서드를 호출하여 시각화할 수 있습니다. 그리고 렌더링 모드(환경에 따른 렌더링 옵션)를 선택할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "B3xZQ77_MM9L" }, "source": [ "**경고**: (Cart-Pole을 포함해) 일부 환경은 화면 접근 권한이 필요합니다. `mode=\"rgb_array\"`로 지정하더라도 별도의 윈도우를 엽니다. 일반적으로 이 윈도우를 무시할 수 있습니다. 하지만 주피터를 백엔드(headless) 서버로 실행한다면 예외가 발생합니다. 이를 피하는 한 가지 방법은 [Xvfb](http://en.wikipedia.org/wiki/Xvfb) 같은 가짜 X 서버를 설치하는 것입니다. 데비안이나 우분투에서는 다음과 같이 설치합니다:\n", "\n", "```bash\n", "$ apt update\n", "$ apt install -y xvfb\n", "```\n", "\n", "그다음 `xvfb-run` 명령으로 주피터를 실행합니다:\n", "\n", "```bash\n", "$ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n", "```\n", "\n", "또는 Xvfb를 감싼 [pyvirtualdisplay](https://github.com/ponty/pyvirtualdisplay) 파이썬 라이브러리를 설치할 수 있습니다:\n", "\n", "```bash\n", "%pip install -U pyvirtualdisplay\n", "```\n", "\n", "그다음 다음 코드를 실행합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IRCSe2PCMM9M" }, "outputs": [], "source": [ "try:\n", " import pyvirtualdisplay\n", " display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()\n", "except ImportError:\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sZ-HW8dMM9M", "outputId": "745dfe7c-f163-4372-8f0b-9df0972820e3" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.render()" ] }, { "cell_type": "markdown", "metadata": { "id": "ecaApWqyMM9M" }, "source": [ "이 예에서는 `mode=\"rgb_array\"`로 지정해 환경 이미지를 넘파이 배열로 받을 것입니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Nv1aNcK_MM9M", "outputId": "b00724e1-d49b-4daf-c403-e5f282e930af" }, "outputs": [ { "data": { "text/plain": [ "(400, 600, 3)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = env.render(mode=\"rgb_array\")\n", "img.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HmRfeGldMM9N" }, "outputs": [], "source": [ "def plot_environment(env, figsize=(5,4)):\n", " plt.figure(figsize=figsize)\n", " img = env.render(mode=\"rgb_array\")\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " return img" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CO1XfyXcMM9N", "outputId": "450d1e22-3f7c-410c-9db5-1ecf3586085c" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAADICAYAAACuyvefAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAF4klEQVR4nO3dzY8bdx3H8e/Y3sfulhSiKlSoIr2UWy9BKDlHAokD/wF3cuN/yYkLh4gr/wRCWiEFISHoAVVERXS3PKQJWbRee8fDgQjqbHcDnk39sfN63Wb8oK+0o7e8Y89vmq7rCiDFYNkDAHyeKAFRRAmIIkpAFFECoogSEGX0ksf9XgB4FZqLHvBJCYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBktewBWX9d1Vd3sc3uaqqappmmWNhOrS5TobfzkqD7+xc+qex6m4eZ23fjgu7X/zvtLnoxVJEr0NpuO6/jTj+Y+Le2/8y1RYiHOKdFbOx2f2zfc2F7CJKwDUaK38WdH8+eUmqa23/r68gZipYkSV6B7YdsJbhYnSkAUUQKiiBK9dF33798pwRURJXo7+fuf5raHG1u1+ca15QzDyhMlepudTee2m+GoBn4SwIJECYgiSkAUUQKiiBL9dF21k5O5Xc1gaIUAFiZK9DJrp3X67K9z+7b2r9dwc2dJE7HqRImr1wzKpSYsSpSAKKIERBElIIoo0cvZ+FmdnRzP7du+dqPKt28sSJToZXY2rVk7f5nJaGvXTwJYmCgBUUQJiCJKQBRRopeundaLa3QPRpvLGYa1IEr0Mn5yVF17Nrdv56vfWNI0rANRopcvXArXN2/0IEpAFFECoogSPbmTCVdLlOjl5PGf57ab4UZt7H5lSdOwDkSJXtrJeG67GQxrtL23pGlYB6IERBElIIooAVFEiYV1XVez6encvqYZWLaEXkSJhXWztsZPDuf2be695UQ3vYgSV6sZPL+bCSzG0QNEESUgiigBUUSJhbWTk5qePJvbt/Xm9WoGDisW5+hhYV07rdl0/jKT0fZeNU5004OjB4giSkAUUQKiiBIL62btuTW6m8FoSdOwLkSJhY2f/uXctW+7199d0jSsC1Ficd2sXlwO1zdv9OUIAqKIEhBFlIAozRfe4fS/3D/nNdS2bT148KCOjo4ufd77107r3cEn/9k+a2f1y0/3a7L99qWvu3v3bt26detKZmVlXbgSoO9vOadt27p//349fPjw0uf96Affrh9+7zv12eTt2hyc1nYd1f2f/LT+ePjk0tft7OyIEhcSJRZ22u7WwePv19Pp12rYtPXNrYPqup8veyxWnHNKLOyT8Xv1dHq9qgbVdhv10T8/qMlsZ9ljseJEiYV1L5wWmHXDarvhkqZhXYgSC7v55se1P3pcVV0N6qzee+M3tT08XvZYrLhLzynNZrMvaw6C/K9/91/99nf14aMf1+PJjdocjGu3Duvwb/946eu6rnNsveYGlywEeGmUDg4OrnwY8k2n0zo+fvknnl//4bCqDqvq9//X+z969Mix9Zq7c+fOhY9dGqXLXsj6mkwmtbf36u7ddvPmTccWF3JOCYgiSkAUUQKiiBIQRZSAKK5945zhcFj37t2rw8PDV/L+t2/ffiXvy3qwdAmwDBcuXeLfNyCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEGX0ksebL2UKgOd8UgKiiBIQRZSAKKIERBElIIooAVH+Bc6iAHz8aqUYAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_environment(env)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "Xx5PVNLmMM9N" }, "source": [ "환경과 상호작용하는 방법을 알아 보죠. 에이전트는 \"행동 공간\"(가능한 행동의 집합)에서 하나의 행동을 선택해야 합니다. 이 환경의 행동 공간을 다음처럼 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EqjV6_ESMM9N", "outputId": "6397e2dd-fcdc-40b7-d5f0-4e6510899b50" }, "outputs": [ { "data": { "text/plain": [ "Discrete(2)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": { "id": "oIvYOWJIMM9N" }, "source": [ "네 단 두 개의 행동이 가능합니다: 왼쪽 또는 오른쪽으로 가속합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "IUM_t8GpMM9O" }, "source": [ "막대가 오른쪽으로 기울어져 있기 때문에(`obs[2] > 0`), 카트를 오른쪽으로 가속해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0uXGZvqxMM9O", "outputId": "85c1d33f-51fa-4a8a-96fe-6863fbca6411" }, "outputs": [ { "data": { "text/plain": [ "array([-0.01261699, 0.19292789, 0.04204097, -0.28092127])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "action = 1 # 오른쪽으로 가속\n", "obs, reward, done, info = env.step(action)\n", "obs" ] }, { "cell_type": "markdown", "metadata": { "id": "1GFdgA3nMM9O" }, "source": [ "이제 카트가 오른쪽으로 움직였습니다(`obs[1] > 0`). 막대가 여전히 오른쪽으로 기울어져 있습니다(`obs[2] > 0`). 하지만 각속도가 음수이므로(`obs[3] < 0`) 다음 스텝에서는 왼쪽으로 기울 것 같습니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M-gO_rXbMM9O", "outputId": "96a3fb1b-741a-4e15-9508-16b1f3ddcd52" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure cart_pole_plot\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAADvCAYAAADM8A71AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAGHklEQVR4nO3dzWpcBRjH4fdMkklMG2s1+BFF0aIWP6s7wYXQ4kLoBbjpRYilvQGhdyBFvAbrSqwrN5aCGyuIINUK1o9qqmlsYhsnx40Uy8Q2mbbz76HPswrvDHPexeHHMDlzpmnbtgAYv156AYA7lQADhAgwQIgAA4QIMEDI5HUed4kEwI1rNhp6BwwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAymV4ARtW27ZW/m6YJbgKjEWA668fPP6wLP3xVD730RvWmZqqqqun1avsDu6rpTYS3g+sTYDrr8vL5Wvn1+zp9/N0rs97UTD3/5js1ObM9uBlsjs+AAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgSYThpcXq211aWh+fTcfW7EQ2cIMJ3019IvtXz266H5fU+9UhP9uwIbwdYJMECIAAOECDBAiAADhAgwQIgAA4QIMECIAAOECDBAiAADhAgwQIgA00mXls4NzZqJqZratjOwDYxGgOmkX7/6dGg2Nbujdj7+UmAbGI0AA4QIMECIAAOECDBAiAADhAgwQIgAA4QIMECIAAOECDBAiAADhAgwQIgA0zl//ny6Vs//ODSf3/1qVeOUpjucrXTO2uqFGlxeGZrP3PNgNU0T2AhGI8AAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAkyntG1bf69eGJr3Jvs10Z8JbASjE2C6pV2vX059MjSenX+05hZ2BxaC0QkwndO27YZzd0KjawQYIESAAUIEGCBEgAFCBBggRIABQgQYIGQyvQCsrKzUYDDY1HPb9UG16+tD88FgUMvLy5t6jX6/X9PT01vaEW6F5v8uav/XNR+Em2H//v114sSJTT13otfU+2+9Xg/du/2q+Zff/VZvv/fppl7j0KFDdfDgwS3vCTdgw28JeQdM3NLSUi0uLm7quRO9plbWpuvcpUeqqqqptu7t/1Rrf69t+jVWV1dH3hVuJgGmU9pq6os/Xqv++rNXJk9sO1Xr7c/RvWAU/glHp7z89GM1OfPgfyZNffvnC/XByd9jO8GoBJhOuf/+J2t6+urPf9uqOnX6XGYhuAECTKfs7J+rqd6lq2Z3TVwcmkEX+AyYTjl/4WKtLR6vsxefqYfn5+qe7f167u7PatvE8E3a4XYnwHTKRye/qY9OHqmqqhd3PVAL83P1cVWdXdzcNcBwO7nmdcBHjhxxHTC33NGjR+vMmTNjO96+fftq7969YzseHD58eOvXAR84cODWbAP/cezYsbEGeM+ePc5tbgvXDPDCwsK49uAO1u/3x3q8ubk55za3BVdBAIQIMECIAAOECDBAiAADhPgiBnE7duyo+fn5sR1vdnZ2bMeCa3FDduK28osYN4NfxCBgwy9iCDDArbdhgH0GDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIQIMECLAACECDBAiwAAhAgwQIsAAIZPXebwZyxYAdyDvgAFCBBggRIABQgQYIESAAUIEGCDkH6bUxi6unNOlAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_environment(env)\n", "save_fig(\"cart_pole_plot\")" ] }, { "cell_type": "markdown", "metadata": { "id": "7tsgxZK_MM9O" }, "source": [ "요청한 대로 실행되는 것 같습니다!" ] }, { "cell_type": "markdown", "metadata": { "id": "6nS0v0NiMM9P" }, "source": [ "환경은 이전 스텝에서 얼마나 많은 보상을 받는지 에이전트에게 알려 줍니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WHa5a1p7MM9P", "outputId": "e2b9da6c-28c3-4011-e409-61d7d35c81ea" }, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reward" ] }, { "cell_type": "markdown", "metadata": { "id": "Sttrrhw9MM9P" }, "source": [ "게임이 끝나면 환경은 `done=True`를 반환합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjignPnfMM9P", "outputId": "6466a7d0-d8aa-4767-dbc2-df8e39450ce1" }, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "done" ] }, { "cell_type": "markdown", "metadata": { "id": "8ysEdYE7MM9P" }, "source": [ "마지막으로 `info`는 훈련이나 디버깅에 유용한 추가적인 정보를 담은 환경에 특화된 딕셔너리입니다. 예를 들어 일부 게임에서는 얼마나 많은 에이전트의 생명이 몇 개가 남아 있는지 나타낼 수 있습니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NaW2h9CoMM9Q", "outputId": "bd2439d8-808a-483a-a62d-9d3e4be4e6c8" }, "outputs": [ { "data": { "text/plain": [ "{}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "info" ] }, { "cell_type": "markdown", "metadata": { "id": "3tktQI60MM9Q" }, "source": [ "환경이 재설정된 순간부터 종료될 때까지 스텝 시퀀스를 \"에피소드\"라고 합니다. 에피소드 끝에서 (즉, `step()`이 `done=True`를 반환할 때), 계속하기 전에 환경을 재설정해야 합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hb5bmiIwMM9Q" }, "outputs": [], "source": [ "if done:\n", " obs = env.reset()" ] }, { "cell_type": "markdown", "metadata": { "id": "wPl8DcS8MM9R" }, "source": [ "그럼 어떻게 막대를 똑바로 유지할 수 있을까요? 이를 위해 정책을 정의해야 합니다. 에이전트가 매 스텝마다 행동을 선택하기 위해 사용할 전략입니다. 어떤 행동을 선택할지 결정하기 위해 지난 행동과 관측을 모두 사용할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "1MuS0SmWMM9R" }, "source": [ "# 간단한 하드 코딩 정책" ] }, { "cell_type": "markdown", "metadata": { "id": "z6xtsBnHMM9R" }, "source": [ "간단한 정책을 하드 코딩해 보죠. 막대가 왼쪽으로 기울어지면 카트를 왼쪽으로 움직이고 오른쪽으로 기울어지면 반대로 움직입니다. 어떻게 작동하는지 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fIxTVS-JMM9R" }, "outputs": [], "source": [ "env.seed(42)\n", "\n", "def basic_policy(obs):\n", " angle = obs[2]\n", " return 0 if angle < 0 else 1\n", "\n", "totals = []\n", "for episode in range(500):\n", " episode_rewards = 0\n", " obs = env.reset()\n", " for step in range(200):\n", " action = basic_policy(obs)\n", " obs, reward, done, info = env.step(action)\n", " episode_rewards += reward\n", " if done:\n", " break\n", " totals.append(episode_rewards)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4-ZUL55bMM9R", "outputId": "385eb081-57df-4caa-b7c9-5fedb00bc2e6" }, "outputs": [ { "data": { "text/plain": [ "(41.718, 8.858356280936096, 24.0, 68.0)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(totals), np.std(totals), np.min(totals), np.max(totals)" ] }, { "cell_type": "markdown", "metadata": { "id": "XjK9VwcYMM9S" }, "source": [ "예상대로 이 전략은 너무 단순합니다. 최대로 막대를 유지한 스텝 횟수가 68입니다. 이 환경은 에이전트가 막대를 200 스텝 이상 유지해야 해결된 것으로 간주합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "21BN-tYuMM9S" }, "source": [ "하나의 에피소드를 시각화해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cQpd8b7WMM9S" }, "outputs": [], "source": [ "env.seed(42)\n", "\n", "frames = []\n", "\n", "obs = env.reset()\n", "for step in range(200):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " action = basic_policy(obs)\n", "\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "Gos2Se6AMM9S" }, "source": [ "애니메이션을 출력합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cBPZOwgaMM9T" }, "outputs": [], "source": [ "def update_scene(num, frames, patch):\n", " patch.set_data(frames[num])\n", " return patch,\n", "\n", "def plot_animation(frames, repeat=False, interval=40):\n", " fig = plt.figure()\n", " patch = plt.imshow(frames[0])\n", " plt.axis('off')\n", " anim = animation.FuncAnimation(\n", " fig, update_scene, fargs=(frames, patch),\n", " frames=len(frames), repeat=repeat, interval=interval)\n", " plt.close()\n", " return anim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DVYZ7NU6MM9T" }, "outputs": [], "source": [ "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "xTqme9K4MM9T" }, "source": [ "확실히 이 방법은 불안정해서 약간 흔들리면 막대가 너무 기울어져 게임이 끝납니다. 이 보다는 더 똑똑한 전략이 필요합니다!" ] }, { "cell_type": "markdown", "metadata": { "id": "gjAWzygUMM9T" }, "source": [ "# 신경망 정책" ] }, { "cell_type": "markdown", "metadata": { "id": "zroA9QpAMM9T" }, "source": [ "관측을 입력으로 받고 각 관측에 대해 선택할 행동의 확률을 출력하는 신경망을 만들어 보죠. 행동을 선택하기 위해 신경망은 각 행동의 확률을 추정합니다. 이 추정된 확률에 따라 랜덤하게 행동을 선택합니다. Cart-Pole 환경의 경우 두 개의 가능한 행동이 있습니다(왼쪽과 오른쪽). 따라서 하나의 출력 뉴런만 있으면 됩니다. 이 뉴런은 행동 0(왼쪽)의 확률 `p`를 출력합니다. 물론 행동 1(오른쪽)의 확률은 `1 - p`가 됩니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zwTSL6KcMM9U" }, "outputs": [], "source": [ "keras.backend.clear_session()\n", "tf.random.set_seed(42)\n", "np.random.seed(42)\n", "\n", "n_inputs = 4 # == env.observation_space.shape[0]\n", "\n", "model = keras.models.Sequential([\n", " keras.layers.Dense(5, activation=\"elu\", input_shape=[n_inputs]),\n", " keras.layers.Dense(1, activation=\"sigmoid\"),\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "3cDCANQlMM9U" }, "source": [ "이 환경에서는 지난 행동과 관측을 무시할 수 있습니다. 각 관측이 완전한 환경의 상태를 담고 있기 때문입니다. 은닉 상태가 있다면 환경의 은닉 상태를 추정하기 위해 지난 행동과 관측을 고려해야 할 수 있습니다. 예를 들어, 이 환경이 카트의 위치만 제공하고 속도를 알려 주지 않는다면, 현재 속도를 추정하기 위해 현재 관측 뿐만 아니라 지난 관측도 고려해야 합니다. 또 다른 예는 관측에 잡음이 있는 경우입니다. 가장 가능성 있는 현재 상태를 추정하기 위해 지난 몇 개의 관측을 사용할 수 있습니다. 이 문제는 매우 간단합니다. 현재 관측에 잡음이 없고 환경의 모든 상태가 담겨 있습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "a4cbqL1oMM9U" }, "source": [ "정책 네트워크가 출력한 확률 중에서 가장 높은 확률을 가진 행동을 선택하지 않고 랜덤한 행동을 선택하는 이유가 궁금할지 모릅니다. 이 방법은 에이전트가 새로운 행동을 탐험하는 것과 잘 동작하는 행동을 활용하는 것 사이에 밸런스를 찾도록 합니다. 비유를 들어 보죠. 한 음식점에 처음 방문했다고 가정해 보죠. 모든 음식에 대한 선호도가 동일하다면 랜덤하게 하나를 선택합니다. 이 음식이 좋다고 느낀다면 다음 번에 이 음식을 주문할 확률을 높일 수 있습니다. 하지만 이 확률을 100%로 높여서는 안됩니다. 그렇지 않으면 다른 음식을 시도해 볼 수 없습니다. 어쩌면 다른 음식이 이번에 먹은 것보다 훨씬 더 좋을 수도 있습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "h5RiuKZ4MM9U" }, "source": [ "모델을 실행하여 한 에피소드를 플레이하고 애니메이션을 위한 프레임을 반환하는 함수를 작성해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cUSmxiEhMM9U" }, "outputs": [], "source": [ "def render_policy_net(model, n_max_steps=200, seed=42):\n", " frames = []\n", " env = gym.make(\"CartPole-v1\")\n", " env.seed(seed)\n", " np.random.seed(seed)\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " frames.append(env.render(mode=\"rgb_array\"))\n", " left_proba = model.predict(obs.reshape(1, -1))\n", " action = int(np.random.rand() > left_proba)\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break\n", " env.close()\n", " return frames" ] }, { "cell_type": "markdown", "metadata": { "id": "Ku2fsJ7kMM9V" }, "source": [ "랜덤하게 초기화된 정책 네트워크가 얼마나 잘 수행하는지 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C4xD0Cq5MM9V" }, "outputs": [], "source": [ "frames = render_policy_net(model)\n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "uwIOKbamMM9V" }, "source": [ "음.. 아주 나쁘군요. 이 신경망은 더 배워야 합니다. 먼저 앞에서 사용한 기본적인 정책을 학습할 수 있는지 확인해 보죠. 막대가 왼쪽으로 기울면 왼쪽으로 움직이고, 오른쪽으로 기울면 오른쪽으로 움직이도록 합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "lGamEjtPMM9V" }, "source": [ "같은 신경망으로 동시에 50개의 다른 환경을 플레이할 수 있습니다(이렇게 하면 각 스텝마다 다양한 훈련 배치를 얻을 수 있습니다). 그리고 5000번 반복 동안에 훈련합니다. 게임이 종료되면 환경을 재설정합니다. 사용자 정의 훈련 루프를 사용하여 모델을 훈련하기 때문에 훈련 스텝마다 환경에 앞서 예측을 쉽게 만들 수 있습니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NNwa-__yMM9V", "outputId": "e2e20d61-1171-4192-f0fd-228d87c211cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 4999, Loss: 0.094" ] } ], "source": [ "n_environments = 50\n", "n_iterations = 5000\n", "\n", "envs = [gym.make(\"CartPole-v1\") for _ in range(n_environments)]\n", "for index, env in enumerate(envs):\n", " env.seed(index)\n", "np.random.seed(42)\n", "observations = [env.reset() for env in envs]\n", "optimizer = keras.optimizers.RMSprop()\n", "loss_fn = keras.losses.binary_crossentropy\n", "\n", "for iteration in range(n_iterations):\n", " # if angle < 0, we want proba(left) = 1., or else proba(left) = 0.\n", " target_probas = np.array([([1.] if obs[2] < 0 else [0.])\n", " for obs in observations])\n", " with tf.GradientTape() as tape:\n", " left_probas = model(np.array(observations))\n", " loss = tf.reduce_mean(loss_fn(target_probas, left_probas))\n", " print(\"\\rIteration: {}, Loss: {:.3f}\".format(iteration, loss.numpy()), end=\"\")\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", " actions = (np.random.rand(n_environments, 1) > left_probas.numpy()).astype(np.int32)\n", " for env_index, env in enumerate(envs):\n", " obs, reward, done, info = env.step(actions[env_index][0])\n", " observations[env_index] = obs if not done else env.reset()\n", "\n", "for env in envs:\n", " env.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1hyqRPv2MM9V" }, "outputs": [], "source": [ "frames = render_policy_net(model)\n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "du_XBPffMM9W" }, "source": [ "정책을 잘 학습한 것 같군요. 이제 스스로 더 나은 정책을 학습할 수 있는지 확인해 보겠습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "Krl3BwxCMM9W" }, "source": [ "# 정책 그레이디언트" ] }, { "cell_type": "markdown", "metadata": { "id": "Is92qcXZMM9W" }, "source": [ "이 신경망을 훈련하려면 타깃 확률 `y`를 정의해야 합니다. 행동이 좋으면 해당 확률을 증가시키고 반대로 나쁘면 감소시켜야 합니다. 하지만 행동이 좋은지 나쁜지 어떻게 알까요? 대부분 행동의 효과가 지연되어 나타나기 때문에 한 에피소드에서 점수를 얻거나 잃을 때 어떤 행동이 이 결과에 기여했는지 명확하지 않다는 것이 문제입니다. 마지막 행동일까요? 아니면 마지막에서 10번째 행동일까요? 아니면 50 스텝 이전의 행동일까요? 이를 _신용 할당 문제_ 라고 부릅니다.\n", "\n", "_정책 그레이디언트_ 알고리즘은 이 문제를 해결하기 위해 먼저 여러 개의 에피소드를 플레이하고 그다음 좋은 에피소드에 있는 행동의 가능성을 조금 더 높이고, 나쁜 에피소드에 있는 행동의 가능성을 조금 낮춥니다. 먼저 플레이해보고 다시 돌아가서 수행한 작업을 생각해 보겠습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "qFuuHNucMM9W" }, "source": [ "이 모델을 사용해 하나의 스텝을 플레이하는 함수를 만듭니다. 지금은 선택한 행동이 모두 좋다고 가정하고 손실과 그레이디언트를 계산합니다(그레이디언트를 저장하고 나중에 행동이 좋은지 나쁜지에 따라 수정하겠습니다):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "twr__keJMM9W" }, "outputs": [], "source": [ "def play_one_step(env, obs, model, loss_fn):\n", " with tf.GradientTape() as tape:\n", " left_proba = model(obs[np.newaxis])\n", " action = (tf.random.uniform([1, 1]) > left_proba)\n", " y_target = tf.constant([[1.]]) - tf.cast(action, tf.float32)\n", " loss = tf.reduce_mean(loss_fn(y_target, left_proba))\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " obs, reward, done, info = env.step(int(action[0, 0].numpy()))\n", " return obs, reward, done, grads" ] }, { "cell_type": "markdown", "metadata": { "id": "TqLY1B30MM9W" }, "source": [ "`left_proba`가 높으면 `action`이 `False`가 될 가능성이 높습니다(0~1 사이에서 균등 분포로 난수를 샘플링하면 `left_proba`보다 높지 않을 가능성이 높기 때문에). 그리고 `False`를 숫자로 바꾸면 0이므로 `y_target`은 1 - 0 = 1입니다. 다른 말로 하면 타깃을 1로 지정하는 것은 왼쪽일 확률을 100%로 가정한다는 의미입니다(따라서 올바른 행동을 선택했습니다)." ] }, { "cell_type": "markdown", "metadata": { "id": "yUXNSzmiMM9X" }, "source": [ "이제 `play_one_step()` 함수를 사용해 여러 개의 에피소드를 플레이하고 에피소드와 스텝마다 모든 보상과 그레이디언트를 반환하는 또 다른 함수를 만들어 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FAldzAamMM9X" }, "outputs": [], "source": [ "def play_multiple_episodes(env, n_episodes, n_max_steps, model, loss_fn):\n", " all_rewards = []\n", " all_grads = []\n", " for episode in range(n_episodes):\n", " current_rewards = []\n", " current_grads = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " obs, reward, done, grads = play_one_step(env, obs, model, loss_fn)\n", " current_rewards.append(reward)\n", " current_grads.append(grads)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_grads.append(current_grads)\n", " return all_rewards, all_grads" ] }, { "cell_type": "markdown", "metadata": { "id": "I4vI4svXMM9X" }, "source": [ "정책 그레이디언트 알고리즘은 모델을 사용해 여러 번 에피소드를 플레이합니다(예를 들어 10번). 그다음 모든 보상을 할인하고 정규화합니다. 이를 위한 함수를 만들어 보죠. 첫 번째 함수는 할인된 보상을 계산합니다. 두 번째 함수는 여러 에피소드에 걸쳐 할인된 보상을 정규화합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EQF-In9vMM9Y" }, "outputs": [], "source": [ "def discount_rewards(rewards, discount_rate):\n", " discounted = np.array(rewards)\n", " for step in range(len(rewards) - 2, -1, -1):\n", " discounted[step] += discounted[step + 1] * discount_rate\n", " return discounted\n", "\n", "def discount_and_normalize_rewards(all_rewards, discount_rate):\n", " all_discounted_rewards = [discount_rewards(rewards, discount_rate)\n", " for rewards in all_rewards]\n", " flat_rewards = np.concatenate(all_discounted_rewards)\n", " reward_mean = flat_rewards.mean()\n", " reward_std = flat_rewards.std()\n", " return [(discounted_rewards - reward_mean) / reward_std\n", " for discounted_rewards in all_discounted_rewards]" ] }, { "cell_type": "markdown", "metadata": { "id": "nV2JnT8mMM9Y" }, "source": [ "3개의 행동을 수행하고 각 행동의 보상이 10, 0, -50이라고 가정해 보죠. 80%의 할인 계수를 사용하면 세 번째 행동은 -50(마지막 보상의 100%)를 받지만 두 번째 행동은 -40(마지막 보상의 80%)만 받습니다. 그리고 첫 번째 행동은 -40의 80%(-32)에 첫 번째 보상(+10)의 100%를 받습니다. 따라서 할인된 보상의 합은 -22가 됩니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cEiOHjbDMM9Y", "outputId": "0502b10f-8545-465e-9e88-0a292d6c7bc7" }, "outputs": [ { "data": { "text/plain": [ "array([-22, -40, -50])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_rewards([10, 0, -50], discount_rate=0.8)" ] }, { "cell_type": "markdown", "metadata": { "id": "97_H6aE3MM9Y" }, "source": [ "전체 에피소드에 대해 모든 할인된 보상을 정규화하기 위해 전체 할인된 보상의 평균과 표준 편차를 계산합니다. 그리고 할인된 보상에서 평균을 빼고 표준 편차를 나눕니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "id": "7pMp9pWzMM9Y", "outputId": "606ed469-2126-44f4-bcbd-e3fccfa2fae7" }, "outputs": [ { "data": { "text/plain": [ "[array([-0.28435071, -0.86597718, -1.18910299]),\n", " array([1.26665318, 1.0727777 ])]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_and_normalize_rewards([[10, 0, -50], [10, 20]], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KYl9ZKkqMM9Z" }, "outputs": [], "source": [ "n_iterations = 150\n", "n_episodes_per_update = 10\n", "n_max_steps = 200\n", "discount_rate = 0.95" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3De70JDfMM9Z" }, "outputs": [], "source": [ "optimizer = keras.optimizers.Adam(learning_rate=0.01)\n", "loss_fn = keras.losses.binary_crossentropy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-jypGQqYMM9Z" }, "outputs": [], "source": [ "keras.backend.clear_session()\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "\n", "model = keras.models.Sequential([\n", " keras.layers.Dense(5, activation=\"elu\", input_shape=[4]),\n", " keras.layers.Dense(1, activation=\"sigmoid\"),\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l8X7hRqGMM9Z", "outputId": "4d0c3f24-e253-4a2a-9df2-06960a89cda6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 149, mean rewards: 199.6" ] } ], "source": [ "env = gym.make(\"CartPole-v1\")\n", "env.seed(42);\n", "\n", "for iteration in range(n_iterations):\n", " all_rewards, all_grads = play_multiple_episodes(\n", " env, n_episodes_per_update, n_max_steps, model, loss_fn)\n", " total_rewards = sum(map(sum, all_rewards)) # Not shown in the book\n", " print(\"\\rIteration: {}, mean rewards: {:.1f}\".format( # Not shown\n", " iteration, total_rewards / n_episodes_per_update), end=\"\") # Not shown\n", " all_final_rewards = discount_and_normalize_rewards(all_rewards,\n", " discount_rate)\n", " all_mean_grads = []\n", " for var_index in range(len(model.trainable_variables)):\n", " mean_grads = tf.reduce_mean(\n", " [final_reward * all_grads[episode_index][step][var_index]\n", " for episode_index, final_rewards in enumerate(all_final_rewards)\n", " for step, final_reward in enumerate(final_rewards)], axis=0)\n", " all_mean_grads.append(mean_grads)\n", " optimizer.apply_gradients(zip(all_mean_grads, model.trainable_variables))\n", "\n", "env.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B63GZYuLMM9Z" }, "outputs": [], "source": [ "frames = render_policy_net(model)\n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "DGkNh4UrMM9Z" }, "source": [ "# 마르코프 연쇄" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DHyyVS37MM9a", "outputId": "16c14253-c256-469e-b23a-97211c9394db" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "States: 0 0 3 \n", "States: 0 1 2 1 2 1 2 1 2 1 3 \n", "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "States: 0 3 \n", "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "States: 0 1 3 \n", "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 ...\n", "States: 0 0 3 \n", "States: 0 0 0 1 2 1 2 1 3 \n", "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n" ] } ], "source": [ "np.random.seed(42)\n", "\n", "transition_probabilities = [ # shape=[s, s']\n", " [0.7, 0.2, 0.0, 0.1], # from s0 to s0, s1, s2, s3\n", " [0.0, 0.0, 0.9, 0.1], # from s1 to ...\n", " [0.0, 1.0, 0.0, 0.0], # from s2 to ...\n", " [0.0, 0.0, 0.0, 1.0]] # from s3 to ...\n", "\n", "n_max_steps = 50\n", "\n", "def print_sequence():\n", " current_state = 0\n", " print(\"States:\", end=\" \")\n", " for step in range(n_max_steps):\n", " print(current_state, end=\" \")\n", " if current_state == 3:\n", " break\n", " current_state = np.random.choice(range(4), p=transition_probabilities[current_state])\n", " else:\n", " print(\"...\", end=\"\")\n", " print()\n", "\n", "for _ in range(10):\n", " print_sequence()" ] }, { "cell_type": "markdown", "metadata": { "id": "51NUkpFaMM9a" }, "source": [ "# 마르코프 결정 과정" ] }, { "cell_type": "markdown", "metadata": { "id": "JNr2w5rlMM9a" }, "source": [ "전이 확률, 보상, 가능한 행동을 정의해 보죠. 예를 들어, 상태 s0에서 행동 a0가 선택되면 0.7의 확률로 상태 s0로 가고 +10 보상을 받습니다. 그리고 0.3의 확률로 상태 s1으로 가고 보상이 없습니다. 상태 s2로는 이동하지 않습니다(따라서 전이 확률은 `[0.7, 0.3, 0.0]`이고 보상은 `[+10, 0, 0]`입니다):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0fKqxeBvMM9a" }, "outputs": [], "source": [ "transition_probabilities = [ # shape=[s, a, s']\n", " [[0.7, 0.3, 0.0], [1.0, 0.0, 0.0], [0.8, 0.2, 0.0]],\n", " [[0.0, 1.0, 0.0], None, [0.0, 0.0, 1.0]],\n", " [None, [0.8, 0.1, 0.1], None]]\n", "rewards = [ # shape=[s, a, s']\n", " [[+10, 0, 0], [0, 0, 0], [0, 0, 0]],\n", " [[0, 0, 0], [0, 0, 0], [0, 0, -50]],\n", " [[0, 0, 0], [+40, 0, 0], [0, 0, 0]]]\n", "possible_actions = [[0, 1, 2], [0, 2], [1]]" ] }, { "cell_type": "markdown", "metadata": { "id": "HVNgfWe0MM9a" }, "source": [ "# Q-가치 반복" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rew91PtsMM9a" }, "outputs": [], "source": [ "Q_values = np.full((3, 3), -np.inf) # 불가능한 행동은 -np.inf\n", "for state, actions in enumerate(possible_actions):\n", " Q_values[state, actions] = 0.0 # 모든 가능한 행동에 대해" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MzX_h6NgMM9b" }, "outputs": [], "source": [ "gamma = 0.90 # 할인 계수\n", "\n", "history1 = [] # 책에는 없음\n", "for iteration in range(50):\n", " Q_prev = Q_values.copy()\n", " history1.append(Q_prev) # 책에는 없음\n", " for s in range(3):\n", " for a in possible_actions[s]:\n", " Q_values[s, a] = np.sum([\n", " transition_probabilities[s][a][sp]\n", " * (rewards[s][a][sp] + gamma * np.max(Q_prev[sp]))\n", " for sp in range(3)])\n", "\n", "history1 = np.array(history1) # 책에는 없음" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PbyItG4pMM9b", "outputId": "f4b300c1-79f2-4b57-b163-21f512fccba8" }, "outputs": [ { "data": { "text/plain": [ "array([[18.91891892, 17.02702702, 13.62162162],\n", " [ 0. , -inf, -4.87971488],\n", " [ -inf, 50.13365013, -inf]])" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q_values" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xHsBvKMpMM9b", "outputId": "4ddf59f8-2bdf-4891-8136-7468cd38c3b3" }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 1])" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(Q_values, axis=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "F1z-sffDMM9b" }, "source": [ "할인 계수 0.9를 사용했을 때 이 MDP의 최적 정책은 상태 s0에서 행동 a0를 선택하고, 상태 s1에서 행동 a0를 선택하고, 마지막으로 상태 s2에서 행동 a1(선택 가능한 유일한 행동)을 선택하는 것입니다." ] }, { "cell_type": "markdown", "metadata": { "id": "4UDoc4sxMM9b" }, "source": [ "할인 계수 0.95로 시도해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xUNa27J3MM9b" }, "outputs": [], "source": [ "Q_values = np.full((3, 3), -np.inf) # 불가능한 행동에 대해서는 -np.inf\n", "for state, actions in enumerate(possible_actions):\n", " Q_values[state, actions] = 0.0 # 모든 가능한 행동에 대해서" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GwPPi5GgMM9c" }, "outputs": [], "source": [ "gamma = 0.95 # 할인 계수\n", "\n", "for iteration in range(50):\n", " Q_prev = Q_values.copy()\n", " for s in range(3):\n", " for a in possible_actions[s]:\n", " Q_values[s, a] = np.sum([\n", " transition_probabilities[s][a][sp]\n", " * (rewards[s][a][sp] + gamma * np.max(Q_prev[sp]))\n", " for sp in range(3)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IjMO4FG4MM9c", "outputId": "aecd24fe-1183-494f-9cb6-d9682b9861e4" }, "outputs": [ { "data": { "text/plain": [ "array([[21.73304188, 20.63807938, 16.70138772],\n", " [ 0.95462106, -inf, 1.01361207],\n", " [ -inf, 53.70728682, -inf]])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q_values" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TZ74tFImMM9c", "outputId": "9ac08a10-fd7a-46a6-d0b7-0c74455effbf" }, "outputs": [ { "data": { "text/plain": [ "array([0, 2, 1])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(Q_values, axis=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "bPZSt8aWMM9c" }, "source": [ "이제 정책이 바뀌었습니다! 상태 s1에서 불 속으로 들어가는 것을 선택합니다(행동 a2). 할인 계수가 크기 때문에 에이전트가 미래에 더 많은 가치를 두기 때문에 미래 보상을 얻기 위해 당장의 불이익을 감내합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "byyrVmOOMM9d" }, "source": [ "# Q-러닝" ] }, { "cell_type": "markdown", "metadata": { "id": "k6NvMNmhMM9d" }, "source": [ "Q-러닝은 에이전트의 (예를 들면, 랜덤한) 플레이를 보고 점진적으로 Q-가치 추정을 향상합니다. 정확한 (또는 충분히 가까운) Q-가치 추정을 얻으면 최적의 정책은 가장 높은 Q-가치를 가진 행동을 선택하는 것입니다(즉, 그리디 정책)." ] }, { "cell_type": "markdown", "metadata": { "id": "SvxJeIjUMM9d" }, "source": [ "환경을 돌아다니는 에이전트를 시뮬레이션해야 합니다. 따라서 행동을 선택하고 새로운 상태와 보상을 받는 함수를 정의해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R2bGoGb4MM9d" }, "outputs": [], "source": [ "def step(state, action):\n", " probas = transition_probabilities[state][action]\n", " next_state = np.random.choice([0, 1, 2], p=probas)\n", " reward = rewards[state][action][next_state]\n", " return next_state, reward" ] }, { "cell_type": "markdown", "metadata": { "id": "0Q2HWVhNMM9d" }, "source": [ "또한 탐험 정책도 필요합니다. 가능한 모든 상태를 여러번 방문한다면 어떤 정책도 가능합니다. 상태 공간이 매우 작기 때문에 랜덤한 정책을 사용하겠습니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qlonNWVgMM9e" }, "outputs": [], "source": [ "def exploration_policy(state):\n", " return np.random.choice(possible_actions[state])" ] }, { "cell_type": "markdown", "metadata": { "id": "u5yJmXHEMM9e" }, "source": [ "이제 앞에서와 같이 Q-가치를 초기화하고 Q-러닝 알고리즘을 실행해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7C5MYH2NMM9e" }, "outputs": [], "source": [ "np.random.seed(42)\n", "\n", "Q_values = np.full((3, 3), -np.inf)\n", "for state, actions in enumerate(possible_actions):\n", " Q_values[state][actions] = 0\n", "\n", "alpha0 = 0.05 # 초기 학습률\n", "decay = 0.005 # 학습률 감쇄\n", "gamma = 0.90 # 할인 계수\n", "state = 0 # 초기 상태\n", "history2 = [] # 책에는 없음\n", "\n", "for iteration in range(10000):\n", " history2.append(Q_values.copy()) # 책에는 없음\n", " action = exploration_policy(state)\n", " next_state, reward = step(state, action)\n", " next_value = np.max(Q_values[next_state]) # 다음 스텝의 그리디 정책\n", " alpha = alpha0 / (1 + iteration * decay)\n", " Q_values[state, action] *= 1 - alpha\n", " Q_values[state, action] += alpha * (reward + gamma * next_value)\n", " state = next_state\n", "\n", "history2 = np.array(history2) # 책에는 없음" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2pA3UavRMM9e", "outputId": "62542fdc-c754-4e1a-8e3c-5c030138ac23" }, "outputs": [ { "data": { "text/plain": [ "array([[18.77621289, 17.2238872 , 13.74543343],\n", " [ 0. , -inf, -8.00485647],\n", " [ -inf, 49.40208921, -inf]])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q_values" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P2K_DxamMM9e", "outputId": "215116b7-af9f-442b-fda9-535c3073ea9d" }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 1])" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(Q_values, axis=1) # 각 상태에 대한 최적의 행동" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mZyBRX1zMM9f", "outputId": "6aba482f-9bf2-462a-f754-a66473c47259" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure q_value_plot\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "true_Q_value = history1[-1, 0, 0]\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)\n", "axes[0].set_ylabel(\"Q-Value$(s_0, a_0)$\", fontsize=14)\n", "axes[0].set_title(\"Q-Value Iteration\", fontsize=14)\n", "axes[1].set_title(\"Q-Learning\", fontsize=14)\n", "for ax, width, history in zip(axes, (50, 10000), (history1, history2)):\n", " ax.plot([0, width], [true_Q_value, true_Q_value], \"k--\")\n", " ax.plot(np.arange(width), history[:, 0, 0], \"b-\", linewidth=2)\n", " ax.set_xlabel(\"Iterations\", fontsize=14)\n", " ax.axis([0, width, 0, 24])\n", "\n", "save_fig(\"q_value_plot\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hgzBnxQLMM9f" }, "source": [ "# 심층 Q-네트워크" ] }, { "cell_type": "markdown", "metadata": { "id": "MO8UzaisMM9f" }, "source": [ "DQN을 만들어 보죠. 상태가 주어지면 가능한 모든 행동에 대해서 행동을 플레이한 후 (하지만 결과를 보기 전에) 기대할 수 있는 할인된 미래 보상의 합을 추정합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F8sJomGpMM9f" }, "outputs": [], "source": [ "keras.backend.clear_session()\n", "tf.random.set_seed(42)\n", "np.random.seed(42)\n", "\n", "env = gym.make(\"CartPole-v1\")\n", "input_shape = [4] # == env.observation_space.shape\n", "n_outputs = 2 # == env.action_space.n\n", "\n", "model = keras.models.Sequential([\n", " keras.layers.Dense(32, activation=\"elu\", input_shape=input_shape),\n", " keras.layers.Dense(32, activation=\"elu\"),\n", " keras.layers.Dense(n_outputs)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "cpvvM-jWMM9f" }, "source": [ "이 DQN을 사용해 행동을 선택하려면 가장 큰 예측 Q-가치를 가진 행동을 선택하면 됩니다. 하지만 에이전트가 환경을 탐험하려면 `epsilon` 확률로 랜덤한 행동을 선택합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VwRuzE5XMM9f" }, "outputs": [], "source": [ "def epsilon_greedy_policy(state, epsilon=0):\n", " if np.random.rand() < epsilon:\n", " return np.random.randint(n_outputs)\n", " else:\n", " Q_values = model.predict(state[np.newaxis])\n", " return np.argmax(Q_values[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "5fZ-3GVxMM9g" }, "source": [ "재생 메모리도 필요합니다. 여기에는 에이전트의 경험이 담겨 있습니다. 형식은 `(obs, action, reward, next_obs, done)`와 같습니다. `deque` 클래스를 사용할 수 있습니다(더 강력한 경험 재생의 구현을 위해 딥마인드의 [Reverb 라이브러리](https://github.com/deepmind/reverb)를 참고하세요):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zWp-e3BOMM9g" }, "outputs": [], "source": [ "from collections import deque\n", "\n", "replay_memory = deque(maxlen=2000)" ] }, { "cell_type": "markdown", "metadata": { "id": "ISMBLBd9MM9g" }, "source": [ "그리고 재생 메모리에서 경험을 샘플링하는 함수를 만듭니다. 이 함수는 5개의 넘파이 배열 `[states, actions, rewards, next_obs, dones]`을 반환합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CqeAa6CkMM9g" }, "outputs": [], "source": [ "def sample_experiences(batch_size):\n", " indices = np.random.randint(len(replay_memory), size=batch_size)\n", " batch = [replay_memory[index] for index in indices]\n", " states, actions, rewards, next_states, dones = [\n", " np.array([experience[field_index] for experience in batch])\n", " for field_index in range(5)]\n", " return states, actions, rewards, next_states, dones" ] }, { "cell_type": "markdown", "metadata": { "id": "djWVPLqCMM9g" }, "source": [ "이제 DQN을 사용해 한 스텝을 플레이하는 함수를 만들고 경험을 재생 메모리에 기록할 수 있습니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2AgsJA6kMM9g" }, "outputs": [], "source": [ "def play_one_step(env, state, epsilon):\n", " action = epsilon_greedy_policy(state, epsilon)\n", " next_state, reward, done, info = env.step(action)\n", " replay_memory.append((state, action, reward, next_state, done))\n", " return next_state, reward, done, info" ] }, { "cell_type": "markdown", "metadata": { "id": "ANVUmkAqMM9h" }, "source": [ "마지막으로 재생 메모리에서 약간의 경험을 샘플링하고 훈련 스텝을 수행하는 함수를 만들어 보죠:\n", "\n", "**노트**:\n", "* 2판의 처음 세 번의 릴리스에는 `target_Q_values`를 열 벡터로 변환하는 `reshape()` 연산이 빠져있습니다(`loss_fn()`에서 필요합니다).\n", "* 이 책은 학습률 1e-3을 사용하지만 아래 코드에서는 훈련이 크게 좋아지기 때문에 1e-2를 사용합니다. 또한 여러 가지 DQN의 학습률을 튜닝했습니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IE95b_C3MM9h" }, "outputs": [], "source": [ "batch_size = 32\n", "discount_rate = 0.95\n", "optimizer = keras.optimizers.Adam(learning_rate=1e-2)\n", "loss_fn = keras.losses.mean_squared_error\n", "\n", "def training_step(batch_size):\n", " experiences = sample_experiences(batch_size)\n", " states, actions, rewards, next_states, dones = experiences\n", " next_Q_values = model.predict(next_states)\n", " max_next_Q_values = np.max(next_Q_values, axis=1)\n", " target_Q_values = (rewards +\n", " (1 - dones) * discount_rate * max_next_Q_values)\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", " mask = tf.one_hot(actions, n_outputs)\n", " with tf.GradientTape() as tape:\n", " all_Q_values = model(states)\n", " Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)\n", " loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))" ] }, { "cell_type": "markdown", "metadata": { "id": "HDzA6K8hMM9h" }, "source": [ "이제 모델을 훈련해 보죠!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OoEEkDHuMM9h" }, "outputs": [], "source": [ "env.seed(42)\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "\n", "rewards = [] \n", "best_score = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "id": "XYdtvldOMM9h", "outputId": "2a3f5c5a-32d7-4851-f4db-03f7e1014439" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode: 599, Steps: 200, eps: 0.010" ] } ], "source": [ "for episode in range(600):\n", " obs = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", " if done:\n", " break\n", " rewards.append(step) # Not shown in the book\n", " if step >= best_score: # Not shown\n", " best_weights = model.get_weights() # Not shown\n", " best_score = step # Not shown\n", " print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\") # Not shown\n", " if episode > 50:\n", " training_step(batch_size)\n", "\n", "model.set_weights(best_weights)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EOy_doHlMM9i", "outputId": "c5b7f261-0654-4677-bba1-8e7ddfdb4bb5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure dqn_rewards_plot\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8, 4))\n", "plt.plot(rewards)\n", "plt.xlabel(\"Episode\", fontsize=14)\n", "plt.ylabel(\"Sum of rewards\", fontsize=14)\n", "save_fig(\"dqn_rewards_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ao-N4W5GMM9i" }, "outputs": [], "source": [ "env.seed(42)\n", "state = env.reset()\n", "\n", "frames = []\n", "\n", "for step in range(200):\n", " action = epsilon_greedy_policy(state)\n", " state, reward, done, info = env.step(action)\n", " if done:\n", " break\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " \n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "Jcs3uX_QMM9i" }, "source": [ "나쁘지 않네요! 😀" ] }, { "cell_type": "markdown", "metadata": { "id": "FwKGEv9nMM9i" }, "source": [ "## 더블 DQN" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sPnMAa4OMM9i" }, "outputs": [], "source": [ "keras.backend.clear_session()\n", "tf.random.set_seed(42)\n", "np.random.seed(42)\n", "\n", "model = keras.models.Sequential([\n", " keras.layers.Dense(32, activation=\"elu\", input_shape=[4]),\n", " keras.layers.Dense(32, activation=\"elu\"),\n", " keras.layers.Dense(n_outputs)\n", "])\n", "\n", "target = keras.models.clone_model(model)\n", "target.set_weights(model.get_weights())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TLAntgw1MM9i" }, "outputs": [], "source": [ "batch_size = 32\n", "discount_rate = 0.95\n", "optimizer = keras.optimizers.Adam(learning_rate=6e-3)\n", "loss_fn = keras.losses.Huber()\n", "\n", "def training_step(batch_size):\n", " experiences = sample_experiences(batch_size)\n", " states, actions, rewards, next_states, dones = experiences\n", " next_Q_values = model.predict(next_states)\n", " best_next_actions = np.argmax(next_Q_values, axis=1)\n", " next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n", " next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n", " target_Q_values = (rewards + \n", " (1 - dones) * discount_rate * next_best_Q_values)\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", " mask = tf.one_hot(actions, n_outputs)\n", " with tf.GradientTape() as tape:\n", " all_Q_values = model(states)\n", " Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)\n", " loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n0Qhf1P7MM9j" }, "outputs": [], "source": [ "replay_memory = deque(maxlen=2000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3xQNo7M1MM9j", "outputId": "8a29f368-626e-4bc2-ac2f-1c8e8f6e64ae" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode: 599, Steps: 55, eps: 0.0100" ] } ], "source": [ "env.seed(42)\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "\n", "rewards = []\n", "best_score = 0\n", "\n", "for episode in range(600):\n", " obs = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", " if done:\n", " break\n", " rewards.append(step)\n", " if step >= best_score:\n", " best_weights = model.get_weights()\n", " best_score = step\n", " print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\")\n", " if episode >= 50:\n", " training_step(batch_size)\n", " if episode % 50 == 0:\n", " target.set_weights(model.get_weights())\n", " # Alternatively, you can do soft updates at each step:\n", " #if episode >= 50:\n", " #target_weights = target.get_weights()\n", " #online_weights = model.get_weights()\n", " #for index in range(len(target_weights)):\n", " # target_weights[index] = 0.99 * target_weights[index] + 0.01 * online_weights[index]\n", " #target.set_weights(target_weights)\n", "\n", "model.set_weights(best_weights)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NKBNVII0MM9j", "outputId": "5f2e25a2-07b8-468b-c6b6-3a134991de07" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure double_dqn_rewards_plot\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8, 4))\n", "plt.plot(rewards)\n", "plt.xlabel(\"Episode\", fontsize=14)\n", "plt.ylabel(\"Sum of rewards\", fontsize=14)\n", "save_fig(\"double_dqn_rewards_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "id": "gsE640k5MM9j" }, "outputs": [], "source": [ "env.seed(43)\n", "state = env.reset()\n", "\n", "frames = []\n", "\n", "for step in range(200):\n", " action = epsilon_greedy_policy(state)\n", " state, reward, done, info = env.step(action)\n", " if done:\n", " break\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " \n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "RnHiUB4KMM9k" }, "source": [ "# 듀얼링 더블 DQN" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "96KQOB0RMM9k" }, "outputs": [], "source": [ "keras.backend.clear_session()\n", "tf.random.set_seed(42)\n", "np.random.seed(42)\n", "\n", "K = keras.backend\n", "input_states = keras.layers.Input(shape=[4])\n", "hidden1 = keras.layers.Dense(32, activation=\"elu\")(input_states)\n", "hidden2 = keras.layers.Dense(32, activation=\"elu\")(hidden1)\n", "state_values = keras.layers.Dense(1)(hidden2)\n", "raw_advantages = keras.layers.Dense(n_outputs)(hidden2)\n", "advantages = raw_advantages - K.max(raw_advantages, axis=1, keepdims=True)\n", "Q_values = state_values + advantages\n", "model = keras.models.Model(inputs=[input_states], outputs=[Q_values])\n", "\n", "target = keras.models.clone_model(model)\n", "target.set_weights(model.get_weights())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y_-4S5UmMM9k" }, "outputs": [], "source": [ "batch_size = 32\n", "discount_rate = 0.95\n", "optimizer = keras.optimizers.Adam(learning_rate=7.5e-3)\n", "loss_fn = keras.losses.Huber()\n", "\n", "def training_step(batch_size):\n", " experiences = sample_experiences(batch_size)\n", " states, actions, rewards, next_states, dones = experiences\n", " next_Q_values = model.predict(next_states)\n", " best_next_actions = np.argmax(next_Q_values, axis=1)\n", " next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n", " next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n", " target_Q_values = (rewards + \n", " (1 - dones) * discount_rate * next_best_Q_values)\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", " mask = tf.one_hot(actions, n_outputs)\n", " with tf.GradientTape() as tape:\n", " all_Q_values = model(states)\n", " Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)\n", " loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QRit2vBkMM9k" }, "outputs": [], "source": [ "replay_memory = deque(maxlen=2000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jEWYxRQZMM9l", "outputId": "9be7270a-84be-4b8c-f132-b74357c798aa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode: 599, Steps: 200, eps: 0.010" ] } ], "source": [ "env.seed(42)\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "\n", "rewards = []\n", "best_score = 0\n", "\n", "for episode in range(600):\n", " obs = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", " if done:\n", " break\n", " rewards.append(step)\n", " if step >= best_score:\n", " best_weights = model.get_weights()\n", " best_score = step\n", " print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\")\n", " if episode >= 50:\n", " training_step(batch_size)\n", " if episode % 50 == 0:\n", " target.set_weights(model.get_weights())\n", "\n", "model.set_weights(best_weights)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "41U4UdTxMM9l", "outputId": "11ea93c3-dfb5-45b9-f7c3-4f4bbf7b6363" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(rewards)\n", "plt.xlabel(\"Episode\")\n", "plt.ylabel(\"Sum of rewards\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "id": "K2CK44D5MM9l" }, "outputs": [], "source": [ "env.seed(42)\n", "state = env.reset()\n", "\n", "frames = []\n", "\n", "for step in range(200):\n", " action = epsilon_greedy_policy(state)\n", " state, reward, done, info = env.step(action)\n", " if done:\n", " break\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " \n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "F_0zgUNNMM9l" }, "source": [ "매우 안정적인 에이전트같습니다!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1PSFnlTiMM9l" }, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "Vag2RJSiMM9m" }, "source": [ "# TF-Agents를 사용해 브레이크아웃 게임하기" ] }, { "cell_type": "markdown", "metadata": { "id": "OzHU4KP2MM9m" }, "source": [ "TF-Agents를 사용해 브레이크아웃 플레이를 학습하는 에이전트를 만들어 보죠. 심층 Q-러닝 알고리즘을 사용하겠습니다. 따라서 이전 구현과 구성 요소를 쉽게 비교할 수 있습니다. 하지만 TF-Agents에는 다른 (그리고 복잡한) 알고리즘을 많이 구현되어 있습니다!" ] }, { "cell_type": "markdown", "metadata": { "id": "4ejGbHM0MM9m" }, "source": [ "## TF-Agents 환경" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UfK1JXGwMM9m" }, "outputs": [], "source": [ "tf.random.set_seed(42)\n", "np.random.seed(42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HdxvM9JeMM9o", "outputId": "2e35aa44-c936-4f58-9460-fe566dd7c48a" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tf_agents.environments import suite_gym\n", "\n", "env = suite_gym.load(\"Breakout-v4\")\n", "env" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HSy8OORDMM9p", "outputId": "7cc47c0c-0dec-4c74-b7a8-600e9ccbf197" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.gym" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N7m8k3asMM9p", "outputId": "ce7a862c-efe8-43ea-a068-107a8fc5034c" }, "outputs": [ { "data": { "text/plain": [ "TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([[[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " ...,\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]]], dtype=uint8))" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.seed(42)\n", "env.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DJQ1sDUFMM9p", "outputId": "f6f7f9f9-be2f-4f66-9017-2163dd10c3e3" }, "outputs": [ { "data": { "text/plain": [ "TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([[[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " ...,\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]]], dtype=uint8))" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.step(1) # Fire" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z47Lx1BJMM9q", "outputId": "dcd62dc0-78ce-4aa3-95b1-c3e863cb681f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure breakout_plot\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAIpCAYAAAD3tqwgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAASjElEQVR4nO3dW29c533v8f+aGZ5EKhEpKrLlpLHRGE2yN2IlKbCDHApHCFDkRfZl9K5RbV0IAWoDzcbeiHNAvRNBkWmLlCNKlKg5PPvCjRK3NknHwzU/Sp8P4JvRwjyPxuL6ch3mWV1rrQAgzWDREwCATyJQAEQSKAAiCRQAkQQKgEgCBUCk0VF/2HWde9ABOFWtte6TXncEBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARDpyLX46E/XdXXp0qXa3t6urvvEZalOzWQyqZ2dnfrwww+P3XZ7e7suXbpUg0G/v9vMZrPa2dmpvb29Y7fd3Nysy5cv13A47GFmf9Zaqw8++KDu3r1brZ39ZSw3NjbqxRdfrOXl5d7HvnfvXu3s7NR0Ou197EUZDod1+fLl2tzcPHbb3d3dev/992s2m/Uws8URqBDD4bC++c1v1ve+973ed6wPHjyo69evHxuowWBQX/va1+pHP/pR7zutx48f140bN+revXvH7vxffvnl+vGPf1xra2s9ze4j4/G4bt68WXt7e8/EjvXy5cv1k5/8pLa2tnodt7VWv/jFL+r69ev16NGjXsdepOXl5Xrttdfq29/+9pG/pM5ms3rrrbfqzTffrCdPnvQ4w/4JVIiu6+rcuXO1tbVVo1G//1uWl5drZWXlRNuura3V1tbWibefl0ePHtXq6uqJtl1dXa3Nzc1aX18/5Vl93Hg8rnPnzvU65mlaWlqqCxcu1MWLF3sdt7VWGxsbvR+lL1rXdbW+vl4XL148NlDr6+u9n2lZhOfrXwAAZ4ZAARDJKb4zprVWDx8+rL29vZpMJkduu7KyUhcvXjzxqbF5Ojg4qN3d3RqPx0dut7y8XBcvXuz9elHVR6cN9/b26vDw8MjtRqNRXbx4sfdThmfFeDyu3d3dOjg4mNt73r1795m4jsfnI1Bn0K1bt+rNN9+shw8fHrndCy+8UNeuXasXX3yxp5n92Z07d+qNN9449saLra2tunbtWn31q1/tZ2J/YXd3t65fv14ffPDBkdudP3++Xn/99Xr11Vefi/P+n9X+/n7dvHmz3n333bm958OHD5/5GwA4nkCdQQcHB3Xnzp3a398/crvBYLCwH/LHjx/Xe++9V7u7u0duNx6P6/Hjxz3N6uMODw9rZ2en7ty5c+R2Fy5ceK7uJvusJpNJ3b17t27fvr3oqfCMcQ0KgEgCBUAkp/iAz2U0Gj1dvWNeHj16VPv7+8/Eihz89QQK+Fw2Njbq+9//fl29enUu79daq3feead+/vOfL+z6JBkECvhclpeX66WXXprb+7XW6t69e70v+UUe16AAiCRQAERyig/4RJPJpO7fvz+3xYu7rqvV1dVaWVnxhWdORKCAT7Szs1M/+9nP5rZy/Wg0qqtXr9Y3vvENgeJEBAr4RPv7+/XOO+/M7f1WVlbqpZdeqq9//etze0+eba5BARBJoACI5BTfGTQajWptbe3Yx22srq4u7Kmkw+GwVldXj32Mxurq6sK+7zIYDE40x7W1tefyOzmDwaCWl5fndr1oZWWllpaWXH/ixATqDLpy5Updu3bt2GctbWxs1IULF/qZ1H/xpS99qV5//fVjVwJYW1ur7e3tnmb1cVtbW/XDH/7w2OcYrays1AsvvPDc7VgvXbpUr732Wp0/f34u7zccDusrX/nKc/c58tcTqDOm67ra3t6u7e3tY9cpW+SOYHNzszY3N6Pn+IUvfKGuXr0aPcdF2tzcrO985ztz+wXief0c+esJ1Bn0px/05B94c3w2dF23sNPE4F8eAJEcQQWZzWY1mUx6f8TAZxnzT3Ps+7fqyWRSs9nsRNv+aY7HXaObt/F4XNPptNcxT1NrbSGfY1U9U5/jZzGbzWo8Hh95VN9aO/HPwlknUCGm02n99re/XcjO//Dw8NjHnld99IPxu9/9rt54443e72obj8d169atE4X09u3bdePGjVpaWuphZn82nU7r3XfffWZ2Hru7u3Xz5s1aX1/vddzWWv3hD3+oJ0+e9Druoo3H4/rlL39Z+/v7xwbq97///bF38T4LuqN+4Luu87SwHg0Gg4VdD5nNZifa+S/ymsSzNMezYJGf4/N0lPCXTroPeNY+n9baJ/6lBQqAhfq0QLlJAoBIR16DcvstAItyZKB+8IMf9DUPAPiYIwP105/+tK95AMDHHBmo4xbRBIDT4iYJACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEijvgZqrdXBwUE9fPiwWmt9DQvAnHRdV+vr63Xu3Lnquu7Ux+stULPZrH7961/X22+/XdPptK9hAZiT0WhU3/3ud+tb3/rWsxWo1lrt7u7Wb37zmxqPx30NC8CcLC8v1yuvvNLbWTDXoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEijvgbqqmp1OKzN5eUad11fwwIwJ8vLy7U6HPY2Xm+BGnRd/f3WVv3PV1+tNp32NSwAc9KNRrW6tVWDng4yeg3U354/X1++cqUGAgVw5sxGo7q1sVF3qqr1MF5vgfpLnVN8ABzDTRIARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkRayFl/rZZlBAM6yfgO1Mqs6P66aWc0c4Kxpw1a1POttvP4C1VW1tUm1i4fVmkABnDWtm1StTT56Am0P+j2C6v6zwM0pPoAzp2vVerxzwU0SAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIPS4W26p1raaDWbXW33LtAMzHrBtU6/p7ol9vgWpVNR5N69HquLo26WtYAOZk1s1qPOrvcUm9Pm5jNmg1Hk5rUI6gAM6aWc1q1vX3uCTXoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkXpdLLZ1VdVV9bdYOwDz8qd9eF/6C1RXNVmb1aO1aXXV33LtAMxHq0FNejzA6PUIarLa6vG5aXUetwFw5rSa1uRgVvW4n/F6DVR1f/EfAGdLq173326SACCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABE6m2x2FZVD2pU99tqlcdtAJw9bVTDGvW25nePgerqvbZat9tmzQQK4MwZtGF9ua3WC/WMBaqq6rAGdb9GNXNmEeDMGdawDnvcfysFAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiBSr4vF1mS56vF6tW7W67AAfH6tDaomS72N11+g2qBm779c071XazrrY6F2AOZq0Gq2Na7anPTyvI0eA9VV+/ByTf/fl2s2c2YR4KzphrNqw1tVF+5Ude3Ux+v3FN9TjqAAOJpDGQAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIC1mLr7XTX2QQgDnred/dW6Bam9WD/d/We3/43zWZ9DUqAPOyNKq68qVzVbVafSz63WugPtx7u/7jNzdq/EShAM6a5ZWleuVv/qFa+15VDU99vB5P8bWaTh/X+Mkfazwe9zcsAHPRdcs1mx32Np6bJACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRRoueAKQbVNX6aFQrw+HT11pr9XA6rcfT6eImBs84gYJjnF9aqn+8cqX+xxe/+PS1R9Np/cudO/XW3t4CZwbPNoGCY6wOh3V1c7N+fPny09f2J5P61f379fbeXrUFzg2eZQIFn0HXdYueAjw33CQBQCSBAiCSQAEQSaAAiCRQAEQSKAAiuc0cjtGqatpajVurah9962k8m9Ws+QYUnCaBgmMcTCZ1Y2enbj18+PS1w9msfr2/70u6cIoECo7xYDKpf93ZqRvvv//0tVZVE0dQcKoECk5g0pogQc/cJAFAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQKTRUX+4P5jObaDpcFaHXas2t3cE5m3QdXVxebk2lpYWPZVPNWut7j15UvfH40VP5fnTWi0dHta5Bw9qMBye+nBHBur/rD2Z20Cz6aR2liYCBcHWh8P6xytX6n9tb1e36Ml8ikfTaf3z7dv15s6O/UnPBq3V1gcf1Mu/+lUNB6d/Au7IQO0uze8IatZN62DQqnX+SUGq0WBQr2xs1N9vbVXXZSbqwXhc/7a7W12VQPWsa61WDw7qi3t7NeohUK5BARBJoACIJFAARBIoACIJFACRBAqASAIFQKQjvwcFPF9aa/VoOq3743Hs96AeTib1ZDZb9DTogUABTz2aTuuN996r/9jfX/RUPtWktfq/H37oS7rPAYECnjqczeqtvb16e29v0VM5kjg9HwQK+G8EgARukgAgkiMoAE6kVdW0tXoym9W0nf5xtkABcCKT2aze3turx9NpDeZ4l+c/fMrrAgXAiUxaq7f39urf53wTzT99yutHP1H31ntzm0CbTuvJH/erfH0B4MyatdbbbrxrR5xH3Py7l+d2krG1Vof37tfj3T9W9XDuEoCzobX2iecLjwxU13n8LQCn69MC5TZzACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkbrW2qLnAAD/jSMoACIJFACRBAqASAIFQCSBAiCSQAEQ6f8DiyfmfwULrTYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img = env.render(mode=\"rgb_array\")\n", "\n", "plt.figure(figsize=(6, 8))\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"breakout_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jtx1u0ITMM9s", "outputId": "4f30bddd-6911-469d-ed0e-0901f27e051e" }, "outputs": [ { "data": { "text/plain": [ "TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([[[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " ...,\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]]], dtype=uint8))" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.current_time_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "12x24B2PMM9s" }, "source": [ "## 환경 스펙" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kk38Rk_rMM9t", "outputId": "00928a64-3623-445d-f4c2-17f0fa8b63a6" }, "outputs": [ { "data": { "text/plain": [ "BoundedArraySpec(shape=(210, 160, 3), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255)" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.observation_spec()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i9WO1wqXMM9t", "outputId": "6dc6bf2e-1b89-40bc-b052-0d883bf32aee" }, "outputs": [ { "data": { "text/plain": [ "BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=3)" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_spec()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6okHHwEFMM9t", "outputId": "a4c2bc5a-72ad-40c2-f947-1b0383b559a2" }, "outputs": [ { "data": { "text/plain": [ "TimeStep(step_type=ArraySpec(shape=(), dtype=dtype('int32'), name='step_type'), reward=ArraySpec(shape=(), dtype=dtype('float32'), name='reward'), discount=BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0), observation=BoundedArraySpec(shape=(210, 160, 3), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255))" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.time_step_spec()" ] }, { "cell_type": "markdown", "metadata": { "id": "xDpYjirOMM9t" }, "source": [ "## 환경 래퍼" ] }, { "cell_type": "markdown", "metadata": { "id": "fmdUZvGDMM9t" }, "source": [ "TF-Agents 래퍼로 TF-Agents 환경을 감쌀 수 있습니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "THk6YyUqMM9t", "outputId": "7df6c899-1b04-4c04-812b-24cc7c1575da" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tf_agents.environments.wrappers import ActionRepeat\n", "\n", "repeating_env = ActionRepeat(env, times=4)\n", "repeating_env" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B64H12U2MM9u", "outputId": "141b8e0e-5ad6-4b0d-e2d1-893727d9652f" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repeating_env.unwrapped" ] }, { "cell_type": "markdown", "metadata": { "id": "7Xkqpth6MM9u" }, "source": [ "가능한 래퍼 목록은 다음과 같습니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pQ3yQ-mdMM9u", "outputId": "aae08213-dd80-4c78-8202-95a8fa04337a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ActionClipWrapper Wraps an environment and clips actions to spec before applying.\n", "ActionDiscretizeWrapper Wraps an environment with continuous actions and discretizes them.\n", "ActionOffsetWrapper Offsets actions to be zero-based.\n", "ActionRepeat Repeates actions over n-steps while acummulating the received reward.\n", "FlattenObservationsWrapper Wraps an environment and flattens nested multi-dimensional observations.\n", "GoalReplayEnvWrapper Adds a goal to the observation, used for HER (Hindsight Experience Replay).\n", "HistoryWrapper Adds observation and action history to the environment's observations.\n", "ObservationFilterWrapper Filters observations based on an array of indexes.\n", "OneHotActionWrapper Converts discrete action to one_hot format.\n", "PerformanceProfiler End episodes after specified number of steps.\n", "PyEnvironmentBaseWrapper PyEnvironment wrapper forwards calls to the given environment.\n", "RunStats Wrapper that accumulates run statistics as the environment iterates.\n", "TimeLimit End episodes after specified number of steps.\n" ] } ], "source": [ "import tf_agents.environments.wrappers\n", "\n", "for name in dir(tf_agents.environments.wrappers):\n", " obj = getattr(tf_agents.environments.wrappers, name)\n", " if hasattr(obj, \"__base__\") and issubclass(obj, tf_agents.environments.wrappers.PyEnvironmentBaseWrapper):\n", " print(\"{:27s} {}\".format(name, obj.__doc__.split(\"\\n\")[0]))" ] }, { "cell_type": "markdown", "metadata": { "id": "614L79fnMM9u" }, "source": [ "`suite_gym.load()`는 TF-Agents 환경 래퍼와 짐 환경 래퍼로 환경을 만들고 래핑합니다(후자가 먼저 적용됩니다)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f3Yr4xBAMM9u" }, "outputs": [], "source": [ "from functools import partial\n", "from gym.wrappers import TimeLimit\n", "\n", "limited_repeating_env = suite_gym.load(\n", " \"Breakout-v4\",\n", " gym_env_wrappers=[partial(TimeLimit, max_episode_steps=10000)],\n", " env_wrappers=[partial(ActionRepeat, times=4)],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sCEYP6X_MM9u", "outputId": "b0b0e284-28f5-49d7-fc0d-8fcb8f97d421" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "limited_repeating_env" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oGEO08eyMM9v", "outputId": "3e503dee-0d1f-41e3-e956-ea86aa8a328f" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "limited_repeating_env.unwrapped" ] }, { "cell_type": "markdown", "metadata": { "id": "wdgTTSshMM9v" }, "source": [ "아타리 브레이크아웃 환경을 만들고 기본 아타리 전처리 단계를 적용합니다:" ] }, { "cell_type": "markdown", "metadata": { "id": "eGxzv4XlMM9v" }, "source": [ "**경고**: 브레이크아웃은 게임 시작과 죽을 때마다 FIRE 버튼을 눌러야 합니다. 처음에는 FIRE 버튼을 누르는 것이 빨리 지는 것처럼 보이기 때문에 에이전트가 이를 배우는데 매우 오랜 시간이 걸릴 수 있습니다. 훈련 속도를 높이려면 `AtariPreprocessing` 래퍼 클래스를 상속하여 `AtariPreprocessingWithAutoFire`를 만들고 사용합니다. 이 클래스는 게임 시작과 말이 죽을 때마다 자동으로 FIRE(즉 플레이 행동 1)를 누릅니다. 일반적인 `AtariPreprocessing` 래퍼를 사용한 책의 코드와 다른 점입니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ptRx_mCTMM9v" }, "outputs": [], "source": [ "from tf_agents.environments import suite_atari\n", "from tf_agents.environments.atari_preprocessing import AtariPreprocessing\n", "from tf_agents.environments.atari_wrappers import FrameStack4\n", "\n", "max_episode_steps = 27000 # <=> 108k ALE frames since 1 step = 4 frames\n", "environment_name = \"BreakoutNoFrameskip-v4\"\n", "\n", "class AtariPreprocessingWithAutoFire(AtariPreprocessing):\n", " def reset(self, **kwargs):\n", " obs = super().reset(**kwargs)\n", " super().step(1) # FIRE to start\n", " return obs\n", " def step(self, action):\n", " lives_before_action = self.ale.lives()\n", " obs, rewards, done, info = super().step(action)\n", " if self.ale.lives() < lives_before_action and not done:\n", " super().step(1) # FIRE to start after life lost\n", " return obs, rewards, done, info\n", "\n", "env = suite_atari.load(\n", " environment_name,\n", " max_episode_steps=max_episode_steps,\n", " gym_env_wrappers=[AtariPreprocessingWithAutoFire, FrameStack4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WorYqLXtMM9v", "outputId": "8194b3ab-178c-428e-d371-0f52233df0b1" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env" ] }, { "cell_type": "markdown", "metadata": { "id": "o8Lr8XZwMM9v" }, "source": [ "몇 개의 스텝을 플레이하고 어떻게 동작하는지 확인합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XQt-IN-wMM9v" }, "outputs": [], "source": [ "env.seed(42)\n", "env.reset()\n", "for _ in range(4):\n", " time_step = env.step(3) # 왼쪽" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zJYOoAhzMM9w" }, "outputs": [], "source": [ "def plot_observation(obs):\n", " # 컬러 채널이 3개이기 때문에 4 프레임을 출력할 수 없습니다.\n", " # 따라서 현재 프레임과 다른 프레임의 평균 값을 뺀 차이를 계산합니다.\n", " # 그다음 이 차이를 현재 프레임의 빨강과 파랑 채널에 더해서 보라 색을 구합니다.\n", " obs = obs.astype(np.float32)\n", " img = obs[..., :3]\n", " current_frame_delta = np.maximum(obs[..., 3] - obs[..., :3].mean(axis=-1), 0.)\n", " img[..., 0] += current_frame_delta\n", " img[..., 2] += current_frame_delta\n", " img = np.clip(img / 150, 0, 1)\n", " plt.imshow(img)\n", " plt.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sc2BjqzsMM9w", "outputId": "095a4c8f-76ac-4993-b834-1e201f7ca2a5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure preprocessed_breakout_plot\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAacAAAGoCAYAAADiuSpNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADJ9JREFUeJzt3U+InOUBx/F3Jlmwi+AW/6QExMtiKggGeuplCaUiueToXQQFoaBIiUJ7KgQ8CEK9tFBE8KCeRZHmENJCT5LsQU08KVpipVFZZDAsO28vPWTmeZcdJ+87729mPp/b+/Ds+z7ZnZ1vXt533h3UdV0BQJJh3wsAgGniBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHGO93HQwWDgk78AVHVdD5rGnTkBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECIE4vH8KlPSdOnCjGPvvssyO/7sqVK0fOefTRR4uxjY2Nie2dnZ1izu7ubjH29ttvT2yfPXu2mPPFF18UY99+++3EdtO/9+TJkxPbb7zxRjHn+eefL8aeeOKJie133323mDMajYqx69evT2xPf0+qqvl7N+2ee+45ck6il156qRg7f/78xPb0z62qqurLL7+c63jT3++qqqpnn312rn0tg9dee60Ye+qppya2L1y4UMx55ZVXOltTH5w5ARBHnACII04AxBEnAOK4IWJNnTlz5sg5n376aTE2ffNBm1599dVi7M0335zYbroY//LLL3e2pqaL8dPfu3lvSlll165dK8aabjiZxc2bN+90OSwhZ04AxBEnAOKIEwBxXHMCfpIPPvigGLtx48Zc+5r+sPIzzzxTzGn6wPh777031/FYHs6cAIgjTgDEEScA4ogTAHHcELGmLl26dOSc+++/v/uF3ObFF18sxqafxtz0gdcunTp1qhib/t41PZV8lZ0+fboYm/45zWpZn8xO95w5ARBHnACII04AxBEnAOIM6rpe+EFPnz69+IMCEOfq1auDpnFnTgDEEScA4ogTAHF6uea0t7e30IOOx+OJ7eGwbPL0nMPmtXH8Wffd5Zq61Pe6+z7+Oprld2yVzfqaa+u9qK05CWva2tpyzQmA5SBOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQJzjfS/gTly/fr0YG41GPawEYHVtbm4WY6dOner0mM6cAIgjTgDEEScA4ogTAHEGdV0v/KB7e3utHHRnZ6cY293dbWPXAPzfY489Voxdvnx5Yns8HhdzhsPhkXO2trYGTcd05gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBECc430v4E489NBDxdhoNCrGxuPxxPZwWDZ5es5h8+Yx7767XFOX+l5338dfR7P8jq2yWV9zbb0XtTVn1jU9+OCDxVjX1usVBMBSECcA4ogTAHGW+prT+fPnizHXnPrX97r7Pv46cs1pta85bW5uFmNdW69XEABLQZwAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgzlJ/CPe+++4rxvb393tYCcDq2tjYWPgxnTkBEEecAIgjTgDEEScA4iz1DRFNF+m6fBLwvDyV3FPJV52nkq/2U8mPHTtWjHVtvV5BACwFcQIgjjgBEEecAIiz1DdENF24Ozg4mGnePHPmNe++l/Wict/r7vv462jdv+ez/vvbei9q8z2t7/fHQ4+58CMCwBHECYA44gRAnJW75tT0YbG6rie2B4PBkXMOmzePeffd5Zq61Pe6+z7+Oprld2yVzfqaa+u9qK05s67JNScAqMQJgEDiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHGW+gkRDzzwQDHW9IQIf6Z9sfped9/HX0f+TPtq/5n2pr/2cOvWrWKsTev1CgJgKYgTAHHECYA44gRAHHECII44ARBHnACII04AxFnqD+HevHmzGGv6sBgA82t6uMHdd9/d6TGdOQEQR5wAiCNOAMRZ6mtOe3t7xdj+/n4xVtf1xPZgMDhyzmHz5jHvvrtcU5f6Xnffx19Hs/yOrbJZX3NtvRe1NWfWNW1sbBRjrjkBsHbECYA44gRAHHECIM5S3xAxGo2KsR9//LGHlQCsrrvuumvhx3TmBEAccQIgjjgBEEecAIiz1DdEfPLJJ8VY05PKx+PxxPZwWDZ5es5h8+Yx7767XFOX+l5338dfR7P8jq2yWV9zbb0XtTVn1jXde++9xdj29nYx1qb1egUBsBTECYA44gRAHHECII44ARBHnACII04AxBEnAOIs9Ydw33rrrWLs448/Lsb8mfbF6nvdfR9/Hfkz7av9Z9ofeeSRYuzcuXPFWJucOQEQR5wAiCNOAMQRJwDiLPUNEV9//XUx9tVXX/WwEoDV1fRU8q45cwIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACIc7zvBUCy3zaM/Xpq+0+LWAisGWdOAMQRJwDiiBMAcVxzYn1tTm3/vJzym3+XY8c6WQxwO2dOAMQRJwDiiBMAccQJgDhuiGB9/WFqe6Oc8uffl2N7nSwGuJ0zJwDiiBMAccQJgDiuObGCdhrGflYO/fDh5PZH5ZQbrawH+KmcOQEQR5wAiCNOAMQRJwDiuCGCFfS7hrGGux0ufFiOARGcOQEQR5wAiCNOAMQRJwDiuCGCaINqMLH9evV6MWe32p3Y/mv1l4Y9/bPNZQEdc+YEQBxxAiCOOAEQxzUnog2n/v/0cPVwMeed6p2pkcsdrghYBGdOAMQRJwDiiBMAccQJgDhuiCDaQXUwsf149XhPKwEWyZkTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA44gRAHHECII44ARDneB8HHY1GrexnPB63sp/DDKrBxPbZ6uxc+6mr+sh9t/l1s3i/utYw+stW9l01rLtqad2z+Xs5dGa/HNvsfiVdOHFjcvtXV/pZx+2mf+Kz/rT/M7X9UQtroX37++XvzzfffNPKvre2thrHnTkBEEecAIgjTgDEEScA4vRyQ8T333/fyn4ODg5a2c9hhlPtfqF6Ya79jKvyxo3pfbf5dbN4v/pbw+jTrey7alj3Yv8f9I9y6OmGGyJ+0f1KurB9eXL7hYAbIqZ/4rP+tP81te2GiEy3bt0qxj7//POJ7aYb1IbD4ZFztre3G4/pzAmAOOIEQBxxAiCOOAEQp5cbIuq64ckHgxmemNDwdV06qCZvuHiuem6u/STeEFFV/20Ya+tydN83RPxQDv2xYdpG5wvpxNW9ye35XpXtmveGiIafFFRV5cwJgEDiBEAccQIgzmDR13GqqqqefPLJVg568eLFYuy7775rY9cALEBd1403HDhzAiCOOAEQR5wAiCNOAMTp5YaIwWCw+IMCEMcNEQAsDXECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4gzquu57DQAwwZkTAHHECYA44gRAHHECII44ARBHnACII04AxBEnAOKIEwBxxAmAOOIEQBxxAiCOOAEQR5wAiCNOAMQRJwDiiBMAccQJgDjiBEAccQIgjjgBEEecAIgjTgDEEScA4ogTAHHECYA4/wPd8aiZkkg/HgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(6, 6))\n", "plot_observation(time_step.observation)\n", "save_fig(\"preprocessed_breakout_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "m5LkCvqDMM9w" }, "source": [ "파이썬 환경을 TF 환경으로 변환합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wQ-_-OGOMM9w" }, "outputs": [], "source": [ "from tf_agents.environments.tf_py_environment import TFPyEnvironment\n", "\n", "tf_env = TFPyEnvironment(env)" ] }, { "cell_type": "markdown", "metadata": { "id": "fPzGxdg0MM9w" }, "source": [ "## DQN 만들기" ] }, { "cell_type": "markdown", "metadata": { "id": "Lat01T6IMM9w" }, "source": [ "관측을 정규화하는 작은 클래스를 만듭니다. 이미지를 0~255 사이의 바이트로 저장하는 것이 램을 적게 사용하지만 신경망에는 0.0~1.0 사이의 실수를 전달해야 합니다:" ] }, { "cell_type": "markdown", "metadata": { "id": "qewdBtjYMM9w" }, "source": [ "Q-네트워크를 만듭니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0KEhXLPMM9x" }, "outputs": [], "source": [ "from tf_agents.networks.q_network import QNetwork\n", "\n", "preprocessing_layer = keras.layers.Lambda(\n", " lambda obs: tf.cast(obs, np.float32) / 255.)\n", "conv_layer_params=[(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]\n", "fc_layer_params=[512]\n", "\n", "q_net = QNetwork(\n", " tf_env.observation_spec(),\n", " tf_env.action_spec(),\n", " preprocessing_layers=preprocessing_layer,\n", " conv_layer_params=conv_layer_params,\n", " fc_layer_params=fc_layer_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "dO8912T2MM9x" }, "source": [ "DQN 에이전트를 만듭니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YZ7_mYSDMM9x" }, "outputs": [], "source": [ "from tf_agents.agents.dqn.dqn_agent import DqnAgent\n", "\n", "train_step = tf.Variable(0)\n", "update_period = 4 # run a training step every 4 collect steps\n", "optimizer = keras.optimizers.RMSprop(learning_rate=2.5e-4, rho=0.95, momentum=0.0,\n", " epsilon=0.00001, centered=True)\n", "epsilon_fn = keras.optimizers.schedules.PolynomialDecay(\n", " initial_learning_rate=1.0, # initial ε\n", " decay_steps=250000 // update_period, # <=> 1,000,000 ALE frames\n", " end_learning_rate=0.01) # final ε\n", "agent = DqnAgent(tf_env.time_step_spec(),\n", " tf_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=optimizer,\n", " target_update_period=2000, # <=> 32,000 ALE frames\n", " td_errors_loss_fn=keras.losses.Huber(reduction=\"none\"),\n", " gamma=0.99, # discount factor\n", " train_step_counter=train_step,\n", " epsilon_greedy=lambda: epsilon_fn(train_step))\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "7FDZnZydMM9x" }, "source": [ "재생 버퍼를 만듭니다(램을 많이 사용하기 때문에 메모리 부족 에러가 나오면 버퍼 크기를 줄이세요):" ] }, { "cell_type": "markdown", "metadata": { "id": "BIR_XItoMM9x" }, "source": [ "**경고**: (책과 달리) 1,000,000이 아니고 100,000 크기의 재생 버퍼를 사용합니다. 대부분의 경우 메모리 부족 에러가 나기 때문입니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uTnd-iELMM9x" }, "outputs": [], "source": [ "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.collect_data_spec,\n", " batch_size=tf_env.batch_size,\n", " max_length=100000) # OOM 에러가 나면 줄이세요\n", "\n", "replay_buffer_observer = replay_buffer.add_batch" ] }, { "cell_type": "markdown", "metadata": { "id": "kH1cPdVPMM9x" }, "source": [ "호출 횟수를 카운트하고 출력하는 간단한 사용자 정의 옵저버를 만듭니다(하나의 스텝으로 카운트하지 않는 두 에피소드 사이의 경계는 제외합니다):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GtlRfjVVMM9x" }, "outputs": [], "source": [ "class ShowProgress:\n", " def __init__(self, total):\n", " self.counter = 0\n", " self.total = total\n", " def __call__(self, trajectory):\n", " if not trajectory.is_boundary():\n", " self.counter += 1\n", " if self.counter % 100 == 0:\n", " print(\"\\r{}/{}\".format(self.counter, self.total), end=\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "eJ0eNLtOMM9y" }, "source": [ "훈련 측정 지표를 추가해 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u3G1rCtKMM9y" }, "outputs": [], "source": [ "from tf_agents.metrics import tf_metrics\n", "\n", "train_metrics = [\n", " tf_metrics.NumberOfEpisodes(),\n", " tf_metrics.EnvironmentSteps(),\n", " tf_metrics.AverageReturnMetric(),\n", " tf_metrics.AverageEpisodeLengthMetric(),\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "h-wWJZddMM9y", "outputId": "6ac4c76c-58b5-4c71-d482-efc86541594f" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_metrics[0].result()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8SXxCde5MM9y", "outputId": "0daa0786-045f-4072-a237-a236d82a83a9" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 0\n", "\t\t EnvironmentSteps = 0\n", "\t\t AverageReturn = 0.0\n", "\t\t AverageEpisodeLength = 0.0\n" ] } ], "source": [ "from tf_agents.eval.metric_utils import log_metrics\n", "import logging\n", "logging.getLogger().setLevel(logging.INFO)\n", "log_metrics(train_metrics)" ] }, { "cell_type": "markdown", "metadata": { "id": "YlXRcaauMM9y" }, "source": [ "수집 드라이버를 만듭니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wj-UwBSDMM9z" }, "outputs": [], "source": [ "from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver\n", "\n", "collect_driver = DynamicStepDriver(\n", " tf_env,\n", " agent.collect_policy,\n", " observers=[replay_buffer_observer] + train_metrics,\n", " num_steps=update_period) # collect 4 steps for each training iteration" ] }, { "cell_type": "markdown", "metadata": { "id": "jA7YyShHMM9z" }, "source": [ "훈련 전에 초기 경험을 수집합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CyBfjSfeMM9z", "outputId": "6fe75dec-4296-4db4-d419-1c3d0304d8a3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "20000/20000" ] } ], "source": [ "from tf_agents.policies.random_tf_policy import RandomTFPolicy\n", "\n", "initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),\n", " tf_env.action_spec())\n", "init_driver = DynamicStepDriver(\n", " tf_env,\n", " initial_collect_policy,\n", " observers=[replay_buffer.add_batch, ShowProgress(20000)],\n", " num_steps=20000) # <=> 80,000 ALE frames\n", "final_time_step, final_policy_state = init_driver.run()" ] }, { "cell_type": "markdown", "metadata": { "id": "kVsKhQqQMM9z" }, "source": [ "3개의 스텝을 가진 2개의 서브 에피소드를 샘플링해서 출력해 보죠:" ] }, { "cell_type": "markdown", "metadata": { "id": "FXwAvr6CMM9z" }, "source": [ "**노트**: `replay_buffer.get_next()`는 deprecated 되었습니다. 대신 `replay_buffer.as_dataset(..., single_deterministic_pass=False)`를 사용해야 합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C08ZmQKvMM9z" }, "outputs": [], "source": [ "tf.random.set_seed(9) # 에피소드 끝에서 경로 샘플을 보여주기 위해\n", "\n", "#trajectories, buffer_info = replay_buffer.get_next( # get_next() is deprecated\n", "# sample_batch_size=2, num_steps=3)\n", "\n", "trajectories, buffer_info = next(iter(replay_buffer.as_dataset(\n", " sample_batch_size=2,\n", " num_steps=3,\n", " single_deterministic_pass=False)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cnYnHDN8MM90", "outputId": "f02f60a9-aa0a-40a3-abe7-31431c2d44cf" }, "outputs": [ { "data": { "text/plain": [ "('step_type',\n", " 'observation',\n", " 'action',\n", " 'policy_info',\n", " 'next_step_type',\n", " 'reward',\n", " 'discount')" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trajectories._fields" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aychI9z6MM90", "outputId": "baddbb8b-9b6f-48af-fdc1-cc338d6f9306" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([2, 3, 84, 84, 4])" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trajectories.observation.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "16Er6LrOMM90", "outputId": "8f4d3800-63ba-49a4-fc83-eb956b5af3a2" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([2, 2, 84, 84, 4])" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tf_agents.trajectories.trajectory import to_transition\n", "\n", "time_steps, action_steps, next_time_steps = to_transition(trajectories)\n", "time_steps.observation.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mn5k4vzBMM90", "outputId": "b3d7f23a-b7e6-4c08-89f4-a081142945d7" }, "outputs": [ { "data": { "text/plain": [ "array([[1, 1, 1],\n", " [1, 1, 1]], dtype=int32)" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trajectories.step_type.numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ixiR9LUgMM91", "outputId": "07a7695c-3a1e-4850-81e2-9fcb070a8b8f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure sub_episodes_plot\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6.8))\n", "for row in range(2):\n", " for col in range(3):\n", " plt.subplot(2, 3, row * 3 + col + 1)\n", " plot_observation(trajectories.observation[row, col].numpy())\n", "plt.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0.02)\n", "save_fig(\"sub_episodes_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "4DcXRL7KMM91" }, "source": [ "이제 데이터셋을 만들어 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "W1GCdTZZMM91" }, "outputs": [], "source": [ "dataset = replay_buffer.as_dataset(\n", " sample_batch_size=64,\n", " num_steps=2,\n", " num_parallel_calls=3).prefetch(3)" ] }, { "cell_type": "markdown", "metadata": { "id": "1Zk3mXi6MM91" }, "source": [ "성능을 높이기 위해 메인 함수를 TF 함수로 변환합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pjQ1bgAKMM91" }, "outputs": [], "source": [ "from tf_agents.utils.common import function\n", "\n", "collect_driver.run = function(collect_driver.run)\n", "agent.train = function(agent.train)" ] }, { "cell_type": "markdown", "metadata": { "id": "OfdoyCEVMM91" }, "source": [ "이제 메인 루프를 실행할 준비가 되었습니다!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-p7irXxjMM91" }, "outputs": [], "source": [ "def train_agent(n_iterations):\n", " time_step = None\n", " policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)\n", " iterator = iter(dataset)\n", " for iteration in range(n_iterations):\n", " time_step, policy_state = collect_driver.run(time_step, policy_state)\n", " trajectories, buffer_info = next(iterator)\n", " train_loss = agent.train(trajectories)\n", " print(\"\\r{} loss:{:.5f}\".format(\n", " iteration, train_loss.loss.numpy()), end=\"\")\n", " if iteration % 1000 == 0:\n", " log_metrics(train_metrics)" ] }, { "cell_type": "markdown", "metadata": { "id": "cNlkxwaLMM92" }, "source": [ "다음 셀에서 에이전트를 50,000 스텝 동안 훈련합니다. 그다음 다음 셀을 실행하여 에이전트의 동작을 살펴 보겠습니다. 이 두 셀을 원하는만큼 많이 실행할 수 있습니다. 에이전트는 점점 향상될 것입니다! 에이전트가 어느정도 좋은 동작을 수행하려면 200,000 반복 정도 걸릴 것입니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UnCIGFW8MM92", "outputId": "337d2a2c-9a3c-4408-8fdb-628cbe281483" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /opt/conda/envs/tf2/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /opt/conda/envs/tf2/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n", "INFO:absl: \n", "\t\t NumberOfEpisodes = 0\n", "\t\t EnvironmentSteps = 4\n", "\t\t AverageReturn = 0.0\n", "\t\t AverageEpisodeLength = 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "998 loss:0.00008" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 24\n", "\t\t EnvironmentSteps = 4004\n", "\t\t AverageReturn = 1.7000000476837158\n", "\t\t AverageEpisodeLength = 184.1999969482422\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1998 loss:0.00181" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 48\n", "\t\t EnvironmentSteps = 8004\n", "\t\t AverageReturn = 1.7000000476837158\n", "\t\t AverageEpisodeLength = 182.39999389648438\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2998 loss:0.00005" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 73\n", "\t\t EnvironmentSteps = 12004\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "<<244 more lines>>\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\t\t NumberOfEpisodes = 1003\n", "\t\t EnvironmentSteps = 176004\n", "\t\t AverageReturn = 5.099999904632568\n", "\t\t AverageEpisodeLength = 246.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "44998 loss:0.00165" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 1019\n", "\t\t EnvironmentSteps = 180004\n", "\t\t AverageReturn = 5.199999809265137\n", "\t\t AverageEpisodeLength = 256.6000061035156\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "45998 loss:0.00136" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 1035\n", "\t\t EnvironmentSteps = 184004\n", "\t\t AverageReturn = 4.599999904632568\n", "\t\t AverageEpisodeLength = 252.1999969482422\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "46998 loss:0.00100" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 1050\n", "\t\t EnvironmentSteps = 188004\n", "\t\t AverageReturn = 5.699999809265137\n", "\t\t AverageEpisodeLength = 276.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "47998 loss:0.00116" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 1063\n", "\t\t EnvironmentSteps = 192004\n", "\t\t AverageReturn = 5.900000095367432\n", "\t\t AverageEpisodeLength = 296.3999938964844\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "48998 loss:0.00049" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl: \n", "\t\t NumberOfEpisodes = 1077\n", "\t\t EnvironmentSteps = 196004\n", "\t\t AverageReturn = 7.800000190734863\n", "\t\t AverageEpisodeLength = 308.29998779296875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "49999 loss:0.00073" ] } ], "source": [ "train_agent(n_iterations=50000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pLO9TGYzMM92" }, "outputs": [], "source": [ "frames = []\n", "def save_frames(trajectory):\n", " global frames\n", " frames.append(tf_env.pyenv.envs[0].render(mode=\"rgb_array\"))\n", "\n", "watch_driver = DynamicStepDriver(\n", " tf_env,\n", " agent.policy,\n", " observers=[save_frames, ShowProgress(1000)],\n", " num_steps=1000)\n", "final_time_step, final_policy_state = watch_driver.run()\n", "\n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "l0sfIwMrMM92" }, "source": [ "에이전트를 친구에게 보여주고 싶어서 애니메이션 GIF로 저장하고 싶다면 다음 방법을 사용하세요:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wniytrMWMM92" }, "outputs": [], "source": [ "import PIL\n", "\n", "image_path = os.path.join(\"images\", \"rl\", \"breakout.gif\")\n", "frame_images = [PIL.Image.fromarray(frame) for frame in frames[:150]]\n", "frame_images[0].save(image_path, format='GIF',\n", " append_images=frame_images[1:],\n", " save_all=True,\n", " duration=30,\n", " loop=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "78ons3n3MM92", "outputId": "651bd711-39c2-40ee-899b-87336adbb389" }, "outputs": [ { "data": { "text/html": [ "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%html\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "ybGmShZDMM93" }, "source": [ "# 추가 내용" ] }, { "cell_type": "markdown", "metadata": { "id": "md6Ca2uYMM93" }, "source": [ "## Deque vs 로테이팅 리스트" ] }, { "cell_type": "markdown", "metadata": { "id": "fsOqt6KqMM93" }, "source": [ "`deque` 클래스는 추가(append)가 빠르지만 랜덤 접근은 느립니다(재생 메모리가 클 경우):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5eMTXxp7MM93", "outputId": "d7bc4176-f01c-42cb-db26-4baf1d61f70d" }, "outputs": [ { "data": { "text/plain": [ "[121958, 671155, 131932, 365838, 259178]" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from collections import deque\n", "np.random.seed(42)\n", "\n", "mem = deque(maxlen=1000000)\n", "for i in range(1000000):\n", " mem.append(i)\n", "[mem[i] for i in np.random.randint(1000000, size=5)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "utM_wl_9MM93", "outputId": "01bcb5f1-f71e-4497-f56e-1c315a828e94" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "47.4 ns ± 3.02 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)\n" ] } ], "source": [ "%timeit mem.append(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WJjo2-_7MM93", "outputId": "726ae27e-5811-4ff1-dcf1-07c80c45b3a4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "182 µs ± 6.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" ] } ], "source": [ "%timeit [mem[i] for i in np.random.randint(1000000, size=5)]" ] }, { "cell_type": "markdown", "metadata": { "id": "9UkzfOjdMM93" }, "source": [ "또는 다음의 `ReplayMemory` 클래스 같은 로테이팅 리스트를 사용할 수 있습니다. 재생 메모리가 클 경우 랜덤 접근이 더 빠릅니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sor4s3k0MM94" }, "outputs": [], "source": [ "class ReplayMemory:\n", " def __init__(self, max_size):\n", " self.buffer = np.empty(max_size, dtype=np.object)\n", " self.max_size = max_size\n", " self.index = 0\n", " self.size = 0\n", "\n", " def append(self, obj):\n", " self.buffer[self.index] = obj\n", " self.size = min(self.size + 1, self.max_size)\n", " self.index = (self.index + 1) % self.max_size\n", "\n", " def sample(self, batch_size):\n", " indices = np.random.randint(self.size, size=batch_size)\n", " return self.buffer[indices]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QEHf3W1QMM94", "outputId": "b9f66f83-c991-435d-df09-6f8da5177ae3" }, "outputs": [ { "data": { "text/plain": [ "array([757386, 904203, 190588, 595754, 865356], dtype=object)" ] }, "execution_count": 126, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mem = ReplayMemory(max_size=1000000)\n", "for i in range(1000000):\n", " mem.append(i)\n", "mem.sample(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-b_IXtyhMM94", "outputId": "86b5047a-ac80-4121-d63c-3ba1b8d5473e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "519 ns ± 17.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" ] } ], "source": [ "%timeit mem.append(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cZ47hlV0MM94", "outputId": "2683bf27-8f6c-49df-e008-04bf86467606" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.24 µs ± 227 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" ] } ], "source": [ "%timeit mem.sample(5)" ] }, { "cell_type": "markdown", "metadata": { "id": "8mqPd8eRMM95" }, "source": [ "## 사용자 정의 TF-Agents 환경 만들기" ] }, { "cell_type": "markdown", "metadata": { "id": "cQcOSOEcMM95" }, "source": [ "사용자 정의 TF-Agents 환경을 만들려면 `PyEnvironment` 클래스를 상속하는 클래스를 만들고 몇 개의 메서드를 구현해야 합니다. 예를 들어 다음과 같은 환경은 간단한 4x4 그리드를 표현합니다. 에이전트가 한쪽 코너 (0,0)에서 시작하여 반대쪽 코너 (3,3)으로 이동해야 합니다. 에이전트가 목적지에 도착하면 에피소드가 끝납니다(+10 보상을 받습니다). 또는 에이전트가 경계를 벗어나면 끝납니다(-1 보상). 행동은 위(0), 아래(1), 왼쪽(2), 오른쪽(3)이 가능합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHtkxartMM95" }, "outputs": [], "source": [ "class MyEnvironment(tf_agents.environments.py_environment.PyEnvironment):\n", " def __init__(self, discount=1.0):\n", " super().__init__()\n", " self._action_spec = tf_agents.specs.BoundedArraySpec(\n", " shape=(), dtype=np.int32, name=\"action\", minimum=0, maximum=3)\n", " self._observation_spec = tf_agents.specs.BoundedArraySpec(\n", " shape=(4, 4), dtype=np.int32, name=\"observation\", minimum=0, maximum=1)\n", " self.discount = discount\n", "\n", " def action_spec(self):\n", " return self._action_spec\n", "\n", " def observation_spec(self):\n", " return self._observation_spec\n", "\n", " def _reset(self):\n", " self._state = np.zeros(2, dtype=np.int32)\n", " obs = np.zeros((4, 4), dtype=np.int32)\n", " obs[self._state[0], self._state[1]] = 1\n", " return tf_agents.trajectories.time_step.restart(obs)\n", "\n", " def _step(self, action):\n", " self._state += [(-1, 0), (+1, 0), (0, -1), (0, +1)][action]\n", " reward = 0\n", " obs = np.zeros((4, 4), dtype=np.int32)\n", " done = (self._state.min() < 0 or self._state.max() > 3)\n", " if not done:\n", " obs[self._state[0], self._state[1]] = 1\n", " if done or np.all(self._state == np.array([3, 3])):\n", " reward = -1 if done else +10\n", " return tf_agents.trajectories.time_step.termination(obs, reward)\n", " else:\n", " return tf_agents.trajectories.time_step.transition(obs, reward,\n", " self.discount)" ] }, { "cell_type": "markdown", "metadata": { "id": "ASFuy2U5MM95" }, "source": [ "행동과 관측 스펙은 일반적으로 `tf_agents.spec` 패키지에 있는 `ArraySpec`이나 `BoundedArraySpec`의 인스턴스입니다(이 패키지에 있는 다른 스펙도 살펴 보세요). 선택적으로 `render()` 메서드, 자원을 해제하기 위한 `close()` 메서드를 정의할 수도 있습니다. 또한 `reward`와 `discount`를 32비트 실수 스칼라로 사용하고 싶지 않다면 `time_step_spec()` 메서드를 정의할 수 있습니다. 베이스 클래스는 현재 타임 스텝을 추적하므로 `reset()`, `step()` 대신에 `_reset()`, `_step()`을 구현해야 합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ikv6ppsMM95", "outputId": "3a07cfe3-ab81-4e1c-a580-c9b52e740e59" }, "outputs": [ { "data": { "text/plain": [ "TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([[1, 0, 0, 0],\n", " [0, 0, 0, 0],\n", " [0, 0, 0, 0],\n", " [0, 0, 0, 0]], dtype=int32))" ] }, "execution_count": 130, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_env = MyEnvironment()\n", "time_step = my_env.reset()\n", "time_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nyL2ivJOMM95", "outputId": "27944b75-d53d-4d93-c2bc-74c8e2001cab" }, "outputs": [ { "data": { "text/plain": [ "TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([[0, 0, 0, 0],\n", " [1, 0, 0, 0],\n", " [0, 0, 0, 0],\n", " [0, 0, 0, 0]], dtype=int32))" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ "time_step = my_env.step(1)\n", "time_step" ] }, { "cell_type": "markdown", "metadata": { "id": "9ZgnO0JSMM96" }, "source": [ "# 연습문제 해답" ] }, { "cell_type": "markdown", "metadata": { "id": "k5WRiK2CMM96" }, "source": [ "## 1. to 7.\n", "\n", "부록 A 참조" ] }, { "cell_type": "markdown", "metadata": { "id": "lzXGaVbIMM96" }, "source": [ "## 8.\n", "_연습문제: 정책 그레이디언트를 사용해 OpenAI 짐의 LunarLander-v2 환경을 해결해보세요. 이를 위해 Box2D 패키지를 설치해야 합니다(`%pip install -U gym[box2d]`)._" ] }, { "cell_type": "markdown", "metadata": { "id": "5qzPiBWzMM96" }, "source": [ "먼저 LunarLander-v2 환경을 만들어 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J3QWnu4wMM96" }, "outputs": [], "source": [ "env = gym.make(\"LunarLander-v2\")" ] }, { "cell_type": "markdown", "metadata": { "id": "w-QO2DYPMM96" }, "source": [ "입력은 8차원입니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7v8hpHyCMM96", "outputId": "d38cb675-0bfc-4127-a88c-d41843bee1f0" }, "outputs": [ { "data": { "text/plain": [ "Box(-inf, inf, (8,), float32)" ] }, "execution_count": 241, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.observation_space" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "khWHDPkuMM96", "outputId": "77511fba-ad73-4e79-9dc8-0998a98426ed" }, "outputs": [ { "data": { "text/plain": [ "array([-0.00499964, 1.4194578 , -0.506422 , 0.37943238, 0.00580009,\n", " 0.11471219, 0. , 0. ], dtype=float32)" ] }, "execution_count": 242, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.seed(42)\n", "obs = env.reset()\n", "obs" ] }, { "cell_type": "markdown", "metadata": { "id": "FSwLy_zXMM97" }, "source": [ "[소스 코드](https://github.com/openai/gym/blob/master/gym/envs/box2d/lunar_lander.py)를 보면 8D 관측(x, y, h, v, a, w, l, r)이 각각 다음에 해당합니다:\n", "* x,y: 우주선의 좌표. (0, 1.4) 근처의 랜덤한 위치에서 시작하고 (0, 0)에 있는 목적지 근처에 내려야 합니다.\n", "* h,v: 우주선의 수평, 수직 속도. 랜덤한 적은 속도로 시작합니다.\n", "* a,w: 우주선의 각도와 각속도.\n", "* l,r: 왼쪽이나 오른쪽 다리가 땅에 닿았는지(1.0) 아닌지(0.0) 여부." ] }, { "cell_type": "markdown", "metadata": { "id": "4J8l50gDMM97" }, "source": [ "행동 공간은 이산적이며 4개의 가능한 행동이 있습니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xs0R_U5TMM97", "outputId": "d886221b-af4e-414e-9e96-48be9c16792d" }, "outputs": [ { "data": { "text/plain": [ "Discrete(4)" ] }, "execution_count": 243, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": { "id": "alRmUUExMM97" }, "source": [ "[LunarLander-v2 설명](https://gym.openai.com/envs/LunarLander-v2/)을 보면 이 행동은 다음과 같습니다:\n", "* 아무것도 하지 않음\n", "* 왼쪽 방향 엔진을 켬\n", "* 주 엔진을 켬\n", "* 오른쪽 방향 엔진을 켬" ] }, { "cell_type": "markdown", "metadata": { "id": "YDfJ1tdvMM97" }, "source": [ "(행동마다 하나씩) 4개의 출력 뉴런을 가진 간단한 정책 네트워크를 만들어 보죠:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gREhSkWHMM97" }, "outputs": [], "source": [ "keras.backend.clear_session()\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "\n", "n_inputs = env.observation_space.shape[0]\n", "n_outputs = env.action_space.n\n", "\n", "model = keras.models.Sequential([\n", " keras.layers.Dense(32, activation=\"relu\", input_shape=[n_inputs]),\n", " keras.layers.Dense(32, activation=\"relu\"),\n", " keras.layers.Dense(n_outputs, activation=\"softmax\"),\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "iq235HmbMM97" }, "source": [ "출력 층에 CartPole-v1 환경처럼 시그모이드 활성화 함수를 사용하지 않고 대신에 소프트맥스 활성화 함수를 사용합니다. CartPole-v1 환경은 두 개의 행동만 있어서 이진 분류 모델이 맞기 때문입니다. 하지만 두 개 이상의 행동이 있으므로 다중 분류 모델이 됩니다." ] }, { "cell_type": "markdown", "metadata": { "id": "Eraao9BuMM97" }, "source": [ "그다음 CartPole-v1 정책 그레이디언트 코드에서 정의한 `play_one_step()`와 `play_multiple_episodes()` 함수를 재사용합니다. 하지만 다중 분류 모델에 맞게 `play_one_step()`를 조금 수정하겠습니다. 그다음 수정된 `play_one_step()` 를 호출하고, 우주선이 최대 스텝 횟수 전에 랜딩하지 못하면 (또는 부서지면) 큰 페널티를 부여하도록 `play_multiple_episodes()` 함수를 수정합니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zVxW3agtMM98" }, "outputs": [], "source": [ "def lander_play_one_step(env, obs, model, loss_fn):\n", " with tf.GradientTape() as tape:\n", " probas = model(obs[np.newaxis])\n", " logits = tf.math.log(probas + keras.backend.epsilon())\n", " action = tf.random.categorical(logits, num_samples=1)\n", " loss = tf.reduce_mean(loss_fn(action, probas))\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " obs, reward, done, info = env.step(action[0, 0].numpy())\n", " return obs, reward, done, grads\n", "\n", "def lander_play_multiple_episodes(env, n_episodes, n_max_steps, model, loss_fn):\n", " all_rewards = []\n", " all_grads = []\n", " for episode in range(n_episodes):\n", " current_rewards = []\n", " current_grads = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " obs, reward, done, grads = lander_play_one_step(env, obs, model, loss_fn)\n", " current_rewards.append(reward)\n", " current_grads.append(grads)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_grads.append(current_grads)\n", " return all_rewards, all_grads" ] }, { "cell_type": "markdown", "metadata": { "id": "ZJ3-NqLoMM98" }, "source": [ "앞에서와 동일한 `discount_rewards()`와 `discount_and_normalize_rewards()` 함수를 사용합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wK2nO-T0MM98" }, "outputs": [], "source": [ "def discount_rewards(rewards, discount_rate):\n", " discounted = np.array(rewards)\n", " for step in range(len(rewards) - 2, -1, -1):\n", " discounted[step] += discounted[step + 1] * discount_rate\n", " return discounted\n", "\n", "def discount_and_normalize_rewards(all_rewards, discount_rate):\n", " all_discounted_rewards = [discount_rewards(rewards, discount_rate)\n", " for rewards in all_rewards]\n", " flat_rewards = np.concatenate(all_discounted_rewards)\n", " reward_mean = flat_rewards.mean()\n", " reward_std = flat_rewards.std()\n", " return [(discounted_rewards - reward_mean) / reward_std\n", " for discounted_rewards in all_discounted_rewards]" ] }, { "cell_type": "markdown", "metadata": { "id": "p-oWW-uDMM98" }, "source": [ "이제 몇 개의 하이퍼파라미터를 정의합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j62ORPXfMM98" }, "outputs": [], "source": [ "n_iterations = 200\n", "n_episodes_per_update = 16\n", "n_max_steps = 1000\n", "discount_rate = 0.99" ] }, { "cell_type": "markdown", "metadata": { "id": "VTOAXopvMM98" }, "source": [ "여기서도 다중 분류 모델이기 때문에 이진 크로스 엔트로피가 아니라 범주형 크로스 엔트로피를 사용해야 합니다. 또한 `lander_play_one_step()` 함수가 클래스 확률이 아니라 클래스 레이블로 타깃을 설정하기 때문에 `sparse_categorical_crossentropy()` 손실 함수를 사용해야 합니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MHQ544cjMM98" }, "outputs": [], "source": [ "optimizer = keras.optimizers.Nadam(learning_rate=0.005)\n", "loss_fn = keras.losses.sparse_categorical_crossentropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YV9y4MpRMM99" }, "source": [ "모델을 훈련할 준비가 되었네요. 시작해 보죠!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WaWxYxI0MM99", "outputId": "c55c97c9-a912-4870-bf33-406d5f43eb96" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 200/200, mean reward: 134.2 " ] } ], "source": [ "env.seed(42)\n", "\n", "mean_rewards = []\n", "\n", "for iteration in range(n_iterations):\n", " all_rewards, all_grads = lander_play_multiple_episodes(\n", " env, n_episodes_per_update, n_max_steps, model, loss_fn)\n", " mean_reward = sum(map(sum, all_rewards)) / n_episodes_per_update\n", " print(\"\\rIteration: {}/{}, mean reward: {:.1f} \".format(\n", " iteration + 1, n_iterations, mean_reward), end=\"\")\n", " mean_rewards.append(mean_reward)\n", " all_final_rewards = discount_and_normalize_rewards(all_rewards,\n", " discount_rate)\n", " all_mean_grads = []\n", " for var_index in range(len(model.trainable_variables)):\n", " mean_grads = tf.reduce_mean(\n", " [final_reward * all_grads[episode_index][step][var_index]\n", " for episode_index, final_rewards in enumerate(all_final_rewards)\n", " for step, final_reward in enumerate(final_rewards)], axis=0)\n", " all_mean_grads.append(mean_grads)\n", " optimizer.apply_gradients(zip(all_mean_grads, model.trainable_variables))" ] }, { "cell_type": "markdown", "metadata": { "id": "axnoL3jMMM99" }, "source": [ "학습 곡선을 그려 보겠습니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ol8ai_ZzMM99", "outputId": "433c404c-b364-404d-b70a-9e201fd0949f" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(mean_rewards)\n", "plt.xlabel(\"Episode\")\n", "plt.ylabel(\"Mean reward\")\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "3VgVDzDfMM99" }, "source": [ "결과를 확인해 보죠!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IS0TjvFzMM99" }, "outputs": [], "source": [ "def lander_render_policy_net(model, n_max_steps=500, seed=42):\n", " frames = []\n", " env = gym.make(\"LunarLander-v2\")\n", " env.seed(seed)\n", " tf.random.set_seed(seed)\n", " np.random.seed(seed)\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " frames.append(env.render(mode=\"rgb_array\"))\n", " probas = model(obs[np.newaxis])\n", " logits = tf.math.log(probas + keras.backend.epsilon())\n", " action = tf.random.categorical(logits, num_samples=1)\n", " obs, reward, done, info = env.step(action[0, 0].numpy())\n", " if done:\n", " break\n", " env.close()\n", " return frames" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pzfWaUVGMM-B" }, "outputs": [], "source": [ "frames = lander_render_policy_net(model, seed=42)\n", "plot_animation(frames)" ] }, { "cell_type": "markdown", "metadata": { "id": "RQW9sCXpMM-B" }, "source": [ "꽤 괜찮군요. 더 오래 훈련하거나 하이퍼파라미터를 튜닝하여 200을 넘을 수 있는지 확인해 보세요." ] }, { "cell_type": "markdown", "metadata": { "id": "6xoce5bUMM-B" }, "source": [ "## 9.\n", "_연습문제: 알고리즘에 상관없이 TF-Agents를 사용해 SpaceInvaders-v4 환경에서 사람을 능가하는 에이전트를 훈련해보세요._" ] }, { "cell_type": "markdown", "metadata": { "id": "S0Zm1na_MM-B" }, "source": [ "`\"Breakout-v4\"`를 `\"SpaceInvaders-v4\"`로 바꾸고 [TF Agents를 사용해 브레이크아웃 게임하기](#TF-Agents%EB%A5%BC-%EC%82%AC%EC%9A%A9%ED%95%B4-%EB%B8%8C%EB%A0%88%EC%9D%B4%ED%81%AC%EC%95%84%EC%9B%83-%EA%B2%8C%EC%9E%84%ED%95%98%EA%B8%B0) 절에 있는 단계를 따라해 보세요. 하지만 몇 가지를 바꾸어야 합니다. 예를 들어 스페이스 인베이더 게임은 게임을 시작할 때 FIRE 버튼을 누를 필요가 없습니다. 대신 플레이어의 레이저 캐논이 몇 초간 깜빡거린 다음 자동으로 게임이 시작됩니다. 성능을 높이려면 에피소드를 시작할 때와 죽을 때마다 깜빡임 단계(약 40 스텝 동안 지속됩니다)를 건너 뛸 수 있습니다. 사실 이 단계에서는 아무것도 할 수 없고 아무것도 움직이지 않습니다. 건너 뛰는 방법은 `AtariPreprocessingWithAutoFire` 래퍼 대신에 다음과 같은 사용자 정의 환경 래퍼를 사용하는 것입니다:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i9vUt5WJMM-B" }, "outputs": [], "source": [ "class AtariPreprocessingWithSkipStart(AtariPreprocessing):\n", " def skip_frames(self, num_skip):\n", " for _ in range(num_skip):\n", " super().step(0) # NOOP for num_skip steps\n", " def reset(self, **kwargs):\n", " obs = super().reset(**kwargs)\n", " self.skip_frames(40)\n", " return obs\n", " def step(self, action):\n", " lives_before_action = self.ale.lives()\n", " obs, rewards, done, info = super().step(action)\n", " if self.ale.lives() < lives_before_action and not done:\n", " self.skip_frames(40)\n", " return obs, rewards, done, info" ] }, { "cell_type": "markdown", "metadata": { "id": "ncMSSzA6MM-B" }, "source": [ "또한 전처리된 이미지가 게임 플레이에 관한 충분한 정보를 담고 있는지 항상 확인해야 합니다. 예를 들어, 낮은 해상도에도 불구하고 레이저 캐논과 에일리언에서 발사된 총알은 항상 보여야 합니다. 이 경우에 브레이크아웃에서 수행했던 전처리가 스페이스 인베이더에도 잘 맞습니다. 하지만 다른 게임에서는 항상 확인해봐야 합니다. 이를 위해 에이전트가 랜덤하게 플레이하게 잠시 놔두고 전처리된 프레임을 기롭한 다음 애니메이션을 플레이하여 게임 플레이가 잘 보이는지 확인하세요.\n", "\n", "좋은 성능을 얻으려면 에이전트를 꽤 오랜 시간 동안 훈련해야 합니다. 안타깝게도 DQN 알고리즘은 스페이스 인베이더에서 사람을 뛰어 넘는 수준을 달성할 수 없습니다. 사람은 이 게임에서 효율적인 장기 전략을 학습할 수 있지만 DQN은 매우 짧은 전략만 학습할 수 있습니다. 하지만 지난 몇 년간 많은 발전이 있었습니다. 이제는 많은 RL 알고리즘이 이 게임에서 전문가의 수준을 뛰어 넘을 수 있습니다. [State-of-the-Art for Space Invaders on paperswithcode.com](https://paperswithcode.com/sota/atari-games-on-atari-2600-space-invaders)를 참고하세요." ] }, { "cell_type": "markdown", "metadata": { "id": "GdCKaaQGMM-B" }, "source": [ "## 10.\n", "_연습문제: 10만 원 정도 여유가 있다면 라즈베리 파이 3와 저렴한 로보틱스 구성품을 구입해 텐서플로를 설치하고 실행할 수 있습니다! 예를 들어 루카스 비월드의 재미있는 [포스트](https://homl.info/2)를 참고하거나, GoPiGo42나 BrickPi43를 둘러보세요. 간단한 작업부터 시작해보세요. 예를 들어 (조도 센서가 있다면) 로봇이 밝은 쪽으로 회전하거나 (초음파 센서가 있다면) 가까운 물체가 있는 쪽으로 움직이도록 해보세요. 그다음 딥러닝을 사용해보세요. 예를 들어 로봇에 카메라가 있다면 객체 탐지 알고리즘을 구현해 사람을 감지하고 가까이 다가가게 만들 수 있습니다. 강화 학습을 사용해 목표를 달성하기 위해 모터 사용법을 스스로 학습할 수도 있습니다._" ] }, { "cell_type": "markdown", "metadata": { "id": "MHUiq7VhMM-C" }, "source": [ "이제 여러분 차례입니다. 도전적이고 창의적으로, 무엇보다도 인내심을 가지고 한 발씩 나아가세요. 여러분은 할 수 있습니다!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }