{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-26-reinforce.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T365137%20%7C%20REINFORCE%20in%20PyTorch.ipynb","timestamp":1644673798189}],"collapsed_sections":[],"authorship_tag":"ABX9TyPvQry7G08zNChJqXqGydZl"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"EsWO9mLNJqqy"},"source":["# REINFORCE in PyTorch"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2xm8x19-HLxm","executionInfo":{"status":"ok","timestamp":1634908084389,"user_tz":-330,"elapsed":3601,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"c610fd77-c3d8-414e-8e6e-fc9c492fa089"},"source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -m -iv -u -t -d"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Last updated: 2021-10-22 13:08:10\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.104+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","gym : 0.17.3\n","numpy : 1.19.5\n","torch : 1.9.0+cu111\n","matplotlib: 3.2.2\n","IPython : 5.5.0\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"tUedeymSIaGe"},"source":["---"]},{"cell_type":"code","metadata":{"id":"gB58AmCQH5IP"},"source":["!pip install gym pyvirtualdisplay > /dev/null 2>&1\n","!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n","\n","!apt-get update > /dev/null 2>&1\n","!apt-get install cmake > /dev/null 2>&1\n","!pip install --upgrade setuptools 2>&1\n","!pip install ez_setup > /dev/null 2>&1\n","!pip install gym[atari] > /dev/null 2>&1\n","\n","!wget http://www.atarimania.com/roms/Roms.rar\n","!mkdir /content/ROM/\n","!unrar e /content/Roms.rar /content/ROM/\n","!python -m atari_py.import_roms /content/ROM/"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5hmUsRZNH5If"},"source":["import gym\n","from gym.wrappers import Monitor\n","import glob\n","import io\n","import base64\n","from IPython.display import HTML\n","from pyvirtualdisplay import Display\n","from IPython import display as ipythondisplay\n","\n","display = Display(visible=0, size=(1400, 900))\n","display.start()\n","\n","\"\"\"\n","Utility functions to enable video recording of gym environment \n","and displaying it.\n","To enable video, just do \"env = wrap_env(env)\"\"\n","\"\"\"\n","\n","def show_video():\n"," mp4list = glob.glob('video/*.mp4')\n"," if len(mp4list) > 0:\n"," mp4 = mp4list[0]\n"," video = io.open(mp4, 'r+b').read()\n"," encoded = base64.b64encode(video)\n"," ipythondisplay.display(HTML(data=''''''.format(encoded.decode('ascii'))))\n"," else: \n"," print(\"Could not find video\")\n"," \n","\n","def wrap_env(env):\n"," env = Monitor(env, './video', force=True)\n"," return env"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XxINAZsiHLp0"},"source":["---"]},{"cell_type":"code","metadata":{"id":"vaxMAMEFE4R8"},"source":["import gym\n","import numpy as np\n","from collections import deque\n","import matplotlib.pyplot as plt\n","%matplotlib inline\n","\n","import torch\n","torch.manual_seed(0) # set random seed\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from torch.distributions import Categorical"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hjW7DIxtFP8v","executionInfo":{"status":"ok","timestamp":1634908503685,"user_tz":-330,"elapsed":10,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"f0eafef1-cdb0-4c91-90a5-cc85f7a4984e"},"source":["env = gym.make('CartPole-v0')\n","env.seed(0)\n","print('observation space:', env.observation_space)\n","print('action space:', env.action_space)\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","class Policy(nn.Module):\n"," def __init__(self, s_size=4, h_size=16, a_size=2):\n"," super(Policy, self).__init__()\n"," self.fc1 = nn.Linear(s_size, h_size)\n"," self.fc2 = nn.Linear(h_size, a_size)\n","\n"," def forward(self, x):\n"," x = F.relu(self.fc1(x))\n"," x = self.fc2(x)\n"," return F.softmax(x, dim=1)\n"," \n"," def act(self, state):\n"," state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n"," probs = self.forward(state).cpu()\n"," m = Categorical(probs)\n"," action = m.sample()\n"," return action.item(), m.log_prob(action)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["observation space: Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)\n","action space: Discrete(2)\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zGVueeA_Fn7b","executionInfo":{"status":"ok","timestamp":1634908550599,"user_tz":-330,"elapsed":46921,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"331eeab6-adaa-407a-9217-f4452beaa085"},"source":["policy = Policy().to(device)\n","optimizer = optim.Adam(policy.parameters(), lr=1e-2)\n","\n","def reinforce(n_episodes=1000, max_t=1000, gamma=1.0, print_every=100):\n"," scores_deque = deque(maxlen=100)\n"," scores = []\n"," for i_episode in range(1, n_episodes+1):\n"," saved_log_probs = []\n"," rewards = []\n"," state = env.reset()\n"," for t in range(max_t):\n"," action, log_prob = policy.act(state)\n"," saved_log_probs.append(log_prob)\n"," state, reward, done, _ = env.step(action)\n"," rewards.append(reward)\n"," if done:\n"," break \n"," scores_deque.append(sum(rewards))\n"," scores.append(sum(rewards))\n"," \n"," discounts = [gamma**i for i in range(len(rewards)+1)]\n"," R = sum([a*b for a,b in zip(discounts, rewards)])\n"," \n"," policy_loss = []\n"," for log_prob in saved_log_probs:\n"," policy_loss.append(-log_prob * R)\n"," policy_loss = torch.cat(policy_loss).sum()\n"," \n"," optimizer.zero_grad()\n"," policy_loss.backward()\n"," optimizer.step()\n"," \n"," if i_episode % print_every == 0:\n"," print('Episode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))\n"," if np.mean(scores_deque)>=195.0:\n"," print('Environment solved in {:d} episodes!\\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_deque)))\n"," break\n"," \n"," return scores\n"," \n","scores = reinforce()"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Episode 100\tAverage Score: 34.47\n","Episode 200\tAverage Score: 66.26\n","Episode 300\tAverage Score: 87.82\n","Episode 400\tAverage Score: 72.83\n","Episode 500\tAverage Score: 172.00\n","Episode 600\tAverage Score: 160.65\n","Episode 700\tAverage Score: 167.15\n","Environment solved in 691 episodes!\tAverage Score: 196.69\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":279},"id":"Zyp30sNGHHHf","executionInfo":{"status":"ok","timestamp":1634908550601,"user_tz":-330,"elapsed":23,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"5076a0db-d879-422d-8fc9-60d7b96d7df9"},"source":["fig = plt.figure()\n","ax = fig.add_subplot(111)\n","plt.plot(np.arange(1, len(scores)+1), scores)\n","plt.ylabel('Score')\n","plt.xlabel('Episode #')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":421},"id":"-kQt576BHvxZ","executionInfo":{"status":"ok","timestamp":1634908595120,"user_tz":-330,"elapsed":4594,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9b2cd399-081b-4ed2-a6b5-60067c07d9dc"},"source":["env = wrap_env(gym.make('CartPole-v0'))\n","\n","state = env.reset()\n","for t in range(1000):\n"," action, _ = policy.act(state)\n"," env.render()\n"," state, reward, done, _ = env.step(action)\n"," if done:\n"," break \n","\n","env.close()\n","show_video()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/html":[""],"text/plain":[""]},"metadata":{}}]}]}