{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 405 DQN Reinforcement Learning\n", "\n", "View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/\n", "My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n", "More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/\n", "\n", "Dependencies:\n", "* torch: 0.1.11\n", "* gym: 0.8.1\n", "* numpy" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import gym" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2017-06-20 22:23:40,418] Making new env: CartPole-v0\n" ] } ], "source": [ "# Hyper Parameters\n", "BATCH_SIZE = 32\n", "LR = 0.01 # learning rate\n", "EPSILON = 0.9 # greedy policy\n", "GAMMA = 0.9 # reward discount\n", "TARGET_REPLACE_ITER = 100 # target update frequency\n", "MEMORY_CAPACITY = 2000\n", "env = gym.make('CartPole-v0')\n", "env = env.unwrapped\n", "N_ACTIONS = env.action_space.n\n", "N_STATES = env.observation_space.shape[0]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self, ):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(N_STATES, 10)\n", " self.fc1.weight.data.normal_(0, 0.1) # initialization\n", " self.out = nn.Linear(10, N_ACTIONS)\n", " self.out.weight.data.normal_(0, 0.1) # initialization\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " actions_value = self.out(x)\n", " return actions_value" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class DQN(object):\n", " def __init__(self):\n", " self.eval_net, self.target_net = Net(), Net()\n", "\n", " self.learn_step_counter = 0 # for target updating\n", " self.memory_counter = 0 # for storing memory\n", " self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory\n", " self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)\n", " self.loss_func = nn.MSELoss()\n", "\n", " def choose_action(self, x):\n", " x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))\n", " # input only one sample\n", " if np.random.uniform() < EPSILON: # greedy\n", " actions_value = self.eval_net.forward(x)\n", " action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax\n", " else: # random\n", " action = np.random.randint(0, N_ACTIONS)\n", " return action\n", "\n", " def store_transition(self, s, a, r, s_):\n", " transition = np.hstack((s, [a, r], s_))\n", " # replace the old memory with new memory\n", " index = self.memory_counter % MEMORY_CAPACITY\n", " self.memory[index, :] = transition\n", " self.memory_counter += 1\n", "\n", " def learn(self):\n", " # target parameter update\n", " if self.learn_step_counter % TARGET_REPLACE_ITER == 0:\n", " self.target_net.load_state_dict(self.eval_net.state_dict())\n", " self.learn_step_counter += 1\n", "\n", " # sample batch transitions\n", " sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)\n", " b_memory = self.memory[sample_index, :]\n", " b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))\n", " b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))\n", " b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))\n", " b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))\n", "\n", " # q_eval w.r.t the action in experience\n", " q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)\n", " q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate\n", " q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)\n", " loss = self.loss_func(q_eval, q_target)\n", "\n", " self.optimizer.zero_grad()\n", " loss.backward()\n", " self.optimizer.step()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dqn = DQN()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Collecting experience...\n", "Ep: 201 | Ep_r: 1.59\n", "Ep: 202 | Ep_r: 4.18\n", "Ep: 203 | Ep_r: 2.73\n", "Ep: 204 | Ep_r: 1.97\n", "Ep: 205 | Ep_r: 1.18\n", "Ep: 206 | Ep_r: 0.86\n", "Ep: 207 | Ep_r: 2.88\n", "Ep: 208 | Ep_r: 1.63\n", "Ep: 209 | Ep_r: 3.91\n", "Ep: 210 | Ep_r: 3.6\n", "Ep: 211 | Ep_r: 0.98\n", "Ep: 212 | Ep_r: 3.85\n", "Ep: 213 | Ep_r: 1.81\n", "Ep: 214 | Ep_r: 2.32\n", "Ep: 215 | Ep_r: 3.75\n", "Ep: 216 | Ep_r: 3.53\n", "Ep: 217 | Ep_r: 4.75\n", "Ep: 218 | Ep_r: 2.4\n", "Ep: 219 | Ep_r: 0.64\n", "Ep: 220 | Ep_r: 1.15\n", "Ep: 221 | Ep_r: 2.3\n", "Ep: 222 | Ep_r: 7.37\n", "Ep: 223 | Ep_r: 1.25\n", "Ep: 224 | Ep_r: 5.02\n", "Ep: 225 | Ep_r: 10.29\n", "Ep: 226 | Ep_r: 17.54\n", "Ep: 227 | Ep_r: 36.2\n", "Ep: 228 | Ep_r: 6.61\n", "Ep: 229 | Ep_r: 10.04\n", "Ep: 230 | Ep_r: 55.19\n", "Ep: 231 | Ep_r: 10.03\n", "Ep: 232 | Ep_r: 13.25\n", "Ep: 233 | Ep_r: 8.75\n", "Ep: 234 | Ep_r: 3.83\n", "Ep: 235 | Ep_r: -0.92\n", "Ep: 236 | Ep_r: 5.12\n", "Ep: 237 | Ep_r: 3.56\n", "Ep: 238 | Ep_r: 5.69\n", "Ep: 239 | Ep_r: 8.43\n", "Ep: 240 | Ep_r: 29.27\n", "Ep: 241 | Ep_r: 17.95\n", "Ep: 242 | Ep_r: 44.77\n", "Ep: 243 | Ep_r: 98.0\n", "Ep: 244 | Ep_r: 38.78\n", "Ep: 245 | Ep_r: 45.02\n", "Ep: 246 | Ep_r: 27.73\n", "Ep: 247 | Ep_r: 36.96\n", "Ep: 248 | Ep_r: 48.98\n", "Ep: 249 | Ep_r: 111.36\n", "Ep: 250 | Ep_r: 95.61\n", "Ep: 251 | Ep_r: 149.77\n", "Ep: 252 | Ep_r: 29.96\n", "Ep: 253 | Ep_r: 2.79\n", "Ep: 254 | Ep_r: 20.1\n", "Ep: 255 | Ep_r: 24.25\n", "Ep: 256 | Ep_r: 3074.75\n", "Ep: 257 | Ep_r: 1258.49\n", "Ep: 258 | Ep_r: 127.39\n", "Ep: 259 | Ep_r: 283.46\n", "Ep: 260 | Ep_r: 166.96\n", "Ep: 261 | Ep_r: 101.71\n", "Ep: 262 | Ep_r: 63.45\n", "Ep: 263 | Ep_r: 288.94\n", "Ep: 264 | Ep_r: 130.49\n", "Ep: 265 | Ep_r: 207.05\n", "Ep: 266 | Ep_r: 183.71\n", "Ep: 267 | Ep_r: 142.75\n", "Ep: 268 | Ep_r: 126.53\n", "Ep: 269 | Ep_r: 310.79\n", "Ep: 270 | Ep_r: 863.2\n", "Ep: 271 | Ep_r: 365.12\n", "Ep: 272 | Ep_r: 659.52\n", "Ep: 273 | Ep_r: 103.98\n", "Ep: 274 | Ep_r: 554.83\n", "Ep: 275 | Ep_r: 246.01\n", "Ep: 276 | Ep_r: 332.23\n", "Ep: 277 | Ep_r: 323.35\n", "Ep: 278 | Ep_r: 278.71\n", "Ep: 279 | Ep_r: 613.6\n", "Ep: 280 | Ep_r: 152.21\n", "Ep: 281 | Ep_r: 402.02\n", "Ep: 282 | Ep_r: 351.4\n", "Ep: 283 | Ep_r: 115.87\n", "Ep: 284 | Ep_r: 163.26\n", "Ep: 285 | Ep_r: 631.0\n", "Ep: 286 | Ep_r: 263.47\n", "Ep: 287 | Ep_r: 511.21\n", "Ep: 288 | Ep_r: 337.18\n", "Ep: 289 | Ep_r: 819.76\n", "Ep: 290 | Ep_r: 190.83\n", "Ep: 291 | Ep_r: 442.98\n", "Ep: 292 | Ep_r: 537.24\n", "Ep: 293 | Ep_r: 1101.12\n", "Ep: 294 | Ep_r: 178.42\n", "Ep: 295 | Ep_r: 225.61\n", "Ep: 296 | Ep_r: 252.62\n", "Ep: 297 | Ep_r: 617.5\n", "Ep: 298 | Ep_r: 617.8\n", "Ep: 299 | Ep_r: 244.01\n", "Ep: 300 | Ep_r: 687.91\n", "Ep: 301 | Ep_r: 618.51\n", "Ep: 302 | Ep_r: 1405.07\n", "Ep: 303 | Ep_r: 456.95\n", "Ep: 304 | Ep_r: 340.33\n", "Ep: 305 | Ep_r: 502.91\n", "Ep: 306 | Ep_r: 441.21\n", "Ep: 307 | Ep_r: 255.81\n", "Ep: 308 | Ep_r: 403.03\n", "Ep: 309 | Ep_r: 229.1\n", "Ep: 310 | Ep_r: 308.49\n", "Ep: 311 | Ep_r: 165.37\n", "Ep: 312 | Ep_r: 153.76\n", "Ep: 313 | Ep_r: 442.05\n", "Ep: 314 | Ep_r: 229.23\n", "Ep: 315 | Ep_r: 128.52\n", "Ep: 316 | Ep_r: 358.18\n", "Ep: 317 | Ep_r: 319.03\n", "Ep: 318 | Ep_r: 381.76\n", "Ep: 319 | Ep_r: 199.19\n", "Ep: 320 | Ep_r: 418.63\n", "Ep: 321 | Ep_r: 223.95\n", "Ep: 322 | Ep_r: 222.37\n", "Ep: 323 | Ep_r: 405.4\n", "Ep: 324 | Ep_r: 311.32\n", "Ep: 325 | Ep_r: 184.85\n", "Ep: 326 | Ep_r: 1026.71\n", "Ep: 327 | Ep_r: 252.41\n", "Ep: 328 | Ep_r: 224.93\n", "Ep: 329 | Ep_r: 620.02\n", "Ep: 330 | Ep_r: 174.54\n", "Ep: 331 | Ep_r: 782.45\n", "Ep: 332 | Ep_r: 263.79\n", "Ep: 333 | Ep_r: 178.63\n", "Ep: 334 | Ep_r: 242.84\n", "Ep: 335 | Ep_r: 635.43\n", "Ep: 336 | Ep_r: 668.89\n", "Ep: 337 | Ep_r: 265.42\n", "Ep: 338 | Ep_r: 207.81\n", "Ep: 339 | Ep_r: 293.09\n", "Ep: 340 | Ep_r: 530.23\n", "Ep: 341 | Ep_r: 479.26\n", "Ep: 342 | Ep_r: 559.77\n", "Ep: 343 | Ep_r: 241.39\n", "Ep: 344 | Ep_r: 158.83\n", "Ep: 345 | Ep_r: 1510.69\n", "Ep: 346 | Ep_r: 425.17\n", "Ep: 347 | Ep_r: 266.94\n", "Ep: 348 | Ep_r: 166.08\n", "Ep: 349 | Ep_r: 630.52\n", "Ep: 350 | Ep_r: 250.95\n", "Ep: 351 | Ep_r: 625.88\n", "Ep: 352 | Ep_r: 417.7\n", "Ep: 353 | Ep_r: 867.81\n", "Ep: 354 | Ep_r: 150.62\n", "Ep: 355 | Ep_r: 230.89\n", "Ep: 356 | Ep_r: 1017.52\n", "Ep: 357 | Ep_r: 190.28\n", "Ep: 358 | Ep_r: 396.91\n", "Ep: 359 | Ep_r: 305.53\n", "Ep: 360 | Ep_r: 131.61\n", "Ep: 361 | Ep_r: 387.54\n", "Ep: 362 | Ep_r: 298.82\n", "Ep: 363 | Ep_r: 207.56\n", "Ep: 364 | Ep_r: 248.56\n", "Ep: 365 | Ep_r: 589.12\n", "Ep: 366 | Ep_r: 179.52\n", "Ep: 367 | Ep_r: 130.19\n", "Ep: 368 | Ep_r: 1220.84\n", "Ep: 369 | Ep_r: 126.35\n", "Ep: 370 | Ep_r: 133.31\n", "Ep: 371 | Ep_r: 485.81\n", "Ep: 372 | Ep_r: 823.4\n", "Ep: 373 | Ep_r: 253.26\n", "Ep: 374 | Ep_r: 466.06\n", "Ep: 375 | Ep_r: 203.27\n", "Ep: 376 | Ep_r: 386.5\n", "Ep: 377 | Ep_r: 491.02\n", "Ep: 378 | Ep_r: 239.45\n", "Ep: 379 | Ep_r: 276.93\n", "Ep: 380 | Ep_r: 331.98\n", "Ep: 381 | Ep_r: 764.79\n", "Ep: 382 | Ep_r: 198.29\n", "Ep: 383 | Ep_r: 717.18\n", "Ep: 384 | Ep_r: 562.15\n", "Ep: 385 | Ep_r: 29.44\n", "Ep: 386 | Ep_r: 344.95\n", "Ep: 387 | Ep_r: 671.87\n", "Ep: 388 | Ep_r: 299.81\n", "Ep: 389 | Ep_r: 899.76\n", "Ep: 390 | Ep_r: 319.04\n", "Ep: 391 | Ep_r: 252.11\n", "Ep: 392 | Ep_r: 865.62\n", "Ep: 393 | Ep_r: 255.64\n", "Ep: 394 | Ep_r: 81.74\n", "Ep: 395 | Ep_r: 213.13\n", "Ep: 396 | Ep_r: 422.33\n", "Ep: 397 | Ep_r: 167.47\n", "Ep: 398 | Ep_r: 507.34\n", "Ep: 399 | Ep_r: 614.0\n" ] } ], "source": [ "\n", "print('\\nCollecting experience...')\n", "for i_episode in range(400):\n", " s = env.reset()\n", " ep_r = 0\n", " while True:\n", " env.render()\n", " a = dqn.choose_action(s)\n", "\n", " # take action\n", " s_, r, done, info = env.step(a)\n", "\n", " # modify the reward\n", " x, x_dot, theta, theta_dot = s_\n", " r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8\n", " r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5\n", " r = r1 + r2\n", "\n", " dqn.store_transition(s, a, r, s_)\n", "\n", " ep_r += r\n", " if dqn.memory_counter > MEMORY_CAPACITY:\n", " dqn.learn()\n", " if done:\n", " print('Ep: ', i_episode,\n", " '| Ep_r: ', round(ep_r, 2))\n", "\n", " if done:\n", " break\n", " s = s_" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }