{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Policy Gradient with gym-MiniGrid\n", "> In this session, it will show the pytorch-implemented Policy Gradient in Gym-MiniGrid Environment. Through this, you will know how to implement Vanila Policy Gradient (also known as REINFORCE), and test it on open source RL environment.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, PyTorch, Reinforcement_Learning]\n", "- image: images/Minigrid_sample.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Jupyter Setting" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from pprint import pprint\n", "\n", "%matplotlib inline\n", "plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots\n", "plt.rcParams['image.interpolation'] = 'nearest'\n", "plt.rcParams['image.cmap'] = 'gray'\n", "\n", "# for auto-reloading external modules\n", "# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup the environment\n", "Gridworld is widely used in RL environment. [Gym-MiniGrid](https://github.com/maximecb/gym-minigrid) is custom GridWorld environment of OpenAI [gym](https://github.com/openai/gym) style. Before dive in this environment, you need to install both of them.\n", "```\n", "pip install gym\n", "pip install gym-minigrid \n", "```\n", "At first, Let's look at some frames of MiniGrid." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlYAAAE5CAYAAABS724NAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVqUlEQVR4nO3dfaxlV3kf4N8bD+YCJjVgQ13bKm5q8ZHIAWtE3YIiipOJoRF2JZBsRcGitqZVDSZNomCSP2ilRgK1DelEjSvHdjEVxbgOH1bkJmO5IERUDMOX8QfEU0PtwQ6G8pE0MFDTt3/cPfFlcsdj37POnHPuPI90tfdee59z3ru8vfzz2vvsW90dAABm92OLLgAAYLsQrAAABhGsAAAGEawAAAYRrAAABhGsAAAGmVuwqqoLqupLVbW/qq6a1+cAACyLmsdzrKrqhCR/muTnkhxI8qkkl3T3PcM/DABgScxrxuplSfZ39/3d/YMkNya5cE6fBQCwFHbM6X1PT/Lghu0DSf7ekQ5eW1vrZz7zmXMqZfGe9rSnLboEOC5973vfW3QJzMDYybJ68MEHv9Hdp262b17BqjZp+5FrjlW1O8nuJDnppJNy0UUXzamUxTvnnHMWXQIcl+68885Fl8AMjJ0sqyuvvPJ/HWnfvC4FHkhy5obtM5I8tPGA7r6mu3d29861tbU5lQEAcOzMK1h9KsnZVXVWVZ2Y5OIkt8zpswAAlsJcLgV296NV9aYkf5zkhCTXd/fd8/gsAIBlMa97rNLdtya5dV7vDwCwbDx5HQBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGCQLQerqjqzqj5SVfdW1d1V9Zap/dlVdVtV3TctnzWuXACA5TXLjNWjSX61u1+U5LwkV1TVi5NcleT27j47ye3TNgDAtrflYNXdD3f3Z6b1v0hyb5LTk1yY5IbpsBuSXDRrkQAAq2DIPVZV9fwkL01yR5LndffDyXr4SvLcEZ8BALDsZg5WVXVSkj9I8svd/edP4nW7q2pfVe07ePDgrGUAACzcTMGqqp6S9VD13u7+wNT8tao6bdp/WpJHNnttd1/T3Tu7e+fa2tosZQAALIVZvhVYSa5Lcm93//aGXbckuXRavzTJh7deHgDA6tgxw2tfnuSXknyhqj43tf1GknckuamqLkvyQJLXz1YiAMBq2HKw6u6PJ6kj7D5/q+8LALCqPHkdAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYJAdiy7geLJ3795FlzAXu3bt2ta/W7K9/9mx+q699tpFlzAXe/bs2fb/7m333+94ZMYKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwWrB3jr9vGDRhQCskEPjprGTZTNzsKqqE6rqs1X1h9P2WVV1R1XdV1Xvr6oTZy8TAGD5jXjy+luS3Jvkx6ftdyZ5V3ffWFX/McllSa4e8Dnb0is2LL84rX98Wn742JcDsBJekcfGT2Mny2SmGauqOiPJP0py7bRdSV6V5ObpkBuSXDTLZwAArIpZZ6x+J8mvJ3nmtP2cJN/u7ken7QNJTt/shVW1O8nuJDnppJNmLGN7eOFhy8uTfGhaP/R/Yl86phUBLD9jJ8tky8Gqqn4hySPd/emqeuWh5k0O7c1e393XJLkmSU499dRNj+Gx6b5Dy0NT3h9K8ifHvhyAlWDsZFFmmbF6eZLXVtVrkqxl/R6r30lyclXtmGatzkjy0OxlAgAsvy0Hq+5+W5K3Jck0Y/Vr3f2LVfVfk7wuyY1JLo37CIc6NNV9VZJvTOuHproPTX3/72NaEcDyO9rYadxklHk8x+qtSX6lqvZn/Z6r6+bwGQAAS2fE4xbS3R9N8tFp/f4kLxvxvjy+U6bl4fcSfDyPzV65YRPgR202dh4+82/sZKs8eR0AYJAhM1Ysl40Pzjt0L8HGrx67lwDgR73isOXGsfPQbJaxkydCsNrmDk15X75h6fkuAI9v49h5aPw0dvJEuBQIADCIGavj0JEenPfxeDYGwJFsNnb6+4QczowVAMAgZqz4kb+ztfGRDYmHjgIcyQvz2Pi52dhp3Dw+mbECABhEsAIAGMSlQNy8DrAFbl5nM2asAAAGMWN1HPKQO4Anz9jJE2HGCgBgEDNW25y/FQjw5PlbgWyVGSsAgEHMWG1DH89jM1TuAQA4usMfimzsZKsEqxV2aKraU9IBnrjNxk7jJqO4FAgAMIgZqxVz6GGeH0ryJ4ssBGCFGDs5VsxYAQAMYsZqyXkgHcCTZ+xkUcxYAQAMYsZqiWz8Y8iJP+oJ8EQYO1kmgtWCbfy6r6lqgCfG8/pYVi4FAgAMYsZqwd656AIAVpCxk2U104xVVZ1cVTdX1Rer6t6q+vtV9eyquq2q7puWzxpVLADAMpv1UuC/T/JH3f3CJD+d5N4kVyW5vbvPTnL7tA0AsO1tOVhV1Y8n+Zkk1yVJd/+gu7+d5MIkN0yH3ZDkolmLBABYBbPMWP2dJF9P8p+q6rNVdW1VPSPJ87r74SSZls/d7MVVtbuq9lXVvoMHD85QBgDAcpglWO1Icm6Sq7v7pUn+Mk/isl93X9PdO7t759ra2gxlAAAsh1mC1YEkB7r7jmn75qwHra9V1WlJMi0fma1EAIDVsOVg1d1/luTBqnrB1HR+knuS3JLk0qnt0ngILgBwnJj1OVZvTvLeqjoxyf1J3pj1sHZTVV2W5IEkr5/xMwAAVsJMwaq7P5dk5ya7zp/lfQEAVpE/aQMAMEh196JryKmnntoXXbR9H3d1zjnnLLoEOC7deeediy6BGRg7WVZXXnnlp7t7syt2/lbgsbR3795FlzAXu3bt2ta/W7K9/9mx+q699tpFlzAXe/bs2fb/7m333+945FIgAMAgghUAwCCCFQDAIIIVAMAgghUAwCCCFQDAIIIVAMAgghUAwCCCFQDAIIIVAMAgghUAwCCCFQDAIIIVAMAgghUAwCCCFQDAIIIVAMAgghUAwCCCFQDAIIIVAMAgghUAwCCCFQDAIIIVAMAgghUAwCAzBauq+hdVdXdV3VVV76uqtao6q6ruqKr7qur9VXXiqGIBAJbZloNVVZ2e5MokO7v7p5KckOTiJO9M8q7uPjvJt5JcNqJQAIBlN+ulwB1JnlZVO5I8PcnDSV6V5OZp/w1JLprxMwAAVsKWg1V3fzXJv03yQNYD1XeSfDrJt7v70emwA0lO3+z1VbW7qvZV1b6DBw9utQwAgKUxy6XAZyW5MMlZSf5WkmckefUmh/Zmr+/ua7p7Z3fvXFtb22oZAABLY5ZLgT+b5Mvd/fXu/r9JPpDkHyQ5ebo0mCRnJHloxhoBAFbCLMHqgSTnVdXTq6qSnJ/kniQfSfK66ZhLk3x4thIBAFbDLPdY3ZH1m9Q/k+QL03tdk+StSX6lqvYneU6S6wbUCQCw9HYc/ZAj6+63J3n7Yc33J3nZLO8LALCKPHkdAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgkB2LLuB4smvXrkWXMDfb9Xfb++a96ytvXmwdc/O7iy6AES6//PJFlzA3xpYVdRyPLWasAAAGMWN1DO3du3fRJczFrl27tu3vtm3/b5Jt5dprr110CXOxZ88eYwsrx4wVAMAgghUAwCBHDVZVdX1VPVJVd21oe3ZV3VZV903LZ03tVVV7qmp/Vd1ZVefOs3gAgGXyRGas3p3kgsParkpye3efneT2aTtJXp3k7Olnd5Krx5QJALD8jhqsuvtjSb55WPOFSW6Y1m9IctGG9vf0uk8kObmqThtVLADAMtvqPVbP6+6Hk2RaPndqPz3JgxuOOzC1/TVVtbuq9lXVvoMHD26xDACA5TH65vXapK03O7C7r+nund29c21tbXAZAADH3laD1dcOXeKblo9M7QeSnLnhuDOSPLT18gAAVsdWg9UtSS6d1i9N8uEN7W+Yvh14XpLvHLpkCACw3R31yetV9b4kr0xySlUdSPL2JO9IclNVXZbkgSSvnw6/NclrkuxP8t0kb5xDzQAAS+mowaq7LznCrvM3ObaTXDFrUQAAq8iT1wEABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABhGsAAAGEawAAAYRrAAABjlqsKqq66vqkaq6a0Pbv6mqL1bVnVX1wao6ecO+t1XV/qr6UlX9/LwKBwBYNk9kxurdSS44rO22JD/V3eck+dMkb0uSqnpxkouT/OT0mt+rqhOGVQsAsMSOGqy6+2NJvnlY297ufnTa/ESSM6b1C5Pc2N3f7+4vJ9mf5GUD6wUAWFoj7rH6J0n+27R+epIHN+w7MLUBAGx7MwWrqvrNJI8mee+hpk0O6yO8dndV7auqfQcPHpylDACApbDlYFVVlyb5hSS/2N2HwtOBJGduOOyMJA9t9vruvqa7d3b3zrW1ta2WAQCwNLYUrKrqgiRvTfLa7v7uhl23JLm4qp5aVWclOTvJJ2cvEwBg+e042gFV9b4kr0xySlUdSPL2rH8L8KlJbquqJPlEd/+z7r67qm5Kck/WLxFe0d0/nFfxAADL5KjBqrsv2aT5usc5/reS/NYsRQEArCJPXgcAGESwAgAYRLACABhEsAIAGESwAgAYRLACABhEsAIAGESwAgAY5KgPCGWcXbt2LbqEudm2v9vvLroAOLrLL7980SXMjbGFVWPGCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGCQowarqrq+qh6pqrs22fdrVdVVdcq0XVW1p6r2V9WdVXXuPIoGAFhGT2TG6t1JLji8sarOTPJzSR7Y0PzqJGdPP7uTXD17iQAAq+Gowaq7P5bkm5vseleSX0/SG9ouTPKeXveJJCdX1WlDKgUAWHJbuseqql6b5Kvd/fnDdp2e5MEN2wemNgCAbW/Hk31BVT09yW8m2bXZ7k3aepO2VNXurF8uzEknnfRkywAAWDpbmbH6iSRnJfl8VX0lyRlJPlNVfzPrM1Rnbjj2jCQPbfYm3X1Nd+/s7p1ra2tbKAMAYLk86WDV3V/o7ud29/O7+/lZD1PndvefJbklyRumbweel+Q73f3w2JIBAJbTE3ncwvuS/I8kL6iqA1V12eMcfmuS+5PsT/L7Sf75kCoBAFbAUe+x6u5LjrL/+RvWO8kVs5cFALB6PHkdAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYBDBCgBgEMEKAGAQwQoAYJDq7kXXkKr6epK/TPKNRdeyTZ0SfTsv+na+9O/86Nv50bfzsyx9+7e7+9TNdixFsEqSqtrX3TsXXcd2pG/nR9/Ol/6dH307P/p2flahb10KBAAYRLACABhkmYLVNYsuYBvTt/Ojb+dL/86Pvp0ffTs/S9+3S3OPFQDAqlumGSsAgJW2FMGqqi6oqi9V1f6qumrR9ay6qvpKVX2hqj5XVfumtmdX1W1Vdd+0fNai61wFVXV9VT1SVXdtaNu0L2vdnuk8vrOqzl1c5cvvCH37L6vqq9O5+7mqes2GfW+b+vZLVfXzi6l6NVTVmVX1kaq6t6rurqq3TO3O3Rk9Tt86dweoqrWq+mRVfX7q3381tZ9VVXdM5+77q+rEqf2p0/b+af/zF1l/sgTBqqpOSPIfkrw6yYuTXFJVL15sVdvCP+zul2z4WupVSW7v7rOT3D5tc3TvTnLBYW1H6stXJzl7+tmd5OpjVOOqenf+et8mybumc/cl3X1rkkxjwsVJfnJ6ze9NYwebezTJr3b3i5Kcl+SKqQ+du7M7Ut8mzt0Rvp/kVd3900lekuSCqjovyTuz3r9nJ/lWksum4y9L8q3u/rtJ3jUdt1ALD1ZJXpZkf3ff390/SHJjkgsXXNN2dGGSG6b1G5JctMBaVkZ3fyzJNw9rPlJfXpjkPb3uE0lOrqrTjk2lq+cIfXskFya5sbu/391fTrI/62MHm+juh7v7M9P6XyS5N8npce7O7HH69kicu0/CdA7+n2nzKdNPJ3lVkpun9sPP3UPn9M1Jzq+qOkblbmoZgtXpSR7csH0gj3+ScnSdZG9Vfbqqdk9tz+vuh5P1gSHJcxdW3eo7Ul86l8d403Q56voNl6z17RZNl0ZemuSOOHeHOqxvE+fuEFV1QlV9LskjSW5L8j+TfLu7H50O2diHf9W/0/7vJHnOsa34Ry1DsNosWfqq4mxe3t3nZn16/4qq+plFF3SccC7P7uokP5H1SwAPJ/l3U7u+3YKqOinJHyT55e7+88c7dJM2/fs4Nulb5+4g3f3D7n5JkjOyPrv3os0Om5ZL17/LEKwOJDlzw/YZSR5aUC3bQnc/NC0fSfLBrJ+YXzs0tT8tH1lchSvvSH3pXJ5Rd39tGlT/X5Lfz2OXTPTtk1RVT8n6f/jf290fmJqduwNs1rfO3fG6+9tJPpr1e9lOrqod066NffhX/Tvt/xt54rcYzMUyBKtPJTl7uuP/xKzf5HfLgmtaWVX1jKp65qH1JLuS3JX1Pr10OuzSJB9eTIXbwpH68pYkb5i+YXVeku8cuuzCE3PYfT3/OOvnbrLetxdP3wA6K+s3WX/yWNe3KqZ7TK5Lcm93//aGXc7dGR2pb527Y1TVqVV18rT+tCQ/m/X72D6S5HXTYYefu4fO6dcl+e+94Ad07jj6IfPV3Y9W1ZuS/HGSE5Jc3913L7isVfa8JB+c7t3bkeS/dPcfVdWnktxUVZcleSDJ6xdY48qoqvcleWWSU6rqQJK3J3lHNu/LW5O8Jus3p343yRuPecEr5Ah9+8qqeknWp/K/kuSfJkl3311VNyW5J+vfyrqiu3+4iLpXxMuT/FKSL0z3qiTJb8S5O8KR+vYS5+4QpyW5Yfrm5I8luam7/7Cq7klyY1X96ySfzXq4zbT8z1W1P+szVRcvouiNPHkdAGCQZbgUCACwLQhWAACDCFYAAIMIVgAAgwhWAACDCFYAAIMIVgAAgwhWAACD/H9qTo41YGvz9QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import gym\n", "import gym_minigrid\n", "\n", "env = gym.make('MiniGrid-Empty-5x5-v0')\n", "env.reset()\n", "before_img = env.render('rgb_array')\n", "action = env.actions.forward\n", "obs, reward, done, info = env.step(action)\n", "after_img = env.render('rgb_array')\n", "\n", "plt.imshow(np.concatenate([before_img, after_img], 1));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the example of `MiniGrid-Empty-5x5-v0` environment. There are some blank cells, and gray obstacle which the agent cannot pass it. And the green cell is the goal to reach. The ultimate goal of this environment (and most of RL problem) is to find the optimal policy with highest reward. In this case, well-trained agent should find the optimal path to reach the goal.\n", "\n", "Let's move to more larger environment `MiniGrid-Empty-8x8-v0`, and find the information what we can get." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Observation: {'image': array([[[2, 5, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0]],\n", "\n", " [[2, 5, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0]],\n", "\n", " [[2, 5, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0]],\n", "\n", " [[2, 5, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0],\n", " [1, 0, 0]],\n", "\n", " [[2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0]],\n", "\n", " [[2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0]],\n", "\n", " [[2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0],\n", " [2, 5, 0]]], dtype=uint8), 'direction': 1, 'mission': 'get to the green goal square'}\n", "Reward: 0\n", "Done: False\n", "Info: {}\n", "Image shape: (256, 256, 3)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdsAAAHVCAYAAAC5cFFEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVtklEQVR4nO3dX4ilB5nn8d+zxjEyHTDimHaT9CpDBkfBjdK4gsPiINtqbhIvXOKFBrG3vYigIM1Gb/RGcGFVEHbCxjIYwdEN+Ce5CDPtBkEc8E8ioUyMWRvNapuQbHRRewWHxGcv6jSpidXpTtd5+lRVfz7QnFPvec+p533rJF/e95w6Vd0dAGDOv1r1AACw14ktAAwTWwAYJrYAMExsAWCY2ALAsLHYVtVbq+qhqjpeVTdNfR8A2Olq4vdsq+p5Sf5Xkv+Q5ESS7yd5Z3f/aOnfDAB2uKkj29cnOd7dP+3uf07y5STXDn0vANjRLhp63MuT/GLT1yeS/LvTrXzxxRf3JZdcMjQKAMx74oknnujuv9jqtqnY1hbL/sX56qo6kuRIkuzbty/XXXfd0CgAMG9tbe1/n+62qdPIJ5JcuenrK5I8snmF7r6luw9298GLL754aAwAWL2p2H4/yVVV9Yqq+rMk1ye5c+h7AcCONnIaubufrKr3J/nHJM9Lcmt3PzDxvQBgp5t6zTbdfVeSu6YeHwB2C58gBQDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGEXrXqAJHnhC1+Y17zmNaseA9iF1tfXVz0CnJEjWwAYtiOObPe6Y8eOrXqEpTt06NCe3a5k7/3M9up2Jcn+/fuTJGtrayueZLkOHz6cZO9tV/L0tl1IHNkCwDCxBYBhYgsAw8QWAIZ5g9QO8Z+T/M0Kv/+3F5f/ZYUzAOxVjmwBYJgj2x3i61ntke3XV/i9AfY6R7YAMMyR7Q7xUJIfL66/8jx+31Pf86Hz+D0BLjRiu4OcepPS+Yztt8+8CgDb5DQyAAxzZLuD3LG4PJ+fGnrHmVcBYJsc2QLAMEe2O9CpX8O57jx8DwDmie0OdOpNS5Ox9cYogPPHaWQAGObIdgc69TuvP87MrwH9OH6vFuB8cmQLAMPEdgebehOTN0cBnF9OI+9g/5TkicX1lyzh8U491j8t4bEAOHuObAFgmCPbHW6Zvwbk130AVsORLQAMc2S7wy3z06S8MQpgNRzZAsCwbR3ZVtXDSX6X5KkkT3b3wap6cZL/keTlSR5O8h+7+/9ub8wL168Wl6deb/2bc3iMU/f91bOuBcCUZRzZ/m13X93dBxdf35Tk7u6+Ksndi6/Zpq/n3E8Db+e+AGzfxGnka5Pctrh+W2Y/Tx8AdrztvkGqkxyrqk7y37v7liSXdfejSdLdj1bVS7c7JE9/lvETeW4fcPFEfA4ywKptN7Zv7O5HFkH9RlX9+GzvWFVHkhxJkksvvXSbYwDAzrWt08jd/cji8vEkX0vy+iSPVdXLkmRx+fhp7ntLdx/s7oP79u3bzhgXlOf62qvXagFW75xjW1V/XlWXnLqe5FCS+5PcmeSGxWo3JLlju0PytOf6KVA+NQpg9bZzGvmyJF+rqlOP8/fd/Q9V9f0kt1fVe5P8PMk7tj8mAOxe5xzb7v5pkn+7xfJfJXnzdobi9H6Vs/tUqVPr+N1agNXzCVIAMMxnI+9CZ/OXgLxWC7BziO0udOr3Zk/9ntUrN93242esA8DqOY0MAMMc2e5ip04Vv3KLZQDsHI5sAWCYI9td7NSnhVy3xTIAdg5HtgAwzJHtHuB1WoCdTWz3AH9sAGBncxoZAIY5st0DfP4xwM7myBYAhoktAAwTWwAYJrYAMExsAWCY2ALAsOruVc+QAwcO9NGjR1c9BrALra+vr3oESJKsra3d290Ht7rN79meB8eOHVv1CEt36NChPbtdyd77me3V7UqS/fv3J0nW1tZWPMlyHT58OMne267k6W27kDiNDADDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYVt296hly4MCBPnr06KrHAHah9fX1VY8ASZK1tbV7u/vgVrc5sgWAYReteoALwbFjx1Y9wtIdOnRoz25Xsvd+Znt1u5Jk//79SZK1tbUVT7Jchw8fTrL3tit5etsuJI5sAWDYGWNbVbdW1eNVdf+mZS+uqm9U1U8Wl5culldVfaaqjlfVelW9bnJ4ANgNzubI9vNJ3vqMZTclubu7r0py9+LrJHlbkqsW/44kuXk5YwLA7nXG2Hb3t5L8+hmLr01y2+L6bUmu27T8C73hO0leVFUvW9awALAbnetrtpd196NJsrh86WL55Ul+sWm9E4tlAHDBWvYbpGqLZVv+Im9VHamqe6rqnpMnTy55DADYOc41to+dOj28uHx8sfxEkis3rXdFkke2eoDuvqW7D3b3wX379p3jGACw851rbO9McsPi+g1J7ti0/N2LdyW/IclvTp1uBoAL1Rk/1KKqvpTkTUleUlUnknw0ySeS3F5V703y8yTvWKx+V5JrkhxP8vsk7xmYGQB2lTPGtrvfeZqb3rzFup3kxu0OBQB7iU+QAoBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBh1d2rniEHDhzoo0ePrnoMYBdaX19f9QiQJFlbW7u3uw9udZsjWwAYdtGqB7gQHDt2bNUjLN2hQ4f27HYle+9ntle3K0n279+fJFlbW1vxJMt1+PDhJHtvu5Knt+1C4sgWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw84Y26q6taoer6r7Ny37WFX9sqruW/y7ZtNtH66q41X1UFW9ZWpwANgtzubI9vNJ3rrF8k9399WLf3clSVW9Ksn1SV69uM/fVdXzljUsAOxGZ4xtd38rya/P8vGuTfLl7v5Dd/8syfEkr9/GfACw623nNdv3V9X64jTzpYtllyf5xaZ1TiyWAcAF61xje3OSv0xydZJHk3xysby2WLe3eoCqOlJV91TVPSdPnjzHMQBg5zun2Hb3Y939VHf/Mcln8/Sp4hNJrty06hVJHjnNY9zS3Qe7++C+ffvOZQwA2BXOKbZV9bJNX749yal3Kt+Z5PqqekFVvSLJVUm+t70RAWB3u+hMK1TVl5K8KclLqupEko8meVNVXZ2NU8QPJ3lfknT3A1V1e5IfJXkyyY3d/dTM6ACwO5wxtt39zi0Wf+5Z1v94ko9vZygA2Et8ghQADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwrLq3/HOz59WBAwf66NGjqx4D2IXW19dXPQIkSdbW1u7t7oNb3ebIFgCGnfGv/rB9x44dW/UIS3fo0KE9u13J3vuZ7dXtSpL9+/cnSdbW1lY8yXIdPnw4yd7bruTpbbuQOLIFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDxBYAhlV3r3qGHDhwoI8ePbrqMYBdaH19fdUjjFj77NqqRxhz+D8dXvUII9bW1u7t7oNb3ebIFgCGXbTqAS4Ex44dW/UIS3fo0KE9u13J3vuZ7dXtSpL9+/cnSdbW9tiR4GdXPQDL5MgWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw84Y26q6sqq+WVUPVtUDVfWBxfIXV9U3quoni8tLF8urqj5TVcerar2qXje9EQCwk53Nke2TST7U3X+d5A1JbqyqVyW5Kcnd3X1VkrsXXyfJ25Jctfh3JMnNS58aAHaRM8a2ux/t7h8srv8uyYNJLk9ybZLbFqvdluS6xfVrk3yhN3wnyYuq6mVLnxwAdonn9JptVb08yWuTfDfJZd39aLIR5CQvXax2eZJfbLrbicUyALggnXVsq2pfkq8k+WB3//bZVt1i2Z/8Hb+qOlJV91TVPSdPnjzbMQBg1zmr2FbV87MR2i9291cXix87dXp4cfn4YvmJJFduuvsVSR555mN29y3dfbC7D+7bt+9c5weAHe9s3o1cST6X5MHu/tSmm+5McsPi+g1J7ti0/N2LdyW/IclvTp1uBoAL0dn8Pds3JnlXkh9W1X2LZR9J8okkt1fVe5P8PMk7FrfdleSaJMeT/D7Je5Y6MQDsMmeMbXd/O1u/Dpskb95i/U5y4zbnAoA9wydIAcAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMOqu1c9Qw4cONBHjx5d9RjALrS+vr7qESBJsra2dm93H9zqNke2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGHbG2FbVlVX1zap6sKoeqKoPLJZ/rKp+WVX3Lf5ds+k+H66q41X1UFW9ZXIDAGCnu+gs1nkyyYe6+wdVdUmSe6vqG4vbPt3d/3XzylX1qiTXJ3l1kn+d5H9W1V9191PLHBwAdoszHtl296Pd/YPF9d8leTDJ5c9yl2uTfLm7/9DdP0tyPMnrlzEsAOxGz+k126p6eZLXJvnuYtH7q2q9qm6tqksXyy5P8otNdzuRLeJcVUeq6p6quufkyZPPeXAA2C3OOrZVtS/JV5J8sLt/m+TmJH+Z5Ookjyb55KlVt7h7/8mC7lu6+2B3H9y3b99zHhwAdouzim1VPT8bof1id381Sbr7se5+qrv/mOSzefpU8YkkV266+xVJHlneyACwu5zNu5EryeeSPNjdn9q0/GWbVnt7kvsX1+9Mcn1VvaCqXpHkqiTfW97IALC7nM27kd+Y5F1JflhV9y2WfSTJO6vq6mycIn44yfuSpLsfqKrbk/woG+9kvtE7kQG4kJ0xtt397Wz9Ouxdz3Kfjyf5+DbmAoA9wydIAcAwsQWAYWILAMPEFgCGVfeffN7E+R+i6v8k+X9Jnlj1LHvMS2KfLpt9unz26fLZp8t3Nvv033T3X2x1w46IbZJU1T3dfXDVc+wl9uny2afLZ58un326fNvdp04jA8AwsQWAYTsptreseoA9yD5dPvt0+ezT5bNPl29b+3THvGYLAHvVTjqyBYA9aUfEtqreWlUPVdXxqrpp1fPsRlX1cFX9sKruq6p7FsteXFXfqKqfLC4vXfWcO11V3VpVj1fV/ZuWbbkfa8NnFs/b9ap63eom35lOsz8/VlW/XDxX76uqazbd9uHF/nyoqt6ymql3tqq6sqq+WVUPVtUDVfWBxXLP03P0LPt0ac/Vlce2qp6X5L8leVuSV2Xjrwm9arVT7Vp/291Xb3p7+k1J7u7uq5LcvfiaZ/f5JG99xrLT7ce3ZeNPSF6V5EiSm8/TjLvJ5/On+zNJPr14rl7d3XclyeK/++uTvHpxn79b/P+Bf+nJJB/q7r9O8oYkNy72nefpuTvdPk2W9FxdeWyz8Ufnj3f3T7v7n5N8Ocm1K55pr7g2yW2L67cluW6Fs+wK3f2tJL9+xuLT7cdrk3yhN3wnyYue8XeeL3in2Z+nc22SL3f3H7r7Z0mOZ+P/D2zS3Y929w8W13+X5MEkl8fz9Jw9yz49nef8XN0Jsb08yS82fX0iz76RbK2THKuqe6vqyGLZZd39aLLxZEry0pVNt7udbj967p679y9Oad666eUN+/M5qqqXJ3ltku/G83QpnrFPkyU9V3dCbLf6W7neIv3cvbG7X5eNU0Y3VtW/X/VAFwDP3XNzc5K/THJ1kkeTfHKx3P58DqpqX5KvJPlgd//22VbdYpn9uoUt9unSnqs7IbYnkly56esrkjyyoll2re5+ZHH5eJKvZeOUxmOnThctLh9f3YS72un2o+fuOejux7r7qe7+Y5LP5unTb/bnWaqq52cjCl/s7q8uFnuebsNW+3SZz9WdENvvJ7mqql5RVX+WjRed71zxTLtKVf15VV1y6nqSQ0nuz8Z+vGGx2g1J7ljNhLve6fbjnUnevXi35xuS/ObUaTxO7xmvF749G8/VZGN/Xl9VL6iqV2TjDT3fO9/z7XRVVUk+l+TB7v7Upps8T8/R6fbpMp+rFy135Oeuu5+sqvcn+cckz0tya3c/sOKxdpvLknxt4/mSi5L8fXf/Q1V9P8ntVfXeJD9P8o4VzrgrVNWXkrwpyUuq6kSSjyb5RLbej3cluSYbb474fZL3nPeBd7jT7M83VdXV2Tjt9nCS9yVJdz9QVbcn+VE23h16Y3c/tYq5d7g3JnlXkh9W1X2LZR+J5+l2nG6fvnNZz1WfIAUAw3bCaWQA2NPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYf8fKoYjlfniNlsAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Make a new environment MiniGrid-Empty-8x8-v0\n", "env = gym.make('MiniGrid-Empty-8x8-v0')\n", "\n", "# Reset the environment\n", "env.reset()\n", "\n", "# Select the action right (sample action)\n", "action = env.actions.right\n", "\n", "# Take a step in the environment and store it in appropriate variables\n", "obs, reward, done, info = env.step(action)\n", "\n", "# Render the current state of the environment\n", "img = env.render('rgb_array')\n", "\n", "print('Observation:', obs)\n", "print('Reward:', reward)\n", "print('Done:', done)\n", "print('Info:', info)\n", "print('Image shape:', img.shape)\n", "plt.imshow(img);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the agent take an action, environment (MiniGrid) will be changed with respect to action. \n", "If the agent want to find the optimal path, the agent should notice the difference between current state and next state while taking an action. To help this, the environment generates next state, reward, and terminal flags.\n", "\n", "Some helper function offers to render the sample action in Jupyter Notebook." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import base64\n", "import glob\n", "import io\n", "from IPython.display import HTML\n", "from IPython import display \n", "\n", "def show_video():\n", " mp4list = glob.glob('video/*.mp4')\n", " if len(mp4list) > 0:\n", " mp4 = mp4list[0]\n", " video = io.open(mp4, 'r+b').read()\n", " encoded = base64.b64encode(video)\n", " display.display(HTML(data=''''''.format(encoded.decode('ascii'))))\n", " else:\n", " print(\"Could not find video\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To help agent training easily, MiniGrid offers `FlatObsWrapper` for flattening observation (in other words, 1D array)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import gym\n", "from gym import spaces\n", "from gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX\n", "\n", "max_env_steps = 50\n", "\n", "class FlatObsWrapper(gym.core.ObservationWrapper):\n", " \"\"\"Fully observable gridworld returning a flat grid encoding.\"\"\"\n", "\n", " def __init__(self, env):\n", " super().__init__(env)\n", "\n", " # Since the outer walls are always present, we remove left, right, top, bottom walls\n", " # from the observation space of the agent. There are 3 channels, but for simplicity\n", " # in this assignment, we will deal with flattened version of state.\n", " \n", " self.observation_space = spaces.Box(\n", " low=0,\n", " high=255,\n", " shape=((self.env.width-2) * (self.env.height-2) * 3,), # number of cells\n", " dtype='uint8'\n", " )\n", " self.unwrapped.max_steps = max_env_steps\n", "\n", " def observation(self, obs):\n", " # this method is called in the step() function to get the observation\n", " # we provide code that gets the grid state and places the agent in it\n", " env = self.unwrapped\n", " full_grid = env.grid.encode()\n", " full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([\n", " OBJECT_TO_IDX['agent'],\n", " COLOR_TO_IDX['red'],\n", " env.agent_dir\n", " ])\n", " full_grid = full_grid[1:-1, 1:-1] # remove outer walls of the environment (for efficiency)\n", " \n", " flattened_grid = full_grid.ravel()\n", " return flattened_grid\n", " \n", " def render(self, *args, **kwargs):\n", " \"\"\"This removes the default visualization of the partially observable field of view.\"\"\"\n", " kwargs['highlight'] = False\n", " return self.unwrapped.render(*args, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So It's time to run with sample action!" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Observation: [10 0 1 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0\n", " 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0\n", " 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0\n", " 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0\n", " 1 0 0 1 0 0 1 0 0 8 1 0] , Observation Shape: (108,)\n", "Reward: 0\n", "Done: False\n", "Info: {}\n", "Image shape: (256, 256, 3)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdsAAAHVCAYAAAC5cFFEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVB0lEQVR4nO3dX4ild53n8c93jWPDJKBiG9wkrCK9MPFiozSuIiwOsqPmJuWFS7zQIPa0FxF015vojd4IXqwKwk7YWAYjOLoRFXMRZsYNggz4L5GgiVnXRrOmTUjWdVFZaYfE717UKbomVqe7q+rb59Tp1wuKc+p3nnPq9zz9mLfPc/5VdwcAmPMvlj0BAFh3YgsAw8QWAIaJLQAME1sAGCa2ADBsLLZV9Zaq+klVnaqq26b+DgCsupp4n21VPS/J/0zy75OcTvL9JO/o7h8f+B8DgBU3dWT72iSnuvtn3f1PSb6U5KahvwUAK+2Koce9JsljO34/neTfnmvhI0eO9FVXXTU0FQCY96tf/epX3X10t9umYlu7jP2z89VVdTLJySS58sors7GxMTQVAJi3ubn5v85129Rp5NNJrtvx+7VJHt+5QHff0d3Hu/v4kSNHhqYBAMs3FdvvJzlWVa+oqj9LcnOSe4b+FgCstJHTyN39dFW9L8nfJ3lekju7++GJvwUAq27qOdt0971J7p16fAA4LHyCFAAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGHbFsidwOdjc3Fz2FA7ciRMn1na9kvX7N1vX9UrWd93Wdb2Ss+t2OXFkCwDDxBYAhoktAAwTWwAYJrYr4u4kvcSfuxc/ABw8sQWAYd76syI+meTtS/77AMxwZAsAwxzZrojvJPn24vrrL+Hf3f6b37mEfxPgciO2K+TLi8tLGdsvn38RAPbJaWQAGObIdoV8anF5KV+s9KnzLwLAPjmyBYBhjmxX0PaR7X+6BH8DgHliu4K2X7Q0GVsvjAK4dJxGBoBhjmxX0PZ7Xr+dmbcBfTveVwtwKTmyBYBhYrvCpt6W4+0+AJeW08gr7MtJHltcv+4AHm/7sbw4CuDScmQLAMMc2a64g3wbkCNagOVwZAsAwxzZrrjtFzMdxJGtF0YBLIcjWwAYtq8j26p6NMnvkjyT5OnuPl5VL07y35K8PMmjSf5Dd//f/U3z8nV6cbn9fOvb9/AY2/c9/ZxLATDlII5s/7K7b+ju44vfb0tyX3cfS3Lf4nf26ZPZ+5cH7Oe+AOzfxGnkm5Lctbh+V5KNgb8BAIfGfl8g1Un+oao6yX/t7juSXN3dTyRJdz9RVS/d7yQ5+1nGj+XiPuDisfgcZIBl229s39Ddjy+C+o2q+h8XeseqOpnkZJJceeWV+5wGAKyufZ1G7u7HF5dPJflaktcmebKqXpYki8unznHfO7r7eHcfP3LkyH6mcVm52LfveLsPwPLtObZV9edVddX29SR/leShJPckuWWx2C1Jvr7fSXLWxX4KlE+NAli+/ZxGvjrJ16pq+3H+trv/rqq+n+TuqnpPkl9kb+9WAYC1sefYdvfPkvybXcb/T5I37WdSnNvpnH0bz3N9qtT2Mt5bC7B8PkEKAIb5bORD6EK+CchztQCrQ2wPoe33zX57cfn6Hbd9+1nLALB8TiMDwDBHtofY9qni1+8yBsDqcGQLAMMc2R5i258O9R93GQNgdTiyBYBhjmzXgOdpAVab2K4Bp44BVpvTyAAwzJHtGvD5xwCrzZEtAAwTWwAYJrYAMExsAWCY2ALAMLEFgGHV3cueQ44ePdobGxvLngYA7Nnm5uYD3X18t9u8z/YS2NzcXPYUDtyJEyfWdr2S9fs3W9f1StZ33dZ1vZKz63Y5cRoZAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADCsunvZc8jRo0d7Y2Nj2dMAgD3b3Nx8oLuP73abI1sAGHbFsidwOdjc3Fz2FA7ciRMn1na9kvX7N1vX9UrWd93Wdb2Ss+t2OXFkCwDDzhvbqrqzqp6qqod2jL24qr5RVT9dXL5oMV5V9emqOlVVP6yq10xOHgAOgws5sv1ckrc8a+y2JPd197Ek9y1+T5K3Jjm2+DmZ5PaDmSYAHF7njW13fyvJr581fFOSuxbX70qysWP8873lO0leWFUvO6jJAsBhtNfnbK/u7ieSZHH50sX4NUke27Hc6cUYAFy2DvoFUrXL2K5v5K2qk1V1f1Xdf+bMmQOeBgCsjr3G9snt08OLy6cW46eTXLdjuWuTPL7bA3T3Hd19vLuPHzlyZI/TAIDVt9fY3pPklsX1W5J8fcf4uxavSn5dkt9sn24GgMvVeT/Uoqq+mOSNSV5SVaeTfCTJx5PcXVXvSfKLJG9fLH5vkhuTnEry+yTvHpgzABwq541td7/jHDe9aZdlO8mt+50UAKwTnyAFAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsAwsQWAYWILAMOqu5c9hxw9erQ3NjaWPQ0A2LPNzc0Huvv4brc5sgWAYVcsewKXg83NzWVP4cCdOHFibdcrWb9/s3Vdr2R9121d1ys5u26XE0e2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGHbe2FbVnVX1VFU9tGPso1X1y6p6cPFz447bPlRVp6rqJ1X15qmJA8BhcSFHtp9L8pZdxj/V3Tcsfu5Nkqq6PsnNSV61uM/fVNXzDmqyAHAYnTe23f2tJL++wMe7KcmXuvsP3f3zJKeSvHYf8wOAQ28/z9m+r6p+uDjN/KLF2DVJHtuxzOnFGABctvYa29uTvDLJDUmeSPKJxXjtsmzv9gBVdbKq7q+q+8+cObPHaQDA6ttTbLv7ye5+prv/mOQzOXuq+HSS63Ysem2Sx8/xGHd09/HuPn7kyJG9TAMADoU9xbaqXrbj17cl2X6l8j1Jbq6qF1TVK5IcS/K9/U0RAA63K863QFV9Mckbk7ykqk4n+UiSN1bVDdk6RfxokvcmSXc/XFV3J/lxkqeT3Nrdz8xMHQAOh/PGtrvfscvwZ59j+Y8l+dh+JgUA68QnSAHAMLEFgGFiCwDDxBYAhoktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDDqnvXr5u9pI4ePdobGxvLngYA7Nnm5uYD3X18t9sc2QLAsPN+6w/7t7m5uewpHLgTJ06s7Xol6/dvtq7rlazvuq3reiVn1+1y4sgWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGFbdvew55OjRo72xsbHsaQCsjM3PbC57CmNO/PWJZU9hxObm5gPdfXy32xzZAsCwK5Y9gcvB5ub6/T/UEydOrO16Jev3b7au65Ws8bp9ZtkT4CA5sgWAYWILAMPEFgCGiS0ADBNbABgmtgAwTGwBYJjYAsCw88a2qq6rqm9W1SNV9XBVvX8x/uKq+kZV/XRx+aLFeFXVp6vqVFX9sKpeM70SALDKLuTI9ukkH+zuv0jyuiS3VtX1SW5Lcl93H0ty3+L3JHlrkmOLn5NJbj/wWQPAIXLe2Hb3E939g8X13yV5JMk1SW5KctdisbuSbH+TwE1JPt9bvpPkhVX1sgOfOQAcEhf1nG1VvTzJq5N8N8nV3f1EshXkJC9dLHZNksd23O30YgwALksXHNuqujLJV5J8oLt/+1yL7jL2J9/jV1Unq+r+qrr/zJkzFzoNADh0Lii2VfX8bIX2C9391cXwk9unhxeXTy3GTye5bsfdr03y+LMfs7vv6O7j3X38yJEje50/AKy8C3k1ciX5bJJHuvuTO266J8kti+u3JPn6jvF3LV6V/Lokv9k+3QwAl6ML+T7bNyR5Z5IfVdWDi7EPJ/l4krur6j1JfpHk7Yvb7k1yY5JTSX6f5N0HOmMAOGTOG9vu/sfs/jxskrxpl+U7ya37nBcArA2fIAUAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAMq+5e9hxy9OjR3tjYWPY0AGDPNjc3H+ju47vd5sgWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw8QWAIaJLQAME1sAGCa2ADBMbAFgmNgCwDCxBYBhYgsAw84b26q6rqq+WVWPVNXDVfX+xfhHq+qXVfXg4ufGHff5UFWdqqqfVNWbJ1cAAFbdFRewzNNJPtjdP6iqq5I8UFXfWNz2qe7+zzsXrqrrk9yc5FVJ/mWS/15V/7q7nznIiQPAYXHeI9vufqK7f7C4/rskjyS55jnuclOSL3X3H7r750lOJXntQUwWAA6ji3rOtqpenuTVSb67GHpfVf2wqu6sqhctxq5J8tiOu53OLnGuqpNVdX9V3X/mzJmLnjgAHBYXHNuqujLJV5J8oLt/m+T2JK9MckOSJ5J8YnvRXe7efzLQfUd3H+/u40eOHLnoiQPAYXFBsa2q52crtF/o7q8mSXc/2d3PdPcfk3wmZ08Vn05y3Y67X5vk8YObMgAcLhfyauRK8tkkj3T3J3eMv2zHYm9L8tDi+j1Jbq6qF1TVK5IcS/K9g5syABwuF/Jq5DckeWeSH1XVg4uxDyd5R1XdkK1TxI8meW+SdPfDVXV3kh9n65XMt3olMgCXs/PGtrv/Mbs/D3vvc9znY0k+to95AcDa8AlSADBMbAFgmNgCwDCxBYBh1f0nnzdx6SdR9b+T/L8kv1r2XNbMS2KbHjTb9ODZpgfPNj14F7JN/1V3H93thpWIbZJU1f3dfXzZ81gntunBs00Pnm168GzTg7ffbeo0MgAME1sAGLZKsb1j2RNYQ7bpwbNND55tevBs04O3r226Ms/ZAsC6WqUjWwBYSysR26p6S1X9pKpOVdVty57PYVRVj1bVj6rqwaq6fzH24qr6RlX9dHH5omXPc9VV1Z1V9VRVPbRjbNftWFs+vdhvf1hVr1nezFfTObbnR6vql4t99cGqunHHbR9abM+fVNWblzPr1VZV11XVN6vqkap6uKrevxi3n+7Rc2zTA9tXlx7bqnpekv+S5K1Jrs/Wtwldv9xZHVp/2d037Hh5+m1J7uvuY0nuW/zOc/tckrc8a+xc2/Gt2foKyWNJTia5/RLN8TD5XP50eybJpxb76g3dfW+SLP53f3OSVy3u8zeL/z7wzz2d5IPd/RdJXpfk1sW2s5/u3bm2aXJA++rSY5utL50/1d0/6+5/SvKlJDcteU7r4qYkdy2u35VkY4lzORS6+1tJfv2s4XNtx5uSfL63fCfJC5/1Pc+XvXNsz3O5KcmXuvsP3f3zJKey9d8HdujuJ7r7B4vrv0vySJJrYj/ds+fYpudy0fvqKsT2miSP7fj9dJ57JdldJ/mHqnqgqk4uxq7u7ieSrZ0pyUuXNrvD7Vzb0b67d+9bnNK8c8fTG7bnRaqqlyd5dZLvxn56IJ61TZMD2ldXIba7fVeul0hfvDd092uydcro1qr6d8ue0GXAvrs3tyd5ZZIbkjyR5BOLcdvzIlTVlUm+kuQD3f3b51p0lzHbdRe7bNMD21dXIbank1y34/drkzy+pLkcWt39+OLyqSRfy9YpjSe3TxctLp9a3gwPtXNtR/vuHnT3k939THf/Mclncvb0m+15garq+dmKwhe6+6uLYfvpPuy2TQ9yX12F2H4/ybGqekVV/Vm2nnS+Z8lzOlSq6s+r6qrt60n+KslD2dqOtywWuyXJ15czw0PvXNvxniTvWrza83VJfrN9Go9ze9bzhW/L1r6abG3Pm6vqBVX1imy9oOd7l3p+q66qKslnkzzS3Z/ccZP9dI/OtU0Pcl+94mCnfPG6++mqel+Sv0/yvCR3dvfDS57WYXN1kq9t7S+5IsnfdvffVdX3k9xdVe9J8oskb1/iHA+FqvpikjcmeUlVnU7ykSQfz+7b8d4kN2brxRG/T/LuSz7hFXeO7fnGqrohW6fdHk3y3iTp7oer6u4kP87Wq0Nv7e5nljHvFfeGJO9M8qOqenAx9uHYT/fjXNv0HQe1r/oEKQAYtgqnkQFgrYktAAwTWwAYJrYAMExsAWCY2ALAMLEFgGFiCwDD/j+o8Sx5qwnojgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Convert MiniGrid Environment with Flat Observable \n", "env = FlatObsWrapper(gym.make('MiniGrid-Empty-8x8-v0'))\n", "\n", "# Reset the environment\n", "env.reset()\n", "\n", "# Select the action right\n", "action = env.actions.right\n", "\n", "# Take a step in the environment and store it in appropriate variables\n", "obs, reward, done, info = env.step(action)\n", "\n", "# Render the current state of the environment\n", "img = env.render('rgb_array')\n", "################# YOUR CODE ENDS HERE ###############################\n", "\n", "print('Observation:', obs, ', Observation Shape: ', obs.shape)\n", "print('Reward:', reward)\n", "print('Done:', done)\n", "print('Info:', info)\n", "print('Image shape:', img.shape)\n", "plt.imshow(img);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see it in observation, the dimension of observation is changed from 2D to 1D. Using this observation, we will make some kind of neural network to help agent to notice the observation. Let's check the real-time video of random movement." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from gym.wrappers import Monitor\n", "\n", "# Monitor is a gym wrapper, which helps easy rendering of videos of the wrapped environment.\n", "def wrap_env(env):\n", " env = Monitor(env, './video', force=True)\n", " return env\n", "\n", "def gen_wrapped_env(env_name):\n", " return wrap_env(FlatObsWrapper(gym.make(env_name)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Currently, OpenAI Gym offers several utils to help understanding the training progress. Monitor is one of that tool to log the history data. If we set the rendering option to `rgb_array`, the video data will be stored in specific path. (Maybe it requires some additional apps such as ffmpeg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test with Random Policy" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Random agent - we only use it in this cell for demonstration\n", "class RandPolicy:\n", " def __init__(self, action_space):\n", " self.action_space = action_space\n", " \n", " def act(self, *unused_args):\n", " return self.action_space.sample(), None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At first, we want check the operation of environment-agent interaction. To do this, Random Policy that generates the \"random action\" is defined. This policy just generates random action from pre-defined action space. And then run it. \n", "> Note that `pytorch_policy` flag is set to `False` as a default. But to implement the policy gradient, the gradient calculation is required, and pytorch will be used." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total reward: 0\n", "Total length: 50\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# This function plots videos of rollouts (episodes) of a given policy and environment\n", "def log_policy_rollout(policy, env_name, pytorch_policy=False):\n", " # Create environment with flat observation\n", " env = gen_wrapped_env(env_name)\n", "\n", " # Initialize environment\n", " observation = env.reset()\n", "\n", " done = False\n", " episode_reward = 0\n", " episode_length = 0\n", "\n", " # Run until done == True\n", " while not done:\n", " # Take a step\n", " if pytorch_policy: \n", " observation = torch.tensor(observation, dtype=torch.float32)\n", " action = policy.act(observation)[0].data.cpu().numpy()\n", " else:\n", " action = policy.act(observation)[0]\n", " observation, reward, done, info = env.step(action)\n", "\n", " episode_reward += reward\n", " episode_length += 1\n", "\n", " print('Total reward:', episode_reward)\n", " print('Total length:', episode_length)\n", "\n", " env.close()\n", " \n", " show_video()\n", "\n", "# Test that the logging function is working\n", "test_env_name = 'MiniGrid-Empty-8x8-v0'\n", "rand_policy = RandPolicy(FlatObsWrapper(gym.make(test_env_name)).action_space)\n", "log_policy_rollout(rand_policy, test_env_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's the agent work with Random Policy. We found out that Random Policy is not optimal policy since the agent (the red one) cannot reach the goal.(or maybe it'll reach the goal after infinite times go on...) So to reach the goal, it requires more intelligent policy. In natural sense of mind, it needs,\n", "\n", "- Remember the previous trajectory\n", "- When it goes to unknown cell, based on the experience with memory, use it to find the way to goal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implement Rollout Buffer\n", "Before implementing Policy Gradient, it requires to implement memory object to store the previous trajectory or information offered from environment. Sometimes, it is called \"Replay Buffer\" or \"Rollout Buffer\", but in this page, RolloutBuffer will be used for expression. To implement Rollout Buffer, we need to consider such that,\n", "\n", "- how many trajectories stored in buffer?\n", "- how to add trajectory into the buffer?\n", "- (In view of Reinforcement Learning) how to calculate the future reward based on previous reward\n", "- (+) how to sample the trajectory efficiently?\n", "\n", "So this is RolloutBuffer implementation!" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler\n", "\n", "class RolloutBuffer():\n", " def __init__(self, rollout_size, obs_size):\n", " self.rollout_size = rollout_size\n", " self.obs_size = obs_size\n", " self.reset()\n", " \n", " def insert(self, step, done, action, log_prob, reward, obs): \n", " self.done[step].copy_(done)\n", " self.actions[step].copy_(action)\n", " self.log_probs[step].copy_(log_prob)\n", " self.rewards[step].copy_(reward)\n", " self.obs[step].copy_(obs)\n", " \n", " def reset(self):\n", " self.done = torch.zeros(self.rollout_size, 1)\n", " self.returns = torch.zeros(self.rollout_size + 1, 1, requires_grad=False)\n", " # Assuming Discrete Action Space\n", " self.actions = torch.zeros(self.rollout_size, 1, dtype=torch.int64) \n", " self.log_probs = torch.zeros(self.rollout_size, 1)\n", " self.rewards = torch.zeros(self.rollout_size, 1)\n", " self.obs = torch.zeros(self.rollout_size, self.obs_size)\n", " \n", " def compute_returns(self, gamma):\n", " # Compute Returns until the last finished episode\n", " self.last_done = (self.done == 1).nonzero().max() \n", " self.returns[self.last_done + 1] = 0.\n", "\n", " # Accumulate discounted returns\n", " for step in reversed(range(self.last_done + 1)):\n", " self.returns[step] = self.returns[step + 1] * \\\n", " gamma * (1 - self.done[step]) + self.rewards[step]\n", " \n", " def batch_sampler(self, batch_size, get_old_log_probs=False):\n", " sampler = BatchSampler(\n", " SubsetRandomSampler(range(self.last_done)),\n", " batch_size,\n", " drop_last=True)\n", " for indices in sampler:\n", " if get_old_log_probs:\n", " yield self.actions[indices], self.returns[indices], self.obs[indices], self.log_probs[indices]\n", " else:\n", " yield self.actions[indices], self.returns[indices], self.obs[indices]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are couple of things to notice that,\n", "\n", "- All information stored in RolloutBuffer should get the type of `torch.Tensor`\n", "- In this case, returns will be used for minimizing the loss. So returns object should set the `requires_grad` to `True`\n", "- It is inefficient to use all information to train the policy. To handle it, it requires something special sampling strategy. In this code, `BatchSample` is used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Construct Policy Network\n", "\n", "Now that we can store rollouts we need a policy to collect them. In the following you will complete the provided base code for the policy class. The policy is instantiated as a small neural network with simple fully-connected layers, the `ActorNetwork`. The role of policy is sort of strategy that generates the action. (Actually, it is just the probability to generate the action).\n", " And Of course, the important work through `ActorNetwork` is to update policy per each iteration. With pytorch, we need to define,\n", " \n", "- What optimizer should we use?\n", "- How can we define the loss function?\n", "\n", "At first, Let's look gradient function used in policy gradient,\n", "\n", "$$ \\nabla J(\\theta) = \\mathbb{E}_{\\pi}\\big[ \\nabla_{\\theta} \\log \\pi_{\\theta}(a, s) \\; V_t(s) \\big] $$\n", "\n", "Here, $\\theta$ are the parameters of the policy network $\\pi_{\\theta}$ and $V_t(s)$ is the observed future discounted reward from state $s$ onwards which should be **maximized** (we need to focus on this keyword, since the purpose of neural network training is to **minimize** the loss, not **maximize**). So anyway we need the calculate the gradient of $\\log \\pi_{\\theta}(a, s)$ and calculate its mean.\n", "\n", "And Plus, there are some approaches to enhance the exploration. If we can consider the **entropy loss** to handle the overall loss, it takes diverse action. At that case gradient fuction will be,\n", "\n", "$$ \\nabla J(\\theta) = \\mathbb{E}_{\\pi}\\big[ \\nabla_{\\theta} \\log \\pi_{\\theta}(a, s) \\; V_t(s) \\big] + \\nabla_{\\theta}\\mathcal{H}\\big[\\pi_\\theta(a, s)\\big]$$\n", "\n", "And here is the implementation of Actor Network (and it's quite simple!)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class ActorNetwork(nn.Module):\n", " def __init__(self, num_inputs, num_actions, hidden_dim):\n", " super().__init__()\n", " self.num_actions = num_actions\n", " \n", " self.fc = nn.Sequential(\n", " nn.Linear(num_inputs, hidden_dim),\n", " nn.Tanh(),\n", " nn.Linear(hidden_dim, hidden_dim),\n", " nn.Tanh(),\n", " nn.Linear(hidden_dim, num_actions)\n", " )\n", " \n", " def forward(self, state):\n", " x = self.fc(state)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And Below is the implementation of Policy. We select the Adam Optimizer " ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "import torch.nn.functional as F\n", "from torch.distributions.categorical import Categorical\n", "from utils.utils import count_model_params\n", "\n", "class Policy():\n", " def __init__(self, num_inputs, num_actions, hidden_dim, learning_rate,\n", " batch_size, policy_epochs, entropy_coef=0.001):\n", " self.actor = ActorNetwork(num_inputs, num_actions, hidden_dim)\n", " self.optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)\n", " self.batch_size = batch_size\n", " self.policy_epochs = policy_epochs\n", " self.entropy_coef = entropy_coef\n", "\n", " def act(self, state):\n", " logits = self.actor(state)\n", " # To generate the probability of action, we assume its state has categorical distribution.\n", " dist = Categorical(logits=logits)\n", " action = dist.sample()\n", " log_prob = dist.log_prob(action)\n", " return action, log_prob\n", " \n", " def evaluate_actions(self, state, action):\n", " logits = self.actor(state)\n", " dist = Categorical(logits=logits)\n", " log_prob = dist.log_prob(action.squeeze(-1)).view(-1, 1)\n", " entropy = dist.entropy().view(-1, 1)\n", " return log_prob, entropy\n", " \n", " def update(self, rollouts):\n", " for epoch in range(self.policy_epochs):\n", " data = rollouts.batch_sampler(self.batch_size)\n", " \n", " for sample in data:\n", " actions_batch, returns_batch, obs_batch = sample\n", " \n", " log_probs_batch, entropy_batch = self.evaluate_actions(obs_batch, actions_batch)\n", " \n", " # Compute the mean loss for the policy update using \n", " # action log-probabilities and policy returns\n", " policy_loss = -(log_probs_batch * returns_batch).mean()\n", " # Compute the mean entropy for the policy update \n", " entropy_loss = -entropy_batch.mean()\n", " \n", " loss = policy_loss + self.entropy_coef * entropy_loss\n", " \n", " self.optimizer.zero_grad()\n", " loss.backward(retain_graph=False)\n", " self.optimizer.step()\n", " \n", " @property\n", " def num_params(self):\n", " return count_model_params(self.actor)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "from IPython.display import clear_output\n", "from utils.utils import AverageMeter, plot_learning_curve\n", "import time\n", "\n", "def train(env, rollouts, policy, params, seed=123):\n", " # SETTING SEED: it is good practice to set seeds when running experiments to keep results comparable\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " env.seed(seed)\n", "\n", " rollout_time, update_time = AverageMeter(), AverageMeter() # Loggers\n", " rewards, success_rate = [], []\n", "\n", " print(\"Training model with {} parameters...\".format(policy.num_params))\n", "\n", " # Training Loop\n", " for j in range(params.num_updates):\n", " ## Initialization\n", " avg_eps_reward, avg_success_rate = AverageMeter(), AverageMeter()\n", " done = False\n", " prev_obs = env.reset()\n", " prev_obs = torch.tensor(prev_obs, dtype=torch.float32)\n", " eps_reward = 0.\n", " start_time = time.time()\n", " \n", " ## Collect rollouts\n", " for step in range(rollouts.rollout_size):\n", " if done:\n", " # Store episode statistics\n", " avg_eps_reward.update(eps_reward)\n", " if 'success' in info: \n", " avg_success_rate.update(int(info['success']))\n", "\n", " # Reset Environment\n", " obs = env.reset()\n", " obs = torch.tensor(obs, dtype=torch.float32)\n", " eps_reward = 0.\n", " else:\n", " obs = prev_obs\n", "\n", " action, log_prob = policy.act(obs)\n", " obs, reward, done, info = env.step(action)\n", "\n", " rollouts.insert(step, torch.tensor(done, dtype=torch.float32), action, log_prob, \n", " torch.tensor(reward, dtype=torch.float32), \n", " prev_obs)\n", " \n", " prev_obs = torch.tensor(obs, dtype=torch.float32)\n", " eps_reward += reward\n", " \n", " # Use the rollout buffer's function to compute the returns for all stored rollout steps. (requires just 1 line)\n", " rollouts.compute_returns(params['discount'])\n", " \n", " rollout_done_time = time.time()\n", "\n", " \n", " # Call the policy's update function using the collected rollouts \n", " policy.update(rollouts)\n", "\n", " update_done_time = time.time()\n", " rollouts.reset()\n", "\n", " ## log metrics\n", " rewards.append(avg_eps_reward.avg)\n", " if avg_success_rate.count > 0:\n", " success_rate.append(avg_success_rate.avg)\n", " rollout_time.update(rollout_done_time - start_time)\n", " update_time.update(update_done_time - rollout_done_time)\n", " print('it {}: avgR: {:.3f} -- rollout_time: {:.3f}sec -- update_time: {:.3f}sec'.format(j, \n", " avg_eps_reward.avg, \n", " rollout_time.avg,\n", " update_time.avg))\n", " if j % params.plotting_iters == 0 and j != 0:\n", " plot_learning_curve(rewards, success_rate, params.num_updates)\n", " log_policy_rollout(policy, params.env_name, pytorch_policy=True)\n", " clear_output() # this removes all training outputs to keep the notebook clean, DON'T REMOVE THIS LINE!\n", " return rewards, success_rate" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "from utils.utils import ParamDict\n", "import copy\n", "\n", "def instantiate(params_in, nonwrapped_env=None):\n", " params = copy.deepcopy(params_in)\n", "\n", " if nonwrapped_env is None:\n", " nonwrapped_env = gym.make(params.env_name)\n", "\n", " env = None\n", " env = FlatObsWrapper(nonwrapped_env) \n", " obs_size = env.observation_space.shape[0]\n", " num_actions = env.action_space.n\n", "\n", " rollouts = RolloutBuffer(params.rollout_size, obs_size)\n", " policy_class = params.policy_params.pop('policy_class')\n", " \n", " policy = policy_class(obs_size, num_actions, **params.policy_params)\n", " return env, rollouts, policy" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# hyperparameters\n", "policy_params = ParamDict(\n", " policy_class = Policy, # Policy class to use (replaced later) \n", " hidden_dim = 32, # dimension of the hidden state in actor network\n", " learning_rate = 1e-3, # learning rate of policy update\n", " batch_size = 1024, # batch size for policy update\n", " policy_epochs = 4, # number of epochs per policy update\n", " entropy_coef = 0.001, # hyperparameter to vary the contribution of entropy loss\n", ")\n", "params = ParamDict(\n", " policy_params = policy_params,\n", " rollout_size = 2050, # number of collected rollout steps per policy update\n", " num_updates = 50, # number of training policy iterations\n", " discount = 0.99, # discount factor\n", " plotting_iters = 10, # interval for logging graphs and policy rollouts\n", " env_name = 'MiniGrid-Empty-5x5-v0', # we are using a tiny environment here for testing\n", ")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training completed!\n" ] } ], "source": [ "env, rollouts, policy = instantiate(params)\n", "rewards, success_rate = train(env, rollouts, policy, params)\n", "print(\"Training completed!\")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total reward: 0.874\n", "Total length: 7\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total reward: 0.91\n", "Total length: 5\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total reward: 0.874\n", "Total length: 7\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# final reward + policy plotting for easier evaluation\n", "plot_learning_curve(rewards, success_rate, params.num_updates)\n", "for _ in range(3):\n", " log_policy_rollout(policy, params.env_name, pytorch_policy=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }