{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Balancing immediate and long-term goals\n", "\n", "> 그로킹 심층 강화학습 중 3장 내용인 \"순간 목표와 장기 목표간의 균형\"에 대한 내용입니다.\n", "\n", "- hide: true\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Reinforcement_Learning, Grokking_Deep_Reinforcement_Learning]\n", "- permalink: /book/:title:output_ext\n", "- search_exclude: false" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: 라이브러리 설치를 위해 아래의 패키지들을 설치해주기 바랍니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#collapse\n", "!pip install tqdm numpy scikit-learn pyglet setuptools && \\\n", "!pip install gym asciinema pandas tabulate tornado==5.* PyBullet && \\\n", "!pip install git+https://github.com/pybox2d/pybox2d#egg=Box2D && \\\n", "!pip install git+https://github.com/mimoralea/gym-bandits#egg=gym-bandits && \\\n", "!pip install git+https://github.com/mimoralea/gym-walk#egg=gym-walk && \\\n", "!pip install git+https://github.com/mimoralea/gym-aima#egg=gym-aima && \\\n", "!pip install gym[atari]" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gym, gym_walk, gym_aima" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 정책 반복법, 가치 반복법\n", "\n", "**참고**: 해당 노트북에서 사용되는 환경에 대한 정보는 아래 링크를 참고하시기 바랍니다.\n", "\n", "- [gym_walk](https://github.com/mimoralea/gym-walk)\n", "- [gym_aima](https://github.com/mimoralea/gym-aima)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings ; warnings.filterwarnings('ignore')\n", "\n", "import gym, gym_walk, gym_aima\n", "import numpy as np\n", "from pprint import pprint\n", "from tqdm import tqdm_notebook as tqdm\n", "\n", "from itertools import cycle\n", "\n", "import random\n", "\n", "np.set_printoptions(suppress=True)\n", "random.seed(123); np.random.seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 출력을 위한 helper function" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def print_policy(pi, P, action_symbols=('<', 'v', '>', '^'), n_cols=4, title='정책:'):\n", " print(title)\n", " arrs = {k:v for k,v in enumerate(action_symbols)}\n", " for s in range(len(P)):\n", " a = pi(s)\n", " print(\"| \", end=\"\")\n", " if np.all([done for action in P[s].values() for _, _, _, done in action]):\n", " print(\"\".rjust(9), end=\" \")\n", " else:\n", " print(str(s).zfill(2), arrs[a].rjust(6), end=\" \")\n", " if (s + 1) % n_cols == 0: print(\"|\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def print_state_value_function(V, P, n_cols=4, prec=3, title='상태-가치 함수:'):\n", " print(title)\n", " for s in range(len(P)):\n", " v = V[s]\n", " print(\"| \", end=\"\")\n", " if np.all([done for action in P[s].values() for _, _, _, done in action]):\n", " print(\"\".rjust(9), end=\" \")\n", " else:\n", " print(str(s).zfill(2), '{}'.format(np.round(v, prec)).rjust(6), end=\" \")\n", " if (s + 1) % n_cols == 0: print(\"|\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def print_action_value_function(Q, \n", " optimal_Q=None, \n", " action_symbols=('<', '>'), \n", " prec=3, \n", " title='행동-가치 함수:'):\n", " vf_types=('',) if optimal_Q is None else ('', '*', 'err')\n", " headers = ['s',] + [' '.join(i) for i in list(itertools.product(vf_types, action_symbols))]\n", " print(title)\n", " states = np.arange(len(Q))[..., np.newaxis]\n", " arr = np.hstack((states, np.round(Q, prec)))\n", " if not (optimal_Q is None):\n", " arr = np.hstack((arr, np.round(optimal_Q, prec), np.round(optimal_Q-Q, prec)))\n", " print(tabulate(arr, headers, tablefmt=\"fancy_grid\"))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def probability_success(env, pi, goal_state, n_episodes=100, max_steps=200):\n", " random.seed(123); np.random.seed(123) ; env.seed(123)\n", " results = []\n", " for _ in range(n_episodes):\n", " state, done, steps = env.reset(), False, 0\n", " while not done and steps < max_steps:\n", " state, _, done, h = env.step(pi(state))\n", " steps += 1\n", " results.append(state == goal_state)\n", " return np.sum(results)/len(results)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def mean_return(env, pi, n_episodes=100, max_steps=200):\n", " random.seed(123); np.random.seed(123) ; env.seed(123)\n", " results = []\n", " for _ in range(n_episodes):\n", " state, done, steps = env.reset(), False, 0\n", " results.append(0.0)\n", " while not done and steps < max_steps:\n", " state, reward, done, _ = env.step(pi(state))\n", " results[-1] += reward\n", " steps += 1\n", " return np.mean(results)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Slippery Walk Five MDP and sample policy" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| | 01 < | 02 < | 03 < | 04 < | 05 < | |\n", "Reaches goal 7.00%. Obtains an average undiscounted return of 0.0700.\n" ] } ], "source": [ "env = gym.make('SlipperyWalkFive-v0')\n", "P = env.env.P\n", "init_state = env.reset()\n", "goal_state = 6\n", "\n", "LEFT, RIGHT = range(2)\n", "pi = lambda s: {\n", " 0:LEFT, 1:LEFT, 2:LEFT, 3:LEFT, 4:LEFT, 5:LEFT, 6:LEFT\n", "}[s]\n", "print_policy(pi, P, action_symbols=('<', '>'), n_cols=7)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi, goal_state=goal_state)*100, \n", " mean_return(env, pi)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 정책 평가법" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def policy_evaluation(pi, P, gamma=1.0, theta=1e-10):\n", " prev_V = np.zeros(len(P), dtype=np.float64)\n", " while True:\n", " V = np.zeros(len(P), dtype=np.float64)\n", " for s in range(len(P)):\n", " for prob, next_state, reward, done in P[s][pi(s)]:\n", " V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))\n", " if np.max(np.abs(prev_V - V)) < theta:\n", " break\n", " prev_V = V.copy()\n", " return V" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| | 01 0.00275 | 02 0.01099 | 03 0.03571 | 04 0.10989 | 05 0.33242 | |\n" ] } ], "source": [ "V = policy_evaluation(pi, P)\n", "print_state_value_function(V, P, n_cols=7, prec=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 정책 개선법" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def policy_improvement(V, P, gamma=1.0):\n", " Q = np.zeros((len(P), len(P[0])), dtype=np.float64)\n", " for s in range(len(P)):\n", " for a in range(len(P[s])):\n", " for prob, next_state, reward, done in P[s][a]:\n", " Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))\n", " new_pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]\n", " return new_pi" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| | 01 > | 02 > | 03 > | 04 > | 05 > | |\n", "Reaches goal 93.00%. Obtains an average undiscounted return of 0.9300.\n" ] } ], "source": [ "improved_pi = policy_improvement(V, P)\n", "print_policy(improved_pi, P, action_symbols=('<', '>'), n_cols=7)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, improved_pi, goal_state=goal_state)*100, \n", " mean_return(env, improved_pi)))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| | 01 0.66758 | 02 0.89011 | 03 0.96429 | 04 0.98901 | 05 0.99725 | |\n" ] } ], "source": [ "# how about we evaluate the improved policy?\n", "improved_V = policy_evaluation(improved_pi, P)\n", "print_state_value_function(improved_V, P, n_cols=7, prec=5)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| | 01 > | 02 > | 03 > | 04 > | 05 > | |\n", "Reaches goal 93.00%. Obtains an average undiscounted return of 0.9300.\n" ] } ], "source": [ "# can we improved the improved policy?\n", "improved_improved_pi = policy_improvement(improved_V, P)\n", "print_policy(improved_improved_pi, P, action_symbols=('<', '>'), n_cols=7)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, improved_improved_pi, goal_state=goal_state)*100, \n", " mean_return(env, improved_improved_pi)))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| | 01 0.66758 | 02 0.89011 | 03 0.96429 | 04 0.98901 | 05 0.99725 | |\n" ] } ], "source": [ "# it is the same policy\n", "# if we evaluate again, we can see there is nothing to improve \n", "# that also means we reached the optimal policy\n", "improved_improved_V = policy_evaluation(improved_improved_pi, P)\n", "print_state_value_function(improved_improved_V, P, n_cols=7, prec=5)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# state-value function didn't improve, then we reach the optimal policy\n", "assert np.all(improved_V == improved_improved_V)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 정책 반복법" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def policy_iteration(P, gamma=1.0, theta=1e-10):\n", " random_actions = np.random.choice(tuple(P[0].keys()), len(P))\n", " pi = lambda s: {s:a for s, a in enumerate(random_actions)}[s]\n", " while True:\n", " old_pi = {s:pi(s) for s in range(len(P))}\n", " V = policy_evaluation(pi, P, gamma, theta)\n", " pi = policy_improvement(V, P, gamma)\n", " if old_pi == {s:pi(s) for s in range(len(P))}:\n", " break\n", " return V, pi" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (PI):\n", "정책:\n", "| | 01 > | 02 > | 03 > | 04 > | 05 > | |\n", "Reaches goal 93.00%. Obtains an average undiscounted return of 0.9300.\n", "\n", "상태-가치 함수:\n", "| | 01 0.66758 | 02 0.89011 | 03 0.96429 | 04 0.98901 | 05 0.99725 | |\n" ] } ], "source": [ "optimal_V, optimal_pi = policy_iteration(P)\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(optimal_pi, P, action_symbols=('<', '>'), n_cols=7)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, optimal_pi, goal_state=goal_state)*100, \n", " mean_return(env, optimal_pi)))\n", "print()\n", "print_state_value_function(optimal_V, P, n_cols=7, prec=5)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "assert np.all(improved_V == optimal_V)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Frozen Lake MDP and sample policies" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 > | 01 < | 02 v | 03 ^ |\n", "| 04 < | | 06 > | |\n", "| 08 ^ | 09 v | 10 ^ | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 12.00%. Obtains an average undiscounted return of 0.1200.\n" ] } ], "source": [ "env = gym.make('FrozenLake-v0')\n", "P = env.env.P\n", "init_state = env.reset()\n", "goal_state = 15\n", "\n", "LEFT, DOWN, RIGHT, UP = range(4)\n", "random_pi = lambda s: {\n", " 0:RIGHT, 1:LEFT, 2:DOWN, 3:UP,\n", " 4:LEFT, 5:LEFT, 6:RIGHT, 7:LEFT,\n", " 8:UP, 9:DOWN, 10:UP, 11:LEFT,\n", " 12:LEFT, 13:RIGHT, 14:DOWN, 15:LEFT\n", "}[s]\n", "print_policy(random_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, random_pi, goal_state=goal_state)*100, \n", " mean_return(env, random_pi)))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 > | 01 > | 02 v | 03 < |\n", "| 04 v | | 06 v | |\n", "| 08 > | 09 > | 10 v | |\n", "| | 13 > | 14 > | |\n", "Reaches goal 5.00%. Obtains an average undiscounted return of 0.0500.\n" ] } ], "source": [ "go_get_pi = lambda s: {\n", " 0:RIGHT, 1:RIGHT, 2:DOWN, 3:LEFT,\n", " 4:DOWN, 5:LEFT, 6:DOWN, 7:LEFT,\n", " 8:RIGHT, 9:RIGHT, 10:DOWN, 11:LEFT,\n", " 12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT\n", "}[s]\n", "print_policy(go_get_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, go_get_pi, goal_state=goal_state)*100, \n", " mean_return(env, go_get_pi)))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 ^ | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 > | |\n", "Reaches goal 52.00%. Obtains an average undiscounted return of 0.5200.\n" ] } ], "source": [ "careful_pi = lambda s: {\n", " 0:LEFT, 1:UP, 2:UP, 3:UP,\n", " 4:LEFT, 5:LEFT, 6:UP, 7:LEFT,\n", " 8:UP, 9:DOWN, 10:LEFT, 11:LEFT,\n", " 12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT\n", "}[s]\n", "print_policy(careful_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, careful_pi, goal_state=goal_state)*100, \n", " mean_return(env, careful_pi)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 정책 평가법" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.4079 | 01 0.3754 | 02 0.3543 | 03 0.3438 |\n", "| 04 0.4203 | | 06 0.1169 | |\n", "| 08 0.4454 | 09 0.484 | 10 0.4328 | |\n", "| | 13 0.5884 | 14 0.7107 | |\n" ] } ], "source": [ "V = policy_evaluation(careful_pi, P, gamma=0.99)\n", "print_state_value_function(V, P, prec=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 정책 개선법" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n" ] } ], "source": [ "careful_plus_pi = policy_improvement(V, P, gamma=0.99)\n", "print_policy(careful_plus_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, careful_plus_pi, goal_state=goal_state)*100, \n", " mean_return(env, careful_plus_pi)))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 |\n", "| 04 0.5585 | | 06 0.3583 | |\n", "| 08 0.5918 | 09 0.6431 | 10 0.6152 | |\n", "| | 13 0.7417 | 14 0.8628 | |\n" ] } ], "source": [ "new_V = policy_evaluation(careful_plus_pi, P, gamma=0.99)\n", "print_state_value_function(new_V, P, prec=4)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.1341 | 01 0.1234 | 02 0.1164 | 03 0.113 |\n", "| 04 0.1381 | | 06 0.2414 | |\n", "| 08 0.1464 | 09 0.1591 | 10 0.1824 | |\n", "| | 13 0.1533 | 14 0.1521 | |\n" ] } ], "source": [ "print_state_value_function(new_V - V, P, prec=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Alternating between evaluation and improvement" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 ^ | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 ^ | | 06 ^ | |\n", "| 08 < | 09 < | 10 < | |\n", "| | 13 < | 14 < | |\n", "Reaches goal 0.00%. Obtains an average undiscounted return of 0.0000.\n" ] } ], "source": [ "adversarial_pi = lambda s: {\n", " 0:UP, 1:UP, 2:UP, 3:UP,\n", " 4:UP, 5:LEFT, 6:UP, 7:LEFT,\n", " 8:LEFT, 9:LEFT, 10:LEFT, 11:LEFT,\n", " 12:LEFT, 13:LEFT, 14:LEFT, 15:LEFT\n", "}[s]\n", "print_policy(adversarial_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, adversarial_pi, goal_state=goal_state)*100, \n", " mean_return(env, adversarial_pi)))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.0 | 01 0.0 | 02 0.0 | 03 0.0 |\n", "| 04 0.0 | | 06 0.0 | |\n", "| 08 0.0 | 09 0.0 | 10 0.0 | |\n", "| | 13 0.0 | 14 0.0 | |\n" ] } ], "source": [ "V = policy_evaluation(adversarial_pi, P, gamma=0.99)\n", "print_state_value_function(V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 < | 02 < | 03 < |\n", "| 04 < | | 06 < | |\n", "| 08 < | 09 < | 10 < | |\n", "| | 13 < | 14 v | |\n", "Reaches goal 0.00%. Obtains an average undiscounted return of 0.0000.\n" ] } ], "source": [ "i_pi = policy_improvement(V, P, gamma=0.99)\n", "print_policy(i_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, i_pi, goal_state=goal_state)*100, \n", " mean_return(env, i_pi)))" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.0 | 01 0.0 | 02 0.04 | 03 0.02 |\n", "| 04 0.0 | | 06 0.07 | |\n", "| 08 0.0 | 09 0.0 | 10 0.19 | |\n", "| | 13 0.0 | 14 0.5 | |\n" ] } ], "source": [ "i_V = policy_evaluation(i_pi, P, gamma=0.99)\n", "print_state_value_function(i_V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 v | 02 > | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 < | 09 v | 10 < | |\n", "| | 13 v | 14 > | |\n", "Reaches goal 0.00%. Obtains an average undiscounted return of 0.0000.\n" ] } ], "source": [ "ii_pi = policy_improvement(i_V, P, gamma=0.99)\n", "print_policy(ii_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, ii_pi, goal_state=goal_state)*100, \n", " mean_return(env, ii_pi)))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.0 | 01 0.05 | 02 0.16 | 03 0.15 |\n", "| 04 0.0 | | 06 0.17 | |\n", "| 08 0.0 | 09 0.22 | 10 0.35 | |\n", "| | 13 0.33 | 14 0.67 | |\n" ] } ], "source": [ "ii_V = policy_evaluation(ii_pi, P, gamma=0.99)\n", "print_state_value_function(ii_V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 v | 01 > | 02 > | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 v | 09 v | 10 < | |\n", "| | 13 > | 14 > | |\n", "Reaches goal 20.00%. Obtains an average undiscounted return of 0.2000.\n" ] } ], "source": [ "iii_pi = policy_improvement(ii_V, P, gamma=0.99)\n", "print_policy(iii_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, iii_pi, goal_state=goal_state)*100, \n", " mean_return(env, iii_pi)))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.12 | 01 0.09 | 02 0.19 | 03 0.19 |\n", "| 04 0.15 | | 06 0.2 | |\n", "| 08 0.19 | 09 0.38 | 10 0.43 | |\n", "| | 13 0.53 | 14 0.71 | |\n" ] } ], "source": [ "iii_V = policy_evaluation(iii_pi, P, gamma=0.99)\n", "print_state_value_function(iii_V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 ^ | 02 > | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 73.00%. Obtains an average undiscounted return of 0.7300.\n" ] } ], "source": [ "iiii_pi = policy_improvement(iii_V, P, gamma=0.99)\n", "print_policy(iiii_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, iiii_pi, goal_state=goal_state)*100, \n", " mean_return(env, iiii_pi)))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.52 | 01 0.38 | 02 0.26 | 03 0.25 |\n", "| 04 0.54 | | 06 0.28 | |\n", "| 08 0.57 | 09 0.62 | 10 0.58 | |\n", "| | 13 0.72 | 14 0.85 | |\n" ] } ], "source": [ "iiii_V = policy_evaluation(iiii_pi, P, gamma=0.99)\n", "print_state_value_function(iiii_V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 ^ | 02 < | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n" ] } ], "source": [ "iiiii_pi = policy_improvement(iiii_V, P, gamma=0.99)\n", "print_policy(iiiii_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, iiiii_pi, goal_state=goal_state)*100, \n", " mean_return(env, iiiii_pi)))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.53 | 01 0.45 | 02 0.38 | 03 0.37 |\n", "| 04 0.55 | | 06 0.32 | |\n", "| 08 0.58 | 09 0.63 | 10 0.6 | |\n", "| | 13 0.73 | 14 0.86 | |\n" ] } ], "source": [ "iiiii_V = policy_evaluation(iiiii_pi, P, gamma=0.99)\n", "print_state_value_function(iiiii_V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n" ] } ], "source": [ "iiiiii_pi = policy_improvement(iiiii_V, P, gamma=0.99)\n", "print_policy(iiiiii_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, iiiiii_pi, goal_state=goal_state)*100, \n", " mean_return(env, iiiiii_pi)))" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.54 | 01 0.5 | 02 0.47 | 03 0.46 |\n", "| 04 0.56 | | 06 0.36 | |\n", "| 08 0.59 | 09 0.64 | 10 0.62 | |\n", "| | 13 0.74 | 14 0.86 | |\n" ] } ], "source": [ "iiiiii_V = policy_evaluation(iiiiii_pi, P, gamma=0.99)\n", "print_state_value_function(iiiiii_V, P, prec=2)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n" ] } ], "source": [ "iiiiiii_pi = policy_improvement(iiiiii_V, P, gamma=0.99)\n", "print_policy(iiiiiii_pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, iiiiiii_pi, goal_state=goal_state)*100, \n", " mean_return(env, iiiiiii_pi)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 정책 반복법" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태-가치 함수:\n", "| 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 |\n", "| 04 0.5585 | | 06 0.3583 | |\n", "| 08 0.5918 | 09 0.6431 | 10 0.6152 | |\n", "| | 13 0.7417 | 14 0.8628 | |\n", "\n", "Optimal policy and state-value function (PI):\n", "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n" ] } ], "source": [ "V_best_p, pi_best_p = policy_iteration(P, gamma=0.99)\n", "print_state_value_function(V_best_p, P, prec=4)\n", "print()\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(pi_best_p, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best_p, goal_state=goal_state)*100, \n", " mean_return(env, pi_best_p)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Slippery Walk Five" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "env = gym.make('SlipperyWalkFive-v0')\n", "init_state = env.reset()\n", "goal_state = 6\n", "P = env.env.P" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 가치 반복법" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "def value_iteration(P, gamma=1.0, theta=1e-10):\n", " V = np.zeros(len(P), dtype=np.float64)\n", " while True:\n", " Q = np.zeros((len(P), len(P[0])), dtype=np.float64)\n", " for s in range(len(P)):\n", " for a in range(len(P[s])):\n", " for prob, next_state, reward, done in P[s][a]:\n", " Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))\n", " if np.max(np.abs(V - np.max(Q, axis=1))) < theta:\n", " break\n", " V = np.max(Q, axis=1)\n", " pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]\n", " return V, pi" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (PI):\n", "정책:\n", "| | 01 > | 02 > | 03 > | 04 > | 05 > | |\n", "Reaches goal 93.00%. Obtains an average undiscounted return of 0.9300.\n", "\n", "상태-가치 함수:\n", "| | 01 0.66758 | 02 0.89011 | 03 0.96429 | 04 0.98901 | 05 0.99725 | |\n" ] } ], "source": [ "optimal_V, optimal_pi = value_iteration(P)\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(optimal_pi, P, action_symbols=('<', '>'), n_cols=7)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, optimal_pi, goal_state=goal_state)*100, \n", " mean_return(env, optimal_pi)))\n", "print()\n", "print_state_value_function(optimal_V, P, n_cols=7, prec=5)\n", "# | | 01 0.668 | 02 0.890 | 03 0.964 | 04 0.989 | 05 0.997 | |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Frozen Lake MDP" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "env = gym.make('FrozenLake-v0')\n", "init_state = env.reset()\n", "goal_state = 15\n", "P = env.env.P" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (VI):\n", "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n", "\n", "상태-가치 함수:\n", "| 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 |\n", "| 04 0.5585 | | 06 0.3583 | |\n", "| 08 0.5918 | 09 0.6431 | 10 0.6152 | |\n", "| | 13 0.7417 | 14 0.8628 | |\n" ] } ], "source": [ "V_best_v, pi_best_v = value_iteration(P, gamma=0.99)\n", "print('Optimal policy and state-value function (VI):')\n", "print_policy(pi_best_v, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best_v, goal_state=goal_state)*100, \n", " mean_return(env, pi_best_v)))\n", "print()\n", "print_state_value_function(V_best_v, P, prec=4)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "For comparison, optimal policy and state-value function (PI):\n", "정책:\n", "| 00 < | 01 ^ | 02 ^ | 03 ^ |\n", "| 04 < | | 06 < | |\n", "| 08 ^ | 09 v | 10 < | |\n", "| | 13 > | 14 v | |\n", "Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.\n", "\n", "상태-가치 함수:\n", "| 00 0.542 | 01 0.499 | 02 0.471 | 03 0.457 |\n", "| 04 0.558 | | 06 0.358 | |\n", "| 08 0.592 | 09 0.643 | 10 0.615 | |\n", "| | 13 0.742 | 14 0.863 | |\n" ] } ], "source": [ "print('For comparison, optimal policy and state-value function (PI):')\n", "print_policy(pi_best_p, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best_p, goal_state=goal_state)*100, \n", " mean_return(env, pi_best_p)))\n", "print()\n", "print_state_value_function(V_best_p, P)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Changing the Frozen Lake environment MDP" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "env = gym.make('FrozenLake-v0')\n", "P = env.env.P\n", "\n", "# change reward function\n", "reward_goal, reward_holes, reward_others = 1, -1, -0.01\n", "goal, hole = 15, [5, 7, 11, 12]\n", "for s in range(len(P)):\n", " for a in range(len(P[s])):\n", " for t in range(len(P[s][a])):\n", " values = list(P[s][a][t])\n", " if values[1] == goal:\n", " values[2] = reward_goal\n", " values[3] = False\n", " elif values[1] in hole:\n", " values[2] = reward_holes\n", " values[3] = False\n", " else:\n", " values[2] = reward_others\n", " values[3] = False\n", " if s in hole or s == goal:\n", " values[2] = 0\n", " values[3] = True\n", " P[s][a][t] = tuple(values)\n", "\n", "# change transition function\n", "prob_action, prob_drift_one, prob_drift_two = 0.8, 0.1, 0.1\n", "for s in range(len(P)):\n", " for a in range(len(P[s])):\n", " for t in range(len(P[s][a])):\n", " if P[s][a][t][0] == 1.0:\n", " continue\n", " values = list(P[s][a][t])\n", " if t == 0:\n", " values[0] = prob_drift_one\n", " elif t == 1:\n", " values[0] = prob_action\n", " elif t == 2:\n", " values[0] = prob_drift_two\n", " P[s][a][t] = tuple(values)\n", "\n", "env.env.P = P" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (PI):\n", "정책:\n", "| 00 v | 01 ^ | 02 v | 03 ^ |\n", "| 04 < | | 06 v | |\n", "| 08 > | 09 v | 10 < | |\n", "| | 13 > | 14 > | |\n", "Reaches goal 78.00%. Obtains an average undiscounted return of 0.3657.\n", "\n", "상태-가치 함수:\n", "| 00 0.433 | 01 0.353 | 02 0.409 | 03 0.28 |\n", "| 04 0.461 | | 06 0.45 | |\n", "| 08 0.636 | 09 0.884 | 10 0.831 | |\n", "| | 13 0.945 | 14 0.977 | |\n" ] } ], "source": [ "V_best, pi_best = policy_iteration(env.env.P, gamma=0.99)\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(pi_best, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best, goal_state=goal_state)*100, \n", " mean_return(env, pi_best)))\n", "print()\n", "print_state_value_function(V_best, P)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (PI):\n", "정책:\n", "| 00 v | 01 ^ | 02 v | 03 ^ |\n", "| 04 < | | 06 v | |\n", "| 08 > | 09 v | 10 < | |\n", "| | 13 > | 14 > | |\n", "Reaches goal 78.00%. Obtains an average undiscounted return of 0.3657.\n", "\n", "상태-가치 함수:\n", "| 00 0.433 | 01 0.353 | 02 0.409 | 03 0.28 |\n", "| 04 0.461 | | 06 0.45 | |\n", "| 08 0.636 | 09 0.884 | 10 0.831 | |\n", "| | 13 0.945 | 14 0.977 | |\n" ] } ], "source": [ "V_best, pi_best = value_iteration(env.env.P, gamma=0.99)\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(pi_best, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best, goal_state=goal_state)*100, \n", " mean_return(env, pi_best)))\n", "print()\n", "print_state_value_function(V_best, P)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Russell & Norvig's Gridworld" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "env = gym.make('RussellNorvigGridworld-v0')\n", "init_state = env.reset()\n", "goal_state = 3\n", "P = env.env.P" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (PI):\n", "정책:\n", "| 00 > | 01 > | 02 > | |\n", "| 04 ^ | | 06 ^ | |\n", "| 08 ^ | 09 < | 10 < | 11 < |\n", "Reaches goal 96.00%. Obtains an average undiscounted return of 0.6424.\n", "\n", "상태-가치 함수:\n", "| 00 0.812 | 01 0.868 | 02 0.918 | |\n", "| 04 0.762 | | 06 0.66 | |\n", "| 08 0.705 | 09 0.655 | 10 0.611 | 11 0.388 |\n" ] } ], "source": [ "V_best_p, pi_best = policy_iteration(P)\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(pi_best, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best, goal_state=goal_state)*100, \n", " mean_return(env, pi_best)))\n", "print()\n", "print_state_value_function(V_best_p, P)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal policy and state-value function (PI):\n", "정책:\n", "| 00 > | 01 > | 02 > | |\n", "| 04 ^ | | 06 ^ | |\n", "| 08 ^ | 09 < | 10 < | 11 < |\n", "Reaches goal 96.00%. Obtains an average undiscounted return of 0.6424.\n", "\n", "상태-가치 함수:\n", "| 00 0.812 | 01 0.868 | 02 0.918 | |\n", "| 04 0.762 | | 06 0.66 | |\n", "| 08 0.705 | 09 0.655 | 10 0.611 | 11 0.388 |\n" ] } ], "source": [ "V_best_v, pi_best = value_iteration(P)\n", "print('Optimal policy and state-value function (PI):')\n", "print_policy(pi_best, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi_best, goal_state=goal_state)*100, \n", " mean_return(env, pi_best)))\n", "print()\n", "print_state_value_function(V_best_v, P)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Re-construct optimal policy:\n", "정책:\n", "| 00 > | 01 > | 02 > | |\n", "| 04 ^ | | 06 ^ | |\n", "| 08 ^ | 09 < | 10 < | 11 < |\n", "Reaches goal 96.00%. Obtains an average undiscounted return of 0.6424.\n" ] } ], "source": [ "LEFT, DOWN, RIGHT, UP = range(4)\n", "pi = lambda s: {\n", " 0:RIGHT, 1:RIGHT, 2:RIGHT, 3:LEFT,\n", " 4:UP, 5:LEFT, 6:UP, 7:LEFT,\n", " 8:UP, 9:LEFT, 10:LEFT, 11:LEFT\n", "}[s]\n", "print('Re-construct optimal policy:')\n", "print_policy(pi, P)\n", "print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(\n", " probability_success(env, pi, goal_state=goal_state)*100, \n", " mean_return(env, pi)))" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluate optimal policy:\n", "상태-가치 함수:\n", "| 00 0.812 | 01 0.868 | 02 0.918 | |\n", "| 04 0.762 | | 06 0.66 | |\n", "| 08 0.705 | 09 0.655 | 10 0.611 | 11 0.388 |\n" ] } ], "source": [ "V = policy_evaluation(pi, P)\n", "print('Evaluate optimal policy:')\n", "print_state_value_function(V, P)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Improve optimal policy (nothing to improve -- it is the same policy, because it is optimal):\n", "정책:\n", "| 00 > | 01 > | 02 > | |\n", "| 04 ^ | | 06 ^ | |\n", "| 08 ^ | 09 < | 10 < | 11 < |\n" ] } ], "source": [ "pi = policy_improvement(V, P)\n", "print('Improve optimal policy (nothing to improve -- it is the same policy, because it is optimal):')\n", "print_policy(pi, P)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are no differences, nothing to improve on the optimal policy and state-value function:\n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n" ] } ], "source": [ "print('There are no differences, nothing to improve on the optimal policy and state-value function:')\n", "print(np.abs(V_best_p - V))\n", "print(np.abs(V_best_v - V))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "PyTorch_Tutorial.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }