{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## ATARI Asteroids DQN_gym with keras-rl" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n", "Using TensorFlow backend.\n" ] } ], "source": [ "import numpy as no\n", "import gym\n", "from keras.models import Sequential\n", "from keras.layers import Dense, Activation, Flatten\n", "from keras.optimizers import Adam\n", "\n", "from rl.agents.dqn import DQNAgent\n", "from rl.agents.ddpg import DDPGAgent\n", "from rl.policy import BoltzmannQPolicy , LinearAnnealedPolicy , EpsGreedyQPolicy\n", "from rl.memory import SequentialMemory" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "ENV_NAME_2 = 'Asteroids-v0'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "14" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the environment and extract the number of actions\n", "env = gym.make(ENV_NAME_2)\n", "nb_actions = env.action_space.n\n", "nb_actions" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "flatten_1 (Flatten) (None, 100800) 0 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 3) 302403 \n", "_________________________________________________________________\n", "dense_2 (Dense) (None, 14) 56 \n", "_________________________________________________________________\n", "activation_1 (Activation) (None, 14) 0 \n", "=================================================================\n", "Total params: 302,459\n", "Trainable params: 302,459\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "# Next, we build a neural network model\n", "model = Sequential()\n", "model.add(Flatten(input_shape=(1,) + env.observation_space.shape))\n", "model.add(Dense(3, activation= 'tanh')) # One layer of 3 units with tanh activation function \n", "model.add(Dense(nb_actions))\n", "model.add(Activation('sigmoid')) # one layer of 1 unit with sigmoid activation function\n", "print(model.summary())" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#DQN -- Deep Reinforcement Learning \n", "\n", "#Configure and compile the agent. \n", "#Use every built-in Keras optimizer and metrics!\n", "memory = SequentialMemory(limit=20000, window_length=1)\n", "policy = BoltzmannQPolicy()\n", "dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,\n", " target_model_update=1e-2, policy=policy)\n", "dqn.compile(Adam(lr=1e-3), metrics=['mae', 'acc'])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training for 100000 steps ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/rl/memory.py:29: UserWarning: Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!\n", " warnings.warn('Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 1455/100000: episode: 1, duration: 46.233s, episode steps: 1455, steps per second: 31, episode reward: 1350.000, mean reward: 0.928 [0.000, 100.000], mean action: 6.509 [0.000, 13.000], mean observation: 2.139 [0.000, 240.000], loss: 21.742520, mean_absolute_error: 0.748284, acc: 0.079748, mean_q: 0.865032\n", " 2423/100000: episode: 2, duration: 30.502s, episode steps: 968, steps per second: 32, episode reward: 630.000, mean reward: 0.651 [0.000, 100.000], mean action: 6.411 [0.000, 13.000], mean observation: 2.337 [0.000, 240.000], loss: 36.164433, mean_absolute_error: 0.916387, acc: 0.075252, mean_q: 0.964510\n", " 3192/100000: episode: 3, duration: 23.303s, episode steps: 769, steps per second: 33, episode reward: 780.000, mean reward: 1.014 [0.000, 100.000], mean action: 6.113 [0.000, 13.000], mean observation: 1.996 [0.000, 240.000], loss: 32.863270, mean_absolute_error: 0.941716, acc: 0.073716, mean_q: 0.978775\n", " 4037/100000: episode: 4, duration: 25.042s, episode steps: 845, steps per second: 34, episode reward: 880.000, mean reward: 1.041 [0.000, 100.000], mean action: 6.650 [0.000, 13.000], mean observation: 2.124 [0.000, 240.000], loss: 33.935211, mean_absolute_error: 0.961146, acc: 0.068861, mean_q: 0.984799\n", " 4472/100000: episode: 5, duration: 12.989s, episode steps: 435, steps per second: 33, episode reward: 330.000, mean reward: 0.759 [0.000, 50.000], mean action: 6.377 [0.000, 13.000], mean observation: 2.465 [0.000, 240.000], loss: 31.147161, mean_absolute_error: 0.964446, acc: 0.073851, mean_q: 0.988034\n", " 6292/100000: episode: 6, duration: 56.351s, episode steps: 1820, steps per second: 32, episode reward: 1180.000, mean reward: 0.648 [0.000, 100.000], mean action: 6.511 [0.000, 13.000], mean observation: 1.940 [0.000, 240.000], loss: 31.117493, mean_absolute_error: 0.973881, acc: 0.072373, mean_q: 0.991440\n", " 7098/100000: episode: 7, duration: 24.283s, episode steps: 806, steps per second: 33, episode reward: 780.000, mean reward: 0.968 [0.000, 100.000], mean action: 6.561 [0.000, 13.000], mean observation: 2.627 [0.000, 240.000], loss: 28.760992, mean_absolute_error: 0.976429, acc: 0.070720, mean_q: 0.994323\n", " 8897/100000: episode: 8, duration: 53.421s, episode steps: 1799, steps per second: 34, episode reward: 1180.000, mean reward: 0.656 [0.000, 100.000], mean action: 6.430 [0.000, 13.000], mean observation: 1.981 [0.000, 240.000], loss: 27.216291, mean_absolute_error: 0.976294, acc: 0.074833, mean_q: 0.996801\n", " 10784/100000: episode: 9, duration: 56.207s, episode steps: 1887, steps per second: 34, episode reward: 1320.000, mean reward: 0.700 [0.000, 100.000], mean action: 6.379 [0.000, 13.000], mean observation: 1.796 [0.000, 240.000], loss: 31.204847, mean_absolute_error: 0.985795, acc: 0.070830, mean_q: 0.998568\n", " 11988/100000: episode: 10, duration: 35.985s, episode steps: 1204, steps per second: 33, episode reward: 1180.000, mean reward: 0.980 [0.000, 100.000], mean action: 6.602 [0.000, 13.000], mean observation: 1.731 [0.000, 240.000], loss: 30.669523, mean_absolute_error: 0.986347, acc: 0.074517, mean_q: 0.999281\n", " 13541/100000: episode: 11, duration: 46.164s, episode steps: 1553, steps per second: 34, episode reward: 1080.000, mean reward: 0.695 [0.000, 100.000], mean action: 6.574 [0.000, 13.000], mean observation: 1.633 [0.000, 240.000], loss: 28.269192, mean_absolute_error: 0.983127, acc: 0.073487, mean_q: 0.999560\n", " 14309/100000: episode: 12, duration: 23.052s, episode steps: 768, steps per second: 33, episode reward: 580.000, mean reward: 0.755 [0.000, 100.000], mean action: 6.793 [0.000, 13.000], mean observation: 2.165 [0.000, 240.000], loss: 28.651321, mean_absolute_error: 0.982694, acc: 0.075562, mean_q: 0.999704\n", " 14855/100000: episode: 13, duration: 16.280s, episode steps: 546, steps per second: 34, episode reward: 430.000, mean reward: 0.788 [0.000, 100.000], mean action: 6.255 [0.000, 13.000], mean observation: 2.600 [0.000, 240.000], loss: 25.917961, mean_absolute_error: 0.979710, acc: 0.075321, mean_q: 0.999771\n", " 15676/100000: episode: 14, duration: 24.421s, episode steps: 821, steps per second: 34, episode reward: 780.000, mean reward: 0.950 [0.000, 100.000], mean action: 6.279 [0.000, 13.000], mean observation: 2.161 [0.000, 240.000], loss: 29.796518, mean_absolute_error: 0.985070, acc: 0.073196, mean_q: 0.999817\n", " 17303/100000: episode: 15, duration: 48.080s, episode steps: 1627, steps per second: 34, episode reward: 1320.000, mean reward: 0.811 [0.000, 100.000], mean action: 6.373 [0.000, 13.000], mean observation: 1.493 [0.000, 240.000], loss: 31.838648, mean_absolute_error: 0.989215, acc: 0.072065, mean_q: 0.999859\n", " 18249/100000: episode: 16, duration: 27.618s, episode steps: 946, steps per second: 34, episode reward: 880.000, mean reward: 0.930 [0.000, 150.000], mean action: 6.580 [0.000, 13.000], mean observation: 1.984 [0.000, 240.000], loss: 33.025875, mean_absolute_error: 0.990334, acc: 0.072509, mean_q: 0.999892\n", " 21461/100000: episode: 17, duration: 93.989s, episode steps: 3212, steps per second: 34, episode reward: 1880.000, mean reward: 0.585 [0.000, 100.000], mean action: 6.588 [0.000, 13.000], mean observation: 1.762 [0.000, 240.000], loss: 30.154207, mean_absolute_error: 0.987160, acc: 0.070809, mean_q: 0.999936\n", " 22917/100000: episode: 18, duration: 42.662s, episode steps: 1456, steps per second: 34, episode reward: 1180.000, mean reward: 0.810 [0.000, 100.000], mean action: 6.620 [0.000, 13.000], mean observation: 1.905 [0.000, 240.000], loss: 28.899130, mean_absolute_error: 0.983624, acc: 0.068402, mean_q: 0.999970\n", " 23238/100000: episode: 19, duration: 9.459s, episode steps: 321, steps per second: 34, episode reward: 430.000, mean reward: 1.340 [0.000, 100.000], mean action: 6.533 [0.000, 13.000], mean observation: 2.253 [0.000, 240.000], loss: 26.920767, mean_absolute_error: 0.980413, acc: 0.075058, mean_q: 0.999978\n", " 23586/100000: episode: 20, duration: 10.229s, episode steps: 348, steps per second: 34, episode reward: 160.000, mean reward: 0.460 [0.000, 50.000], mean action: 6.871 [0.000, 13.000], mean observation: 2.307 [0.000, 240.000], loss: 25.303034, mean_absolute_error: 0.980627, acc: 0.065823, mean_q: 0.999980\n", " 26189/100000: episode: 21, duration: 78.370s, episode steps: 2603, steps per second: 33, episode reward: 1320.000, mean reward: 0.507 [0.000, 100.000], mean action: 6.503 [0.000, 13.000], mean observation: 1.784 [0.000, 240.000], loss: 27.165476, mean_absolute_error: 0.981853, acc: 0.070652, mean_q: 0.999986\n", " 27369/100000: episode: 22, duration: 34.425s, episode steps: 1180, steps per second: 34, episode reward: 980.000, mean reward: 0.831 [0.000, 100.000], mean action: 6.385 [0.000, 13.000], mean observation: 1.941 [0.000, 240.000], loss: 27.088072, mean_absolute_error: 0.980287, acc: 0.067505, mean_q: 0.999992\n", " 28548/100000: episode: 23, duration: 35.070s, episode steps: 1179, steps per second: 34, episode reward: 980.000, mean reward: 0.831 [0.000, 100.000], mean action: 6.528 [0.000, 13.000], mean observation: 2.060 [0.000, 240.000], loss: 24.597326, mean_absolute_error: 0.977471, acc: 0.068119, mean_q: 0.999994\n", " 28919/100000: episode: 24, duration: 11.016s, episode steps: 371, steps per second: 34, episode reward: 480.000, mean reward: 1.294 [0.000, 100.000], mean action: 6.625 [0.000, 13.000], mean observation: 2.276 [0.000, 240.000], loss: 30.021408, mean_absolute_error: 0.985281, acc: 0.075219, mean_q: 0.999995\n", " 30695/100000: episode: 25, duration: 52.892s, episode steps: 1776, steps per second: 34, episode reward: 1300.000, mean reward: 0.732 [0.000, 100.000], mean action: 6.465 [0.000, 13.000], mean observation: 1.728 [0.000, 240.000], loss: 26.489769, mean_absolute_error: 0.980855, acc: 0.071087, mean_q: 0.999996\n", " 32084/100000: episode: 26, duration: 41.220s, episode steps: 1389, steps per second: 34, episode reward: 930.000, mean reward: 0.670 [0.000, 100.000], mean action: 6.603 [0.000, 13.000], mean observation: 1.972 [0.000, 240.000], loss: 28.961477, mean_absolute_error: 0.984377, acc: 0.070082, mean_q: 0.999997\n", " 33770/100000: episode: 27, duration: 50.046s, episode steps: 1686, steps per second: 34, episode reward: 1350.000, mean reward: 0.801 [0.000, 100.000], mean action: 6.718 [0.000, 13.000], mean observation: 1.699 [0.000, 240.000], loss: 30.137896, mean_absolute_error: 0.985636, acc: 0.070192, mean_q: 0.999998\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 35502/100000: episode: 28, duration: 51.162s, episode steps: 1732, steps per second: 34, episode reward: 1760.000, mean reward: 1.016 [0.000, 100.000], mean action: 6.483 [0.000, 13.000], mean observation: 1.864 [0.000, 240.000], loss: 29.654802, mean_absolute_error: 0.986030, acc: 0.069302, mean_q: 0.999999\n", " 36281/100000: episode: 29, duration: 23.153s, episode steps: 779, steps per second: 34, episode reward: 610.000, mean reward: 0.783 [0.000, 100.000], mean action: 6.589 [0.000, 13.000], mean observation: 2.412 [0.000, 240.000], loss: 27.241671, mean_absolute_error: 0.980745, acc: 0.069480, mean_q: 0.999999\n", " 37151/100000: episode: 30, duration: 25.850s, episode steps: 870, steps per second: 34, episode reward: 830.000, mean reward: 0.954 [0.000, 100.000], mean action: 6.493 [0.000, 13.000], mean observation: 2.012 [0.000, 240.000], loss: 28.615261, mean_absolute_error: 0.984251, acc: 0.071480, mean_q: 0.999999\n", " 38717/100000: episode: 31, duration: 46.353s, episode steps: 1566, steps per second: 34, episode reward: 1180.000, mean reward: 0.754 [0.000, 100.000], mean action: 6.408 [0.000, 13.000], mean observation: 1.558 [0.000, 240.000], loss: 27.654753, mean_absolute_error: 0.983781, acc: 0.072777, mean_q: 0.999999\n", " 39922/100000: episode: 32, duration: 35.704s, episode steps: 1205, steps per second: 34, episode reward: 1080.000, mean reward: 0.896 [0.000, 100.000], mean action: 6.408 [0.000, 13.000], mean observation: 1.722 [0.000, 240.000], loss: 27.877979, mean_absolute_error: 0.982937, acc: 0.068361, mean_q: 0.999999\n", " 41911/100000: episode: 33, duration: 58.841s, episode steps: 1989, steps per second: 34, episode reward: 1560.000, mean reward: 0.784 [0.000, 100.000], mean action: 6.626 [0.000, 13.000], mean observation: 2.046 [0.000, 240.000], loss: 27.723555, mean_absolute_error: 0.982845, acc: 0.071220, mean_q: 0.999999\n", " 42709/100000: episode: 34, duration: 23.760s, episode steps: 798, steps per second: 34, episode reward: 880.000, mean reward: 1.103 [0.000, 100.000], mean action: 6.267 [0.000, 13.000], mean observation: 2.308 [0.000, 240.000], loss: 27.761700, mean_absolute_error: 0.982930, acc: 0.071429, mean_q: 0.999999\n", " 43787/100000: episode: 35, duration: 32.243s, episode steps: 1078, steps per second: 33, episode reward: 780.000, mean reward: 0.724 [0.000, 100.000], mean action: 6.662 [0.000, 13.000], mean observation: 2.266 [0.000, 240.000], loss: 27.965694, mean_absolute_error: 0.983301, acc: 0.072182, mean_q: 0.999999\n", " 46754/100000: episode: 36, duration: 88.179s, episode steps: 2967, steps per second: 34, episode reward: 2080.000, mean reward: 0.701 [0.000, 100.000], mean action: 6.551 [0.000, 13.000], mean observation: 2.167 [0.000, 240.000], loss: 30.416687, mean_absolute_error: 0.986695, acc: 0.070147, mean_q: 1.000000\n", " 48099/100000: episode: 37, duration: 40.026s, episode steps: 1345, steps per second: 34, episode reward: 1130.000, mean reward: 0.840 [0.000, 100.000], mean action: 6.546 [0.000, 13.000], mean observation: 1.729 [0.000, 240.000], loss: 32.340946, mean_absolute_error: 0.989692, acc: 0.068425, mean_q: 1.000000\n", " 48561/100000: episode: 38, duration: 13.956s, episode steps: 462, steps per second: 33, episode reward: 530.000, mean reward: 1.147 [0.000, 100.000], mean action: 6.578 [0.000, 13.000], mean observation: 2.210 [0.000, 240.000], loss: 25.409176, mean_absolute_error: 0.980768, acc: 0.071361, mean_q: 1.000000\n", "done, took 1479.706 seconds\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "## Visualize the training during 500000 steps \n", "dqn.fit(env, nb_steps=500000, visualize=True, verbose=2)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XeYFGW2BvD3MCQJMggjIoKAgghcBZkrIKwCKiK4oq6KAeUqLK673kXxyop4Tat3xQDIYkIMsJgQA8KKEhUTYUCigAySGZghR4GZOfePU033wIQOVd09U+/vefrp6qquqtM1U3Wq6gslqgoiIvKfcokOgIiIEoMJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKfCTgAikiIiP4nIFOdzIxGZJyKZIvKhiFT0LkwiInJbJFcAAwCsDPk8FMBwVT0XwG4Afd0MjIiIvBVWAhCRswD0ADDG+SwAugCY6HxlLIDrvAiQiIi8UT7M740AMAhAdedzLQB7VDXX+bwZQL3CZhSR/gD6A0DVqlXbNGvWLPpoiYh8aOHChTtUNc3t5ZaYAETkGgDZqrpQRDpFugJVHQ1gNACkp6drRkZGxEESEfmZiGzwYrnhXAF0AHCtiHQHUBnAqQBeApAqIuWdq4CzAGzxIkAiIvJGiWUAqjpYVc9S1YYAbgEwS1VvBzAbwI3O1/oAmORZlERE5LpY2gH8DcBAEcmElQm86U5IREQUD+EWAgMAVPVrAF87w78CuNj9kIiIKB7YEpiIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp0pXAjhyBDh6NNFREBGVCaUnAbz/PlC5MvDvfyc6EiKiMqH0JIARI+z9iy8SGwcRURlRYgIQkcoiMl9ElojIChF50hn/joisE5HFzquVp5Hu3m3vU6cCqp6uiojID8J5JvARAF1U9YCIVADwnYhMdaY9pKoTvQsvRE4OUKUKsGULsGwZcMEFcVktEVFZVeIVgJoDzscKziu+p+BHjwJ79gB33mmfp04t/vtERFSisMoARCRFRBYDyAYwXVXnOZOeEZGlIjJcRCp5FuWOHfZ+4YXAyJFA9+6erYqIyC/CuQUEVc0D0EpEUgF8KiItAQwGsA1ARQCjAfwNwFMnzisi/QH0B4AGDRpEF2WNGsBnn9ltn0aNolsGEREVEFEtIFXdA2A2gG6qmuXcHjoC4G0AFxcxz2hVTVfV9LS0tOiirFoV6NnTDv65uXYL6KefolsWEREBCK8WUJpz5g8ROQXAlQBWiUhdZ5wAuA7Acs+iXLfODvq//Wafb70VGDXKs9UREflBOLeA6gIYKyIpsIQxQVWniMgsEUkDIAAWA/iTZ1FOngwMGGBlAbVqAV27BquDini2WiKisqzEBKCqSwG0LmR8F08iKkxODpCSAtSsaZ+7dwc++ghYsgRoVUjzg9WrgYoVWV5ARFSM0tESOCfHzvzLOeF262bvhVUH/eYboHVr4Pbb4xcfEVEpVHoSQGgB8hlnABddBPzwQ8HvqQLPPGNlBQsWAIcOxTdOIqJSpHQkgOzsggkAsE7hJk0qOE4E+PhjYMwYqy20YEH8YiQiKmVKRwJ49VVg2LCC4844I3hL6KefgBtuAA4cAKpXB66/HkhPt+6jiYioUGE1BEu4li0LHz94sFURnT4dqFbNOoyrVs0Ki3n2T0RUrOS/AsjNBd54A1i16uRp69cDH34InHIKMHs2UL9+wenHjgH5+XEJkwDs3WvveXnAs88C+/YlNh4iKlbyJ4AdO4D+/YGZM0+eds89QNu2wKxZQOPGBad99RWQmgqsWBGfOP1uxQqgaVNg3Dhg7lxgyBCrjTVvXsnzElFCJH8CyMmx98K6kejUyQ42TZuePK1JE6sF9N13noZHsKuzyy+3thrt2wMdOgBz5tjVW8eOdjXAKzGipFO6E0BxGjWyguLvv3c/Jgpaswbo0sWq4M6aZYkXsCSweLEVyA8eDPTtm9g4iZLBsGHAyy8nOorjkr8QOJAATj89svlE7OyTCcA7e/bYwf/YMeDrr4FmzQpOr1nTymi6dgXOP9/GsfsO8qvXXwcefNC6tf/LXxIdDYCyfAUA2Fno+vX2FDFyX2oq8NBDwIwZQIsWhX9HBOjXz/4WgH3//vtZRZf8ZdUqO+iLBJ9vkgSS/wrgzjuBSy8FateOfN5u3awmSvnk/5mlyqZN1jivTRvgr38Nfz5VqyH00kt2xfDBBydfNRCVRc2aAe+8YyejCxcmOprjROP4gPX09HTNyMiI2/qO27kTeP55O2ClpwMNG/I2RLS2bLHC96NH7f5/xYqRL2PKFOCuu6yQfuRI4O67+fegsikry06WLrwwpsWIyEJVTXcpquOS/xbQhAl2phiLVaus8OXmm626aO3awJVXAoFklJtrZ6dUvKwsu+e/fbvd24/m4A8A11xjPbm2bw/cdx+wcaO7cRIlg4MHgWuvtTsRhw8nOppCJX8CePll6woiFh06APv3W+vg114D/vAHYNcuoJLzGON337Wk0LUr8MgjwNKlsccdqUWLgBdfBDZsiP+6w7F9u1X13LLFemFt1y625Z15JjBtmnXod/bZNu6XX2KPkygZ5OVZj8SLFllD1lNOsfGzZ9v//qJFiY0vQFXj9mrTpo1G7PzzVf/wh8jni8S336r266faurVq+fKqtWurHj7s7TpD/fab6jnnqNp1iOqll6q+91781h+OgQNVq1RR/eYbb5Y/YYJqSorqU0+p5uZ6sw6ieBk40PblkSMLjv/+exv/5ZcRLQ5AhnpwTE7+K4DCegJ1W8eOlqUXLQK+/NJK6T/5xNt1hlq50rpReOcd4O9/B7Zts0JSwFLCtGlW1TKR/vEPa1R36aXeLP+qq4BevYDHHrMrjc2bvVkPkdcmTbJbzv/93/YKFajMkiw1gbzIKkW9Ir4CyM1VFVH93/+NbL5Y5OerfvKJnZXH0/79BWM4eNCG5861M4a0NNVHH7Vp8bJrl2rv3qrZ2fFZX36+6tixqlWrqtasqfr55/FZL5Gbjh5VHTWq8CvZnTttfx4xIqJFwpdXADt32hmw11cAoUSs9WqgfMBrM2bY/cJq1QrGUKWKDbdubWcU7dsDTz9tZ+LxsGePlYlMmAAsXx6fdYpYtd+ffrLCerYVoNLk55/tjkWFClbnPyXl5O+kplo39klyBVBiAhCRyiIyX0SWiMgKEXnSGd9IROaJSKaIfCgiUVYJKUZamm2o//ov1xddohEjgCef9HYdP/5otZFGjiz6OxUrWk2Czz4DbrsNePRR4IsvvI1r3z6rubBkid0K69zZ2/WdqEkT60Tuxhvt83vvxS8JEUVK1Vr5tm1r/7PF1SgsVw7o0wdo3jx+8RWnpEsEAAKgmjNcAcA8AO0ATABwizP+NQD3lrSsqAqBE+Wuu+xWxN693iw/L081PV31zDML3v4pzsGDqt26WaG1V/btU73kEisM/+wz79YTrsOHVRs0UK1cWfWf/4zvLTAvhPu39tKQIarDh9vtxSNHEh1N6bZhg+qVV9ptncsvV12/3pPVwKNbQJF9GagCYBGAtgB2ACjvjG8P4KuS5o84AWRkqD7yiOqOHZHN54YFC2zzjBrlzfLffNOWP3589MvIy3MvnoCtW1VbtFCdONH9ZUdr2zbVq6+27dWtm2pWVqIjis7mzfYbnn46vus9dkx1xgwbzstTbdJEj9c4q1xZtWNHK3uhyMydq1q9up0ovvpqZCcnEZ7IJDQBAEgBsBjAAQBDAdQGkBkyvT6A5UXM2x9ABoCMBg0aRPSj9eWXLcRE7fDp6arNm7t/1rlnj+rpp9uZdrTLfuwx1Rtv9OaM+Ngx95cZq/x8+3+oXNm23Z49iY4oct98Ezzwxiv+r75SbdnS1rlsWXD85s2qH32k+sADqu3aqQ4bZuNzcqxKcu/etr2XLIlPnKVJ4MTr8GHVvn1Vf/01svl791a94IKIZkmWK4BUALMBdAw3AYS+Ir4CeOIJCzFRB6S33rL1u133fcUKO8vOyIh+GS+8YLE984w7Me3fr3rPPcm/w//8s+orrwQ/J2OyKsrEifY3GzrU+3WtWBG8amrcWPXjj8M7WVi7VvX661XPOCOYrMaM8T7e0iA/344JrVrFdiuvXz/bvhHwKgFEVAtIVfc4CaA9gFQRCfSydhYA97vczMkBTjstcZ259eplhTqBVnxuad7cWhu3aRP9MgYOdLdQeM4cK8jKzo59WV46/3zg3nttePZs25Zz5yY2pnAFera94w5796ptx8GDwCWXWCvrF16w2ik33BBef0uNG1vB/9at9rzt1q2BUaO8ibM02bLFujC5+26gevXYHndau7ZVbtHEdz8TTi2gNBFJdYZPAXAlgJWwROBU00AfAJNcjy4nJ75VQE9UpQrw0UfAf/6nO8tTBcaMsX+ecjHWwBWxxmsXXGCJIDMztuXNmmVVXwPdNpcGlSpZp3QdOwJPPGF9OiWznj2tUV9amtX8at0a+O03d5Z95Ig9jlMVqFrV+mrKzLT+56Op0ixinSa+/bY1jvQrVduuLVvaCUegJ9szz4x+mbVr2/9qEjwzO5yjUF0As0VkKYAFAKar6hQAfwMwUEQyAdQC8Kbr0e3endgEELBhQ7Blbiz+/W/gj38Exo6NfVmAJahPP7X3xYtjW9bMmXbW6PbVjpcuucSqqt52m1XZ7dgx9kTopbp1rdpv+fLAeefZc5Sfey725X7+uV0Z9ekDfPutjbvqqui6UD/RhRcCderEvpzSShV46y173sWSJdb9eawnb8nUGtiL+0pFvSIuA8jPj2+fPEW58krV+vVju9/822+q556r2qyZ+1XvAq2Go5WTY/d6//53d+JJhA8+UE1NjbiFZVxNm2avgF69VCtVUv3ll+iX+f33quXKWUHv9Omxx1iYefNUu3ZV3b3bm+Unm/x81XfftRpxqtZ6183+qRYtUh0wQHXLlrBnQTIUAsf6KlXtAEJNmmSbqksX1U2bolvG0KEaTSdQEfngAyscPtHWrap/+pPtyIVZsMAS3A8/eBdbPGRleVM11i2dO6t26BD8vHWr6qmnql5xRXS1ufbvtwLehg2t/YZXFi7UEqtEBw6a331XuttqZGWp9uxpv/fhhxMdzXH+SwB5eap9+qhOnRr+PF7Jz1cdPdrq+6amqr7/fmTzb92qWq2a6u9/7018AbfdZn0nffGFnbUsXGjjDx60P/VttxU9b35+6d5xA/Lzk/dMtUULq2ETatQoa3S3eHHky5s/3/qImjPHnfiK06aN6n/8R9H/I6NG6fFaQ/37ex+P2/LzrQfe006zq7Lnn/e2V9rDh1UPHQr76/5LAIHbEsl0Sb9mjWrbtqp//GNk861bp9qjh83vpYMHrYpa1ar2T9y0aXCHvekm1Xr1Ct+By8KBP6BjR9Wrrkp0FIVLSzv54Jibq7pyZfTLPHAgtpjC9frrtj/OnXvytO+/tyTWo4dVk5w928Zv2KB67bXWovzo0fjEGa2XXrLf17ZtbH+PcOzZY+sq7Gq9CF4lgOTtDC6Wh8F75dxzrUvkl16yz4sXW82AkjRsaI9BPPdcT8M7XijcogXQt6/VYApU/evc2aqy/fprwXk2bQLq1bOHvJQFTZvaM1c18VXsCsjLs84NTz+94PiUlOBzkVevDm9ZO3YA//ynLbNqVXfjLMqtt9q6Ro8uOD4vzx7vefbZwPjxNtypk01bswaYPx+47jqgfn1g0CDr+jyZ7N9v77ffDgwfbvu318+pPvVUqwiQBIXATACRKl8+WFPmiSes7/qHHiq858r8fHvCWDyf8tWwoXWk9vLLVkU0INCP/zffFPz+zJn2qMd69eIWoqfatLEda9OmREdS0K5d9v9Q1P/zuHFWk2fevOKXo2rtIB580A6w8VK9OjB4sHV4FiolxdoNfPqp9XQZ6vLL7e8webLV2Bo+3P4nd+606fPnW4eIGzbEp+fXRYusrcvAgVanv3Fje8Rpbi5QqxZw//3xaXMkEmwLkGheXFYU9YroFlCg1WQ090bj5cAB1XvvtTgvuEB16dKC099+26YlQz8r+fnWdcTkyQXH9+5ttyaSufA0EoHnJ3z8caIjKejYMWudu3174dP37rWOAVu1Kr622fjx9vv+8Q9v4ozEjz9Gdvtw+3Z71kbApZcGyw0A1Vq1VG+4ITg9Jye2+I4cscoPgRh797b1nHKK7a833WRP7EpEa/KWLVWvuy7sr8N3ZQDjx1tHSxFUlUqYKVOsf5qKFe1+qKrt0HXqWD8ryXpwzc+3g06vXomOxD2HDtn96EceSXQkkfvoI9slhw8vfPrGjao1algfUol6bOaBA/b4zgkTNObODH/+2SosjBljVZDvvVf18cdt2tGjqnXr2m8dPz7yBzTl5aneeqvFuHmzjVu3zsolkmF/7NTJyqvC5L8EUNpkZ6v+7W/Bwq7AM0Hnz09sXCfauTPYj8nKlRbj668nNia3jRjhbZfZ0Vi2zAoai+sELj/f+u+pVq3w6sZdu1oBf2amd3GW5NVX7X+mUiXV9u2960768GFLhIGeS9PSVAcPDh7Mi5Ofr/rnP9t8zz7rTXyx+te/rEfgMDEBlCbbt9umveuuREdS0MqVVk103Dj7nJmpet99dmZERQt9JsSbb1qVx3HjrHbLrFmqq1aVvIxANclt24r/3tq1duZbWHuROXMir4Lstj17VKtUsSvecA7GscrLs8ZzPXtag7dAY7cDB4o+k3/0UdvWDz3kfXxx4r8E8PTTdhZdGm3cqDpoUOz3MN2Wm2vtGPr2TXQk3jp0yK4Adu2KbTnr16vefLNVpw2c6Z57rha4bw1Y9ceSPPaYJd9w7jefeFadbA9t+e471dWr47/ejRuD9/MfeMCuDoYNK/h3XrfO7vH37Zvc1ZsPHrRkH+btKK8SQPLWApo1q/T08nii+vWBoUPd6YvFTSkpwO9+Zz1/5uUBGRn2XtYsX26/c9q06OY/dAh4/HGrDjh5slURzM+3acuWWY+pmZlWq+Trr4OPDi2uY7dIeratWNHW969/AXv32m956qnofosXOnSw6rbxVr9+sFpzx462fw0caDXY+vWzZ0k3bGi1i157LbzeTxPljTeAc86xZ28nUPImgET3BFpWXXaZVR+cOtV6Of3ww0RH5L7Wra2u9axZkc+7bp0d+J96yuqvr14NPPYYULmyTa9c2f4vzznH1nPZZVb1dO5cO/hkZBS+3Ozsk9sAFGfRIuDOO4GLL7YDWrI8QzZZ3HCDdXe9aJEl6PfeC7bPadkycV3IhytJOoRL3gSQnc0E4IXLLrP3xx+393g/8D0eype33zVjRvjzBM7Ezj7b6q/PmQO8/76ddYbj/PPthtBf/1p4I7RIT2jS061R1S+/2AHuxhtLnsePWre2s+mtW4Fnnkl0NOFjAihGfr5tGCYA97VqBbzyCnD4sJ1V1q2b6Ii8ccUV1ur5xJbPJ8rOBvr3B5o0sQZK5cpZH/i/+11k66tRA3j2WWvY9O67J0+fPDnyq60XXwT+7/+sUR8VLzW1dDVmZAIoxqFDdo+xYcNER1L2lC9v3USsX29numXVFVfY+8yZhU8/etRapjZtagf8O+4AKlSIbZ19+thttUGDgl0MBJx6KnDGGZEtr2ZNa31bo0ZscVHySZIEkJw3yqpVs8fYkTfmz7crgEsuSXQk3jnvPOvXpbCnue3dC7RrB6xaBXTrZonAjf5fypWzJ321bw9MnGi3cADramDIEHsiWFne5hS+OnWAESNO7lojzpIzAZC3GjWy/mSuvTbRkXhH5OTHW+7ebWfVNWoA3bvb83K7d3e3tki7dvbkqNB+mHbutCd/nX02EwCZypWBAQMSHUWS3gKaMcM6L1u3LtGRlE316lk5QJUqiY7EW1u2AP/zP9bB2qBBwFln2Vk/YPfXe/Twpqpg4OCflWXvydqxISVWZmZ8O/QrRDgPha8vIrNF5GcRWSEiA5zxT4jIFhFZ7Ly6uxbV2rX2bNNY78mSv4nYgb5DB+D554FevewKIB7mzLEyrK++soJmILJqoFT23XST9eqaQOHcAsoF8KCqLhKR6gAWish0Z9pwVX3B9ah4xkRuOPNMewj74cPAsGGFlwd4pW1boEED62J4yBAbx/9nCpUEXUKXmABUNQtAljO8X0RWAvC2vlVOjtWaqFTJ09WQD0TbGjhWlSpZ4fLvfx9sxcsrAArVuXPCE0BEZQAi0hBAawCBp1bcJyJLReQtESn02lpE+otIhohk5ATO7EvCVsBUFvToYbWMtm+38qxatRIdESWTRx6xK9MECjsBiEg1AB8DuF9V9wF4FcA5AFrBrhBeLGw+VR2tqumqmp4W7kG9fn3WlqDST8SuAo4etQZiydw3DflSWNVARaQC7OD/rqp+AgCquj1k+hsAprgW1dChri2KKKGaNQM2buQVLSWlcGoBCYA3AaxU1WEh40P7ELgewHL3wyMqA3jwpyQVzi2gDgDuANDlhCqfz4nIMhFZCqAzgAdciUjVOtZ65RVXFkdERIULpxbQdwAKu3n5hfvhANi3zxrrFNe3OhERxSz5WgIHGs3wspmIyFPJlwDYCIyIKC6YAIiIfCr5EkBqKnD11daMn4iIPJN83UFfdlnwsYVEROSZ5LsCICKiuEi+BNCv38kP8iAiItclXwLYtMkeoUdERJ5KvgTAnkCJiOKCCYCIyKeSKwGoWktgJgAiIs8lVwLIzQXuvJPPAiAiioPkagdQoQLwxhuJjoKIyBeS6wqAiIjihgmAiMinmACIiHyKCYCIyKeYAIiIfCqch8LXF5HZIvKziKwQkQHO+NNEZLqIrHHea3ofLhERuSWcK4BcAA+qanMA7QD8RUSaA3gYwExVbQJgpvOZiIhKiRITgKpmqeoiZ3g/gJUA6gHoCWCs87WxAK7zKkgiInJfRGUAItIQQGsA8wDUUdUsZ9I2AHWKmKe/iGSISEZO4HGPRESUcGEnABGpBuBjAPer6r7QaaqqALSw+VR1tKqmq2p6Gvv4ISJKGmElABGpADv4v6uqnzijt4tIXWd6XQDZ3oRIREReCKcWkAB4E8BKVR0WMulzAH2c4T4AJrkfHhEReSWczuA6ALgDwDIRWeyMewTAswAmiEhfABsA3OxNiERE5IUSE4CqfgdAiph8ubvhEBFRvLAlMBGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+VQ4D4V/S0SyRWR5yLgnRGSLiCx2Xt29DZOIiNwWzhXAOwC6FTJ+uKq2cl5fuBsWERF5rcQEoKpzAOyKQyxERBRHsZQB3CciS51bRDWL+pKI9BeRDBHJyMnJiWF1RETkpmgTwKsAzgHQCkAWgBeL+qKqjlbVdFVNT0tLi3J1RETktqgSgKpuV9U8Vc0H8AaAi90Ni4iIvBZVAhCRuiEfrwewvKjvEhFRcipf0hdE5H0AnQDUFpHNAB4H0ElEWgFQAOsB3ONhjERE5IESE4Cq3lrI6Dc9iIWIiOKILYGJiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinSkwAIvKWiGSLyPKQcaeJyHQRWeO81/Q2TCIicls4VwDvAOh2wriHAcxU1SYAZjqfiYioFCkxAajqHAC7ThjdE8BYZ3gsgOtcjouIiDwWbRlAHVXNcoa3AahT1BdFpL+IZIhIRk5OTpSrIyIit8VcCKyqCkCLmT5aVdNVNT0tLS3W1RERkUuiTQDbRaQuADjv2e6FRERE8RBtAvgcQB9nuA+ASe6EQ0RE8RJONdD3AfwI4DwR2SwifQE8C+BKEVkD4ArnMxERlSLlS/qCqt5axKTLXY6FiIjiiC2BiYh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp0p8JGRxRGQ9gP0A8gDkqmq6G0EREZH3YkoAjs6qusOF5RARURzxFhARkU/FmgAUwDQRWSgi/d0IiIiI4iPWW0AdVXWLiJwOYLqIrFLVOaFfcBJDfwBo0KBBjKsjIiK3xHQFoKpbnPdsAJ8CuLiQ74xW1XRVTU9LS4tldURE5KKoE4CIVBWR6oFhAF0BLHcrMCIi8lYst4DqAPhURALLeU9Vv3QlKiIi8lzUCUBVfwVwoYuxEBFRHLEaKBGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRT8WUAESkm4isFpFMEXnYraCIiMh7UScAEUkB8DKAqwE0B3CriDR3KzAiIvJWLFcAFwPIVNVfVfUogA8A9HQnLCIi8lr5GOatB2BTyOfNANqe+CUR6Q+gv/PxiIgsj2GdZUltADsSHUSS4LYI4rYI4rYIOs+LhcaSAMKiqqMBjAYAEclQ1XSv11kacFsEcVsEcVsEcVsEiUiGF8uN5RbQFgD1Qz6f5YwjIqJSIJYEsABAExFpJCIVAdwC4HN3wiIiIq9FfQtIVXNF5D4AXwFIAfCWqq4oYbbR0a6vDOK2COK2COK2COK2CPJkW4iqerFcIiJKcmwJTETkU0wAREQ+FZcEUFa7jBCRt0QkO7Rtg4icJiLTRWSN817TGS8iMtLZBktF5KKQefo4318jIn1CxrcRkWXOPCNFROL7C8MnIvVFZLaI/CwiK0RkgDPed9tDRCqLyHwRWeJsiyed8Y1EZJ4T/4dO5QmISCXnc6YzvWHIsgY741eLyFUvORhaAAADpElEQVQh40vVPiUiKSLyk4hMcT77cluIyHrnf3hxoGpnQvcRVfX0BSsgXgugMYCKAJYAaO71euPxAnApgIsALA8Z9xyAh53hhwEMdYa7A5gKQAC0AzDPGX8agF+d95rOcE1n2nznu+LMe3Wif3Mx26IugIuc4eoAfoF1EeK77eHEV80ZrgBgnhP3BAC3OONfA3CvM/xnAK85w7cA+NAZbu7sL5UANHL2o5TSuE8BGAjgPQBTnM++3BYA1gOofcK4hO0j8bgCKLNdRqjqHAC7ThjdE8BYZ3gsgOtCxo9TMxdAqojUBXAVgOmquktVdwOYDqCbM+1UVZ2r9pcdF7KspKOqWaq6yBneD2AlrLW477aH85sOOB8rOC8F0AXARGf8idsisI0mArjcOXPrCeADVT2iqusAZML2p1K1T4nIWQB6ABjjfBb4dFsUIWH7SDwSQGFdRtSLw3oTpY6qZjnD2wDUcYaL2g7Fjd9cyPik51y2t4ad+fpyezi3PBYDyIbtoGsB7FHVXOcrofEf/83O9L0AaiHybZSsRgAYBCDf+VwL/t0WCmCaiCwU6yYHSOA+4nlXEH6mqioivqpnKyLVAHwM4H5V3Rd6C9JP20NV8wC0EpFUAJ8CaJbgkBJCRK4BkK2qC0WkU6LjSQIdVXWLiJwOYLqIrAqdGO99JB5XAH7rMmK7cykG5z3bGV/Udihu/FmFjE9aIlIBdvB/V1U/cUb7dnsAgKruATAbQHvYJXzgpCs0/uO/2ZleA8BORL6NklEHANeKyHrY7ZkuAF6CP7cFVHWL854NOzG4GIncR+JQ6FEeVkjRCMFCmhZerzdeLwANUbAQ+HkULNB5zhnugYIFOvM1WKCzDlaYU9MZPk0LL9DpnujfW8x2ENg9xxEnjPfd9gCQBiDVGT4FwLcArgHwEQoWfP7ZGf4LChZ8TnCGW6BgweevsELPUrlPAeiEYCGw77YFgKoAqocM/wCgWyL3kXj98O6wWiFrAQxJ9B/Cxd/1PoAsAMdg99v6wu5XzgSwBsCMkD+MwB6gsxbAMgDpIcu5G1aolQngrpDx6QCWO/OMgtNyOxlfADrC7m8uBbDYeXX34/YAcAGAn5xtsRzAY874xs4OmukcACs54ys7nzOd6Y1DljXE+b2rEVKjozTuUyiYAHy3LZzfvMR5rQjEmsh9hF1BEBH5FFsCExH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH51P8DQLzbFOy5pggAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#Plot loss variations \n", "import matplotlib.pyplot as plt\n", "episodes = [1455,2423,3192,4037,4472,6292,7098,8897,\n", " 10784,11988,13541,14309,14855,15676,17303,18249,\n", " 21461,22917,23238,23586,26189,27369,28548,28919,\n", " 30695,32084,33770,35502,36281,37151,38717,39922,\n", " 41911,42709,43787,46754,48099,48561]\n", "\n", "loss = [21.74,36.16,32.86,33.93,31.62,31.17,28.76,27.21,31.20,\n", " 30.66,28.269,28.651,25.91,29.79,31.83,33.02,30.15,28.89,\n", " 26.92,25.30,27.16,27.08,24.59,30.02,26.48,28.96,30.13,\n", " 29.65,27.24,28.61,27.87,27.72,26.7,27.76,27.96,30.41,32.34,25.04]\n", "\n", "plt.plot(episodes, loss, 'r--')\n", "plt.axis([0, 50000, 0, 40])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## Save the model \n", "dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME_2), overwrite=True)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing for 10 episodes ...\n", "Episode 1: reward: 110.000, steps: 726\n", "Episode 2: reward: 130.000, steps: 604\n", "Episode 3: reward: 210.000, steps: 613\n", "Episode 4: reward: 110.000, steps: 922\n", "Episode 5: reward: 110.000, steps: 622\n", "Episode 6: reward: 260.000, steps: 571\n", "Episode 7: reward: 130.000, steps: 612\n", "Episode 8: reward: 260.000, steps: 567\n", "Episode 9: reward: 260.000, steps: 576\n", "Episode 10: reward: 260.000, steps: 578\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the algorithm for 10 episodes \n", "dqn.test(env, nb_episodes=10, visualize=True)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "### Another Policy with dqn " ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr=\"eps\", value_max=.8, value_min=.01,\n", " value_test=.0,\n", " nb_steps=100000)\n", "dqn = DQNAgent(model=model, nb_actions=nb_actions, nb_steps_warmup=10, \n", " policy=policy, test_policy=policy, memory = memory,\n", " target_model_update=1e-2)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "dqn.compile(Adam(lr=1e-3), metrics=['mae', 'acc'])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training for 50000 steps ...\n", " 2647/50000: episode: 1, duration: 78.745s, episode steps: 2647, steps per second: 34, episode reward: 1180.000, mean reward: 0.446 [0.000, 100.000], mean action: 5.080 [0.000, 13.000], mean observation: 1.496 [0.000, 240.000], loss: 29.992819, mean_absolute_error: 0.987530, acc: 0.366749, mean_q: 1.000000, mean_eps: 0.789505\n", " 5062/50000: episode: 2, duration: 70.523s, episode steps: 2415, steps per second: 34, episode reward: 1390.000, mean reward: 0.576 [0.000, 100.000], mean action: 4.961 [0.000, 13.000], mean observation: 2.056 [0.000, 240.000], loss: 29.211633, mean_absolute_error: 0.985632, acc: 0.362526, mean_q: 1.000000, mean_eps: 0.769553\n", " 6988/50000: episode: 3, duration: 56.540s, episode steps: 1926, steps per second: 34, episode reward: 1410.000, mean reward: 0.732 [0.000, 100.000], mean action: 4.895 [0.000, 13.000], mean observation: 2.147 [0.000, 240.000], loss: 30.044086, mean_absolute_error: 0.987133, acc: 0.360965, mean_q: 1.000000, mean_eps: 0.752406\n", " 7721/50000: episode: 4, duration: 21.421s, episode steps: 733, steps per second: 34, episode reward: 430.000, mean reward: 0.587 [0.000, 100.000], mean action: 4.943 [0.000, 13.000], mean observation: 2.584 [0.000, 240.000], loss: 26.033418, mean_absolute_error: 0.980401, acc: 0.356114, mean_q: 1.000000, mean_eps: 0.741903\n", " 9006/50000: episode: 5, duration: 37.220s, episode steps: 1285, steps per second: 35, episode reward: 1280.000, mean reward: 0.996 [0.000, 100.000], mean action: 4.936 [0.000, 13.000], mean observation: 1.588 [0.000, 240.000], loss: 26.786589, mean_absolute_error: 0.981151, acc: 0.347811, mean_q: 1.000000, mean_eps: 0.733932\n", " 9489/50000: episode: 6, duration: 14.189s, episode steps: 483, steps per second: 34, episode reward: 460.000, mean reward: 0.952 [0.000, 100.000], mean action: 4.818 [0.000, 13.000], mean observation: 2.550 [0.000, 240.000], loss: 27.920443, mean_absolute_error: 0.983796, acc: 0.344591, mean_q: 1.000000, mean_eps: 0.726949\n", " 10482/50000: episode: 7, duration: 29.064s, episode steps: 993, steps per second: 34, episode reward: 930.000, mean reward: 0.937 [0.000, 100.000], mean action: 4.862 [0.000, 13.000], mean observation: 2.175 [0.000, 240.000], loss: 32.332520, mean_absolute_error: 0.988627, acc: 0.349100, mean_q: 1.000000, mean_eps: 0.721118\n", " 11303/50000: episode: 8, duration: 24.297s, episode steps: 821, steps per second: 34, episode reward: 580.000, mean reward: 0.706 [0.000, 100.000], mean action: 4.653 [0.000, 13.000], mean observation: 2.164 [0.000, 240.000], loss: 25.372195, mean_absolute_error: 0.978653, acc: 0.348432, mean_q: 1.000000, mean_eps: 0.713953\n", " 12967/50000: episode: 9, duration: 49.121s, episode steps: 1664, steps per second: 34, episode reward: 1180.000, mean reward: 0.709 [0.000, 100.000], mean action: 4.603 [0.000, 13.000], mean observation: 1.855 [0.000, 240.000], loss: 26.680792, mean_absolute_error: 0.980070, acc: 0.343675, mean_q: 1.000000, mean_eps: 0.704137\n", " 14767/50000: episode: 10, duration: 52.906s, episode steps: 1800, steps per second: 34, episode reward: 2050.000, mean reward: 1.139 [0.000, 100.000], mean action: 4.512 [0.000, 13.000], mean observation: 2.010 [0.000, 240.000], loss: 26.146817, mean_absolute_error: 0.980223, acc: 0.340104, mean_q: 1.000000, mean_eps: 0.690455\n", " 17088/50000: episode: 11, duration: 68.414s, episode steps: 2321, steps per second: 34, episode reward: 1880.000, mean reward: 0.810 [0.000, 100.000], mean action: 4.256 [0.000, 13.000], mean observation: 2.131 [0.000, 240.000], loss: 28.392949, mean_absolute_error: 0.984508, acc: 0.342996, mean_q: 1.000000, mean_eps: 0.674177\n", " 17370/50000: episode: 12, duration: 8.740s, episode steps: 282, steps per second: 32, episode reward: 160.000, mean reward: 0.567 [0.000, 50.000], mean action: 4.316 [0.000, 13.000], mean observation: 2.853 [0.000, 240.000], loss: 33.064301, mean_absolute_error: 0.992362, acc: 0.329455, mean_q: 1.000000, mean_eps: 0.663895\n", " 17887/50000: episode: 13, duration: 15.122s, episode steps: 517, steps per second: 34, episode reward: 580.000, mean reward: 1.122 [0.000, 100.000], mean action: 4.555 [0.000, 13.000], mean observation: 2.274 [0.000, 240.000], loss: 24.594645, mean_absolute_error: 0.978975, acc: 0.335529, mean_q: 1.000000, mean_eps: 0.660739\n", " 18599/50000: episode: 14, duration: 20.679s, episode steps: 712, steps per second: 34, episode reward: 630.000, mean reward: 0.885 [0.000, 100.000], mean action: 3.958 [0.000, 13.000], mean observation: 2.229 [0.000, 240.000], loss: 26.505238, mean_absolute_error: 0.981136, acc: 0.337166, mean_q: 1.000000, mean_eps: 0.655884\n", " 19641/50000: episode: 15, duration: 30.318s, episode steps: 1042, steps per second: 34, episode reward: 1080.000, mean reward: 1.036 [0.000, 100.000], mean action: 4.361 [0.000, 13.000], mean observation: 1.825 [0.000, 240.000], loss: 30.079884, mean_absolute_error: 0.988954, acc: 0.330824, mean_q: 1.000000, mean_eps: 0.648956\n", " 20361/50000: episode: 16, duration: 20.884s, episode steps: 720, steps per second: 34, episode reward: 830.000, mean reward: 1.153 [0.000, 100.000], mean action: 4.181 [0.000, 13.000], mean observation: 2.156 [0.000, 240.000], loss: 30.022883, mean_absolute_error: 0.986804, acc: 0.341710, mean_q: 1.000000, mean_eps: 0.641996\n", " 20921/50000: episode: 17, duration: 16.383s, episode steps: 560, steps per second: 34, episode reward: 280.000, mean reward: 0.500 [0.000, 50.000], mean action: 4.029 [0.000, 13.000], mean observation: 2.809 [0.000, 240.000], loss: 32.525301, mean_absolute_error: 0.990283, acc: 0.343750, mean_q: 1.000000, mean_eps: 0.636940\n", " 22419/50000: episode: 18, duration: 43.820s, episode steps: 1498, steps per second: 34, episode reward: 1080.000, mean reward: 0.721 [0.000, 100.000], mean action: 4.009 [0.000, 13.000], mean observation: 1.739 [0.000, 240.000], loss: 29.147020, mean_absolute_error: 0.986433, acc: 0.338722, mean_q: 1.000000, mean_eps: 0.628811\n", " 23198/50000: episode: 19, duration: 22.840s, episode steps: 779, steps per second: 34, episode reward: 930.000, mean reward: 1.194 [0.000, 100.000], mean action: 4.067 [0.000, 13.000], mean observation: 2.250 [0.000, 240.000], loss: 28.683109, mean_absolute_error: 0.986974, acc: 0.344071, mean_q: 1.000000, mean_eps: 0.619817\n", " 24514/50000: episode: 20, duration: 38.692s, episode steps: 1316, steps per second: 34, episode reward: 1360.000, mean reward: 1.033 [0.000, 100.000], mean action: 4.226 [0.000, 13.000], mean observation: 2.471 [0.000, 240.000], loss: 30.828469, mean_absolute_error: 0.991295, acc: 0.365715, mean_q: 1.000000, mean_eps: 0.611542\n", " 26366/50000: episode: 21, duration: 53.995s, episode steps: 1852, steps per second: 34, episode reward: 1320.000, mean reward: 0.713 [0.000, 100.000], mean action: 4.002 [0.000, 13.000], mean observation: 1.992 [0.000, 240.000], loss: 30.103927, mean_absolute_error: 0.989485, acc: 0.372182, mean_q: 1.000000, mean_eps: 0.599028\n", " 27983/50000: episode: 22, duration: 47.561s, episode steps: 1617, steps per second: 34, episode reward: 1490.000, mean reward: 0.921 [0.000, 100.000], mean action: 3.993 [0.000, 13.000], mean observation: 1.986 [0.000, 240.000], loss: 31.202262, mean_absolute_error: 0.990682, acc: 0.369086, mean_q: 1.000000, mean_eps: 0.585325\n", " 29873/50000: episode: 23, duration: 55.033s, episode steps: 1890, steps per second: 34, episode reward: 1480.000, mean reward: 0.783 [0.000, 100.000], mean action: 3.815 [0.000, 13.000], mean observation: 1.960 [0.000, 240.000], loss: 33.852708, mean_absolute_error: 0.996705, acc: 0.380539, mean_q: 1.000000, mean_eps: 0.571473\n", " 30851/50000: episode: 24, duration: 29.198s, episode steps: 978, steps per second: 33, episode reward: 780.000, mean reward: 0.798 [0.000, 100.000], mean action: 3.757 [0.000, 13.000], mean observation: 2.426 [0.000, 240.000], loss: 30.200397, mean_absolute_error: 0.990111, acc: 0.375991, mean_q: 1.000000, mean_eps: 0.560144\n", " 31931/50000: episode: 25, duration: 31.562s, episode steps: 1080, steps per second: 34, episode reward: 1180.000, mean reward: 1.093 [0.000, 100.000], mean action: 3.903 [0.000, 13.000], mean observation: 1.719 [0.000, 240.000], loss: 35.342483, mean_absolute_error: 0.997397, acc: 0.402402, mean_q: 1.000000, mean_eps: 0.552015\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 32370/50000: episode: 26, duration: 12.899s, episode steps: 439, steps per second: 34, episode reward: 230.000, mean reward: 0.524 [0.000, 50.000], mean action: 3.752 [0.000, 13.000], mean observation: 3.075 [0.000, 240.000], loss: 31.779249, mean_absolute_error: 0.992009, acc: 0.407247, mean_q: 1.000000, mean_eps: 0.546015\n", " 34069/50000: episode: 27, duration: 49.556s, episode steps: 1699, steps per second: 34, episode reward: 1370.000, mean reward: 0.806 [0.000, 100.000], mean action: 3.727 [0.000, 13.000], mean observation: 1.739 [0.000, 240.000], loss: 32.917870, mean_absolute_error: 0.994308, acc: 0.397752, mean_q: 1.000000, mean_eps: 0.537570\n", " 35248/50000: episode: 28, duration: 34.380s, episode steps: 1179, steps per second: 34, episode reward: 930.000, mean reward: 0.789 [0.000, 100.000], mean action: 3.759 [0.000, 13.000], mean observation: 2.046 [0.000, 240.000], loss: 31.251266, mean_absolute_error: 0.990669, acc: 0.407125, mean_q: 1.000000, mean_eps: 0.526202\n", " 36460/50000: episode: 29, duration: 35.646s, episode steps: 1212, steps per second: 34, episode reward: 1080.000, mean reward: 0.891 [0.000, 100.000], mean action: 3.340 [0.000, 13.000], mean observation: 1.672 [0.000, 240.000], loss: 32.289888, mean_absolute_error: 0.991152, acc: 0.403208, mean_q: 1.000000, mean_eps: 0.516757\n", " 38501/50000: episode: 30, duration: 59.686s, episode steps: 2041, steps per second: 34, episode reward: 1350.000, mean reward: 0.661 [0.000, 100.000], mean action: 3.484 [0.000, 13.000], mean observation: 1.783 [0.000, 240.000], loss: 30.293467, mean_absolute_error: 0.990134, acc: 0.409343, mean_q: 1.000000, mean_eps: 0.503908\n", " 39551/50000: episode: 31, duration: 30.607s, episode steps: 1050, steps per second: 34, episode reward: 980.000, mean reward: 0.933 [0.000, 100.000], mean action: 3.148 [0.000, 13.000], mean observation: 1.775 [0.000, 240.000], loss: 32.976856, mean_absolute_error: 0.992399, acc: 0.419137, mean_q: 1.000000, mean_eps: 0.491699\n", " 40200/50000: episode: 32, duration: 18.901s, episode steps: 649, steps per second: 34, episode reward: 380.000, mean reward: 0.586 [0.000, 100.000], mean action: 3.661 [0.000, 13.000], mean observation: 2.205 [0.000, 240.000], loss: 29.077991, mean_absolute_error: 0.986476, acc: 0.389686, mean_q: 1.000000, mean_eps: 0.484988\n", " 42374/50000: episode: 33, duration: 63.338s, episode steps: 2174, steps per second: 34, episode reward: 1300.000, mean reward: 0.598 [0.000, 100.000], mean action: 9.507 [0.000, 13.000], mean observation: 1.670 [0.000, 240.000], loss: 31.056681, mean_absolute_error: 0.986676, acc: 0.079836, mean_q: 0.999999, mean_eps: 0.473837\n", " 43610/50000: episode: 34, duration: 36.134s, episode steps: 1236, steps per second: 34, episode reward: 880.000, mean reward: 0.712 [0.000, 100.000], mean action: 9.278 [0.000, 13.000], mean observation: 2.080 [0.000, 240.000], loss: 28.141218, mean_absolute_error: 0.981980, acc: 0.117440, mean_q: 1.000000, mean_eps: 0.460367\n", " 45651/50000: episode: 35, duration: 59.989s, episode steps: 2041, steps per second: 34, episode reward: 1480.000, mean reward: 0.725 [0.000, 100.000], mean action: 9.611 [0.000, 13.000], mean observation: 1.897 [0.000, 240.000], loss: 27.262516, mean_absolute_error: 0.981316, acc: 0.158623, mean_q: 1.000000, mean_eps: 0.447423\n", " 48137/50000: episode: 36, duration: 72.982s, episode steps: 2486, steps per second: 34, episode reward: 2150.000, mean reward: 0.865 [0.000, 100.000], mean action: 9.624 [0.000, 13.000], mean observation: 2.096 [0.000, 240.000], loss: 26.563936, mean_absolute_error: 0.980988, acc: 0.219831, mean_q: 1.000000, mean_eps: 0.429541\n", " 49183/50000: episode: 37, duration: 30.598s, episode steps: 1046, steps per second: 34, episode reward: 980.000, mean reward: 0.937 [0.000, 100.000], mean action: 9.748 [0.000, 13.000], mean observation: 2.036 [0.000, 240.000], loss: 30.250379, mean_absolute_error: 0.987904, acc: 0.271481, mean_q: 1.000000, mean_eps: 0.415590\n", "done, took 1484.830 seconds\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dqn.fit(env, nb_steps=50000, visualize=True, verbose=2)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XmYFOW1P/DvAQZEQEEYCauAsoheRR0BB6KoURC8Isp1CZcQ9YqXgNEgIkbiEoxX1J+CRgVUBCNhFfeFTYSACgyLLKLCsMPIjsO+zfn9carTPTM901tVdw/1/TxPP11bV52uma5T9b5vvSWqCiIi8p9yqQ6AiIhSgwmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp6JOACJSXkSWisgnznhjEVkgImtFZKKIVPQuTCIiclssVwAPAFgdMj4UwEuqeh6AvQDucTMwIiLyVlQJQETqA+gC4E1nXABcA2CKs8hYADd7ESAREXmjQpTLDQMwEEA1Z7wmgH2qesIZ3wKgXrgPikhvAL0BoEqVKpe1aNEi/miJiHxo8eLFu1Q10+31RkwAInIjgB2qulhEOsS6AVUdBWAUAGRlZWlOTk7MQRIR+ZmIbPRivdFcAbQDcJOIdAZwGoAzAAwHUF1EKjhXAfUBbPUiQCIi8kbEOgBVfVRV66tqIwB3APhSVXsAmA2gu7NYLwAfehYlERG5LpH7AB4B0F9E1sLqBN5yJyQiIkqGaCuBAQCq+hWAr5zhdQBaux8SERElA+8EJiLyKSYAIiKfYgIgIvIpJgAiIp9iAiAi8ikmACIin2ICICLyKSYAIiKfYgIgIvIpJgAiIp9iAiAi8ikmACIin2ICICLyKSYAIiKfYgIgIvIpJgAiIp9iAiAi8ikmACIin4qYAETkNBFZKCLficgqEXnKmT5GRNaLyDLn1cr7cInIFaqpjoDSQDTPBD4K4BpVPSAiGQDmicjnzryHVXWKd+ERkSd69gRyc4Fvvkl1JJRCEROAqiqAA85ohvPi6QNRWbV9OzBunA3/8APQokVq46GUiaoOQETKi8gyADsAzFDVBc6sv4nIchF5SUQqeRYlEbln9Gh7b90aEEltLJRSUSUAVT2pqq0A1AfQWkQuBPAogBYALgdwFoBHwn1WRHqLSI6I5OzcudOlsIkobh9/DFx9NbBgAdC8eaqjoRSKqRWQqu4DMBtAJ1XNU3MUwNsAWpfwmVGqmqWqWZmZmYlHTESJmTMH+Mc/bHj5cmDdutTGQykTTSugTBGp7gxXBnAdgB9EpI4zTQDcDGCll4ESkQtUgYwMoF494MABoE0b4MUXUx0VpUg0VwB1AMwWkeUAFsHqAD4BME5EVgBYAaAWgKe9C5OIErZpE9C4MfDllzZetSpw003AxInA8eOpjY1SIppWQMsBXBJm+jWeRERE3njzTUsCTZoEp/XoAUyaBMyYAXTunLrYKCV4JzBRWXX4MHDxxcCECZGXPX7cEkDnzkCjRsHpnToBNWoEm4WSrzABEJVVmzZZJe4zz0Re9uOPgbw84H//t/D0ihWB//ovuwJgMZDvMAEQlVXNmwNXXglUqxZ52ddfBxo0AG64ofi8v/7VWgJlZLgfI6W1aLqCIKJ0s3Ch3cHbogUwdWrk5R9+GMjPB8qXLz6vdm3346MygVcAlDhV4G9/s24FyHuHDlnrnd//HmjaFNi1C9izp/TPXH890L17yfPnzgUuvxzYscPVUCm9MQFQ4vLzgcGDgTvuSHUk/jBihPXn078/cNttwPz51qQznCNH7G+zaVPp66xRA8jJsRZB5BtMAJS4M88E6tcHLinWWpjcdvAgMHQo8JvfAO3bAw0bAtnZVpkbzpQpdnW2Zk3p6/2P/7DXP//pfszRWLIEeOed6Cq0yTVMAOSOunWBrVtTHcWpb8QIK6Z58sngtEmTgNmzS16+aVPr+yeSHj2se+hkdw2Rn28JrVcv4LnngIKC5G7fx5gA0tHx48ADD5SdPlpeeMEqJbdsSXUkp76lS4HrrgPatQtO+/OfgZEjiy+7YoUVD913H1Auip/6nXfaezRXAcuXA//6V3QxR/L668DevdZE9ZdfgLVr3VkvRcQEkI7mzQNefhl47bVURxKdZcvsvaRyaC8dOQLs35/87abKu+8CH3xQeFqzZsBPPxVfdsQIoFIlqyyORsOGwIABwEUXlbzM8ePAU0/ZDWjhthmrQ4esL6Lrrwf69LFpOTmJr5eiwgSQjr791t4HDUptHNHKzbUihoULk7vdffusM7NwbdtPNQcPBityTz+98LxAAgj3mMeePYGaNaPfzvPPWwujcH74wa48nnzSKvxvvdWmT59uZ/DxePNNK9IaPBho2RKoXBlYtCi+dVHMmADS0ddf200+NWsCx46lOprIcnOBc89N7jYPH7YD1fLldgVyqj/j9rXXrCx/48bi85o1swSRl1d4+quvAqNGxb6tn3+2ZqGhpk61Sv7cXGDyZGD8eKB6dTvwd+8OZGXZ3yJWM2bYzWy//jVQoYI1RY03mVDMmADSTUGBJYCWLe3OzREjUh1R6fbvB3butKKYa69N3uX7nj223Y4dwx/8TiUHD1rl6NVXA+ecU3x+s2b2/uOP9q4avCcjnid+9eljZ/gnTwantWoF3HgjsHJl4fsJatQApk2zv3/btrG3Ivroo8I3ss2eDYwZE3vMFBcmgHSTn2/N+26/3W7x/+STVEdUuoMHgVtusTO3L7/0vgJP1ZJkvXrAd99Zgvz4Y2uKeqp69VW72euJJ8LPz8624qGrrrLxRYuA88+Pv03/b39rCXXAAGuZo2o9iE6eDNSpU3z5K64AFi+2q4AePYA//SnyFdmJE3amL1K4iCqaympyj6om7XXZZZcpxeChh1QrVlTNz091JJHt3asKqL7wgrfbGTxYtUcP1ePHvd1OutizR7VWLdWOHaP/zF13qVapovrLL/Ft89Ah1WrV7O+Zna26b190nzt2TPWPf1R95JHIy/7jH7aN778vPH33btVrr1WdMCH2uE9hAHLUg2My0226OXw4ONyli9UBzJyZungiCbTZPvNMq5zcts27bb3yCvD007ad0D5tZs60YrNT0ddf25O7hgwpfbmxY23/7N1r3UP36AGccUZ826xc2eochg+3uoBor64yMuwz//d/Nv7tt+EbBhQU2A1fjRoVfyZx9er2maJ1EOQNL7JKSS9eAUSheXPV++6z4WPHVM84Q/Wee1IbU2nuu0/1ggts+LzzVG+/3ZvtTJigKqJ6883Fz/5btFDt1s2b7cZj61bVV16xs9uCgvjWcfhw4fVFcuut9r8zfLiduS9ZEt923VJQYFcPFSuqvvlm4XlTpliM48eH/2yHDqqtW8e33e+/V120yIbLwpVzlMArAB/Ytcsq8ho3tvGMDGuWd/vtqY2rNGvXAlWq2HB2NvCrX7m/jRkzrDnjr39trU8qFOnEtmnTyF0dJMuOHVZZe//9VpHfvLmVj8di8WLgvPPsewN2l3UkzZpZC52xY61pbKq75RCxCt6rrgL+53/sJq+jR4MdBzZtas8hCCcry1p2xdICThV44w3gssuAP/zBriDq1wcWLHDn+5yi2B10OgkUY2RnB6f17p2aWKKVm2uVgIAdfLyQkWEtTD78EDjttOLzmza1g2VBQeorEXv2BDZvBt5/3ypSP/oo+ASuMWPs7tmuXa3rg6Lt+QGrSO/a1SpGw7X4KUmzZlax+sYb4debCjVrAp9/bm38n33WKu2ffNLuZh49OnzX1IA1KDh2zFocXXpp5O3s3Wu/kylTbL++847dlFipEvDoo8CsWfG1hvKDSJcIAE4DsBDAdwBWAXjKmd4YwAIAawFMBFAx0rpYBBTBwIGqGRlWCRdqyRLVOXNSE1Npjh1TLVdO9bHHvFn/gQPB4dKKUl5/3YoUNm3yJo5YfP+96rRp4ec984zqmWdarJUrq3btqjpmTHD+5MlWZHLhhapbtsS23XnzbL2ffhp/7F6aMsWKB0+eVF261P53SrJ+vep116kuWBB5vRs2qDZooFqhgurQobb+gEBx2PTpCYefavCoCCiaBCAAqjrDGc5Bvy2ASQDucKaPANAn0rqYACJo1061TZvi09u2jb9M1Etr1ti/0OjRNj55smrTptaSI1a7d9t6Jk9WnTVLdf581caNVV99NfJnZ860OGbNin27bjh2zFq1RFPef/SoHZD69rUDV/v2Nv2bb6yOo107a/kTq507bR8MHx77Z8uyEydU7747fLI4ckT1nHNUL7ss/rqYNJGyBFBoYeB0AEsAtAGwC0AFZ/oVAKZF+rwrCeCrr+xg+Oij9sMverbsllWrrMlhnz7erD+csWPDN38bMsT+VD//nLxYorFxozX7W7HCxidOtDiXL499Xdu2qd5/v30+8KpSJVihV5r9+1WXLStccZosJ09as1RAde7c2D5bUKC6a1dw+O9/Vz14MP5Y5swpfNVU1pX099y40Sr9t22LvI4xY+xvM3++u7ElWUoTAIDyAJYBOABgKIBaANaGzG8AYGUJn+0NIAdATsOGDRPfEzNn2llShQoWfqVKqtdcE90/QyQbN6o++6zqxRfbusuVU+3c2f4Ri7ZXTqYlSyyet99OXQzRCBRDfPFFfJ/fscOSx+zZVmTw44+uhue6ggLVfv3sOz/zTKqjObUMG2a/7aIneFOmqFavrlq1anRFOydOpL5FlAvS5QqgOoDZANpHmwBCX64WAeXnW3ln//7W3CzQNPAvf1G96SbVl1+2s/hIl347dgTPNJ5+2nZJ27Z2KZ2XZ9P79bN/uBMn3Iu/qJUrSz7gFRSo1q1rTf3Sya5ddpkdsG6d7b+33optPQcOqL7zjhVjxOu995KfIP/yF/u+AwaU+SKGtPPBB7Zvv/7axg8eVL33Xpt2+eVW/BirRK6uUiwtEoDFgccBPJyyIqBInn5atUkT/XcxQp06qg88UHiZ/Hw74HTqpFq+vJU7q6pu366am1t8ne++a+tatsy7uLt1Uz333JLn33uv6tlne5uEYtW1q1VYBhw5YvtpyJDY1jNtmn2upMrTaGNp2TL+z8dq9Wr737nnHh78vbBli/1PvPyyjffvb+MDB1o9Sqyee87qA7wqMvaYVwkgYps5EckUkerOcGUA1wFY7VwJBHqF6gXgw0jrSorHHrOmievWWZO4q64q3KnVNdcAZ58N/O53wOrVwMMPW0dXgE1v0qT4OgPNHL2621TVHtwR2vyzqCFDgPXrS246lwq5ucF7FgBrdtetm3ViF4s5c6xtf2nfP5KmTS2eZD1NqkUL+5uNHMkmhl6oV8/uKZk+3cYHD7bmnEOHlvz4y9K0bm09qZaVZ2wkS6QMAeAiAEsBLAewEsDjzvQmsOahawFMBlAp0rpS3gro4EE70+7b18qrQ5uMlaagQLV2bdWePb2JK9CaZsQIb9bvhYIC1dNPV33wwcTXlZ1txW6JGDnS9uGGDdEt/+KLwau97t3tCmLIENXPPrNiwZJMnWov8l6HDvY3jeeMP5yOHVXPOiv6vo3SCNKlCCiRV8oTQCIiFdEkYuxYjar1zLhxdrBKB9u2WcyvvJLYeg4csAr9QYMSW8+XX1o8M2ZEXnbRIlt22DAbv/de1WbNgsWGgGqvXsHl58yx+o6ZM62dfvv20Z88UPxWrbKWZW4VsS1ebH/bwYPdWV8SeZUAeCdwtB55xDpqU3X/kn/+fOtw64ILSl9uzx6723HNGivySKXcXHsv+iCYRx8NxhiNnBy7gzXQlXG8Avtj/frIyw4fbl1t33WXjQcemvLLL3aXak5OsGhr375gbOXK2d/oo49Sf8exH7RsaS+3XHqpdavy+uv2HOXKld1bdxklllySIysrS3P4vM/itm+3A2b79qUvt3691VG89BLw4IPJia0kmzZZ//C//W3hPuIff9z6ejl6tHifPSXZvBnIzAzfzUO0VC1BRnr8YV6edbHQp48lgkiOHrUEvXixPSlrwIDwfeJT2bB5s73HWk+VYiKyWFWzXF8vE0AMpk+3A03HjqmLoWVLqyALdBSWbkaMsIPrli0WZ7p54gmrUP/pJ+twjfzryJHETjqSyKsEwOvYWATObt20eDHwwgtW/BCNG2+0VjP797sbR6yWLwe2bi0+PXDQDzevqEOH7JLcrdZVkyYBDz0UeZu33sqDv5+pArfdBnzwQaojSTkmgFhkZ9vj9tx8UPvYscCgQdEXl3TtClx/vXUdnUq/+x1w333FpwcSQDQPhvnmGztoR5v8IlmyxB6KcuJEycs8/3z8j0qkU4MI8Pbb9txjn2MCiEV2tl02LluW+LpU7alIr7xi/4iBPvUjadfOnhMc2v4+2QoK7LkFRZ/mBFj5eq9e0T0XYM4cq0xt186duJo2BY4ft/qJolSBFStsmO32Kdrf2ymOCSAWgRuVvvkmsfWoAgMH2k1r//3fdjYSq7y85N30VNSmTZYIW7QoPq9mTev3vm3byOv56it7gEe8jy4sKtASKFwLpHnzgIsuAqZOdWdbRKcAJoBY1K1rZ7iJVmTn5dlBsm9fKwLKyIjt8x9/bLEsWZJYHPH68Ud7D5cAAEtwhw6Vvo7Dh+1pTR06uBdXaQlg+HCgRg2gUyf3tkdUxvE+gFjNmxd/M8Djx62sv25dK0aqWze+4ogrrrDPffKJPT4v2X74wd7DFQEB1m7+tNOCt/GHs22btam/+mr34vrVr4Date0h6qE2brQndA0YkD5PyyJKA0wAsapfP77PHTwI3HKLPe7u6acTayJZq5YVsXz6qT1iL9luvtmSYGZm+Pm1agWvEkpy7rnuX8GI2NVV0aT66qs2rW9fd7dHVMaxCChWe/faA65LO7stKj8fuO46YObM8J3NxaNLFyuKystzZ32xOOcca0ZX0tVLvXqRWwF5VX9RNKaCArszuVs3oGFDb7ZJVEYxAcSqalV76PRnn0W3fEGBVfQuXGh3zt59tztx3HijvX/+efSfUQXee8+Kbho3toPisGGxb/vdd0s/w69b17pQKKke4MgR63l15MjYtx3J1KlWBHX8uI2XK2cPI3/pJfe3RVTGMQHEKiPDinGibQn0179ape2wYVYE5JaLLgJGjwZuuCH6zzz+ONC9u3Wn26aNdYf9xRfB+VdeCVx7rd1M9e67wMqVwQNpQH4+0LOnlamXpG5dey/pKuDzz4Hdu4PLuSk/H5g7F9iwIdi1W7Vq8RfdEZ3CWAcQj+xsu3v38OHIHUqdf76VPbtd/iwS7MysNPn5Fmft2nYlcvbZ1lVD4MazwLMSVIELL7Ripddes7N0wNr0jxlj80eODH6upApgwCqmBw8O39Z63z7g/vttv1x/fdRfN2qhLYHWrwf+9Ce74zPVnecRpSEmgHhkZ9vdpjNnAv/5n3bQnD7dintU7b2gwA4+t99uLy8cOgSMG2dXJIGH2gQUFNhZ/MCBdqNVoOin6IE78IAZkeDDMk6csCKepUuD5eabN1viCDj//JLjuuAC628nnP79rVO199+3B8i4LTQBTJ9uVxos+ycKiwkgHu3aWffNgR4Fv/7abuoq6q67gOrVvY3lj3+0SunQBLBkCdCvnxVTtWljXTTHokIFO4iHdk/doIHdALZsmTWzLOkeAMCS4O7dNlyrVnD65s3AxInWtfbll8cWU7QyM+3Gss8+swTwxBPeJBqiUwB7A43XsWN29ly+vJ0xFxTYWXS5cvZKVncDnTvb2W7g5qfx44EePexAOHSo9dmT7L7rA+Xu995bvPI1N9fK4708KHfvblc8GRmWtKLploIojbE30HRTsWKw+KRCBRvPyLBpyexrpksXYO1a61cHsHL1gQOtu+Pf/z41Dy4RKd4UdP58Swznnuv9Gflbb1lrrTvu4MGfqBTRPBS+gYjMFpHvRWSViDzgTH9SRLaKyDLn1dn7cKmYLl3svUMHq9CtWRN49lkrokqlunWDXUJ/+KE97Gb8+ORs+/TT7SlfAwcmZ3tEZVQ0dQAnADykqktEpBqAxSISeBrJS6r6gnfhUUSNGlkxS/366fWYwrp1rQ5i927rNrpVKyuaSYaMDODOO5OzLaIyLGICUNU8AHnO8H4RWQ0gDR/15GOpfjxkOIEioH797FGN06ZZMRkRpY2YThlFpBGASwAscCb1E5HlIjJaRGqU8JneIpIjIjk7d+5MKFgqQ265xe7InTDBbkC7+OJUR0RERUTdCkhEqgKYA+BvqjpVRGoD2AVAAQwBUEdVS+3n4JRqBUSRffqp3VvwwQexd3lNRP+W0lZAIpIB4D0A41R1KgCo6nZVPamqBQDeANDa7eCojOvSxZIAD/5EaSmaVkAC4C0Aq1X1xZDpoZ3idwOw0v3wiIjIK9G0AmoHoCeAFSISeBjunwHcKSKtYEVAGwCEeUI4ERGlq2haAc0DEO7Opij7QyYionSURg3HiYgomZgAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfCqah8I3EJHZIvK9iKwSkQec6WeJyAwRWeO81/A+XCIicks0VwAnADykqi0BtAXQV0RaAhgEYJaqNgUwyxknIqIyImICUNU8VV3iDO8HsBpAPQBdAYx1FhsL4GavgiQiIvfFVAcgIo0AXAJgAYDaqprnzPoZQO0SPtNbRHJEJGfnzp0JhEpERG6KOgGISFUA7wF4UFXzQ+epqgLQcJ9T1VGqmqWqWZmZmQkFS0RE7okqAYhIBuzgP05VpzqTt4tIHWd+HQA7vAmRiIi8EE0rIAHwFoDVqvpiyKyPAPRyhnsB+ND98IiIyCsVolimHYCeAFaIyDJn2p8BPAtgkojcA2AjgNu8CZGIiLwQMQGo6jwAUsLsa90Nh4iIkoV3AhMR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU9F81D40SKyQ0RWhkx7UkS2isgy59XZ2zCJiMht0VwBjAHQKcz0l1S1lfP6zN2wiIjIaxETgKrOBbAnCbEQEVESJVIH0E9EljtFRDVKWkhEeotIjojk7Ny5M4HNERGRm+JNAK8DOBdAKwB5AP5fSQuq6ihVzVLVrMzMzDg3R0REbosrAajqdlU9qaoFAN4A0NrdsIiIyGtxJQARqRMy2g3AypKWJSKi9FQh0gIiMh5ABwC1RGQLgCcAdBCRVgAUwAYA93kYIxEReSBiAlDVO8NMfsuDWIiIKIl4JzARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEflUxAQgIqNFZIeIrAyZdpaIzBCRNc57DW/DJCIit0VzBTAGQKci0wYBmKWqTQHMcsaJiKgMiZgAVHUugD1FJncFMNYZHgvgZpfjIiIij8VbB1BbVfOc4Z8B1C5pQRHpLSI5IpKzc+fOODdHRERuS7gSWFUVgJYyf5SqZqlqVmZmZqKbIyIil8SbALaLSB0AcN53uBcSERElQ7wJ4CMAvZzhXgA+dCccIiJKlmiagY4H8A2A5iKyRUTuAfAsgOtEZA2A3zjjRERUhlSItICq3lnCrGtdjoWIiJKIdwITEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPRXwkZGlEZAOA/QBOAjihqlluBEVERN5LKAE4rlbVXS6sh4iIkohFQEREPpVoAlAA00VksYj0diMgIiJKjkSLgNqr6lYRORvADBH5QVXnhi7gJIbeANCwYcMEN0dERG5J6ApAVbc67zsAvA+gdZhlRqlqlqpmZWZmJrI5IiJyUdwJQESqiEi1wDCA6wGsdCswIiLyViJFQLUBvC8igfX8U1W/cCUqIiLyXNwJQFXXAbjYxViIiCiJ2AyUiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinmACIiHyKCYCIyKeYAIiIfIoJgIjIp5gAiIh8igmAiMinEkoAItJJRH4UkbUiMsitoIiIyHtxJwARKQ/gVQA3AGgJ4E4RaelWYERE5K1ErgBaA1irqutU9RiACQC6uhMWERF5rUICn60HYHPI+BYAbYouJCK9AfR2Ro+KyMoEtnkqqQVgV6qDSBPcF0HcF0HcF0HNvVhpIgkgKqo6CsAoABCRHFXN8nqbZQH3RRD3RRD3RRD3RZCI5Hix3kSKgLYCaBAyXt+ZRkREZUAiCWARgKYi0lhEKgK4A8BH7oRFRERei7sISFVPiEg/ANMAlAcwWlVXRfjYqHi3dwrivgjivgjivgjivgjyZF+IqnqxXiIiSnO8E5iIyKeYAIiIfCopCeBU7TJCREaLyI7QextE5CwRmSEia5z3Gs50EZGXnX2wXEQuDflML2f5NSLSK2T6ZSKywvnMyyIiyf2G0RORBiIyW0S+F5FVIvKAM913+0NEThORhSLynbMvnnKmNxaRBU78E53GExCRSs74Wmd+o5B1PepM/1FEOoZML1O/KREpLyJLReQTZ9yX+0JENjj/w8sCTTtT+htRVU9fsAriXABNAFQE8B2All5vNxkvAFcCuBTAypBpzwEY5AwPAjDUGe4M4HMAAqAtgAXO9LMArHPeazjDNZx5C51lxfnsDan+zqXsizoALnWGqwH4CdZFiO/2hxNfVWc4A8ACJ+5JAO5wpo8A0McZ/gOAEc7wHQAmOsMtnd9LJQCNnd9R+bL4mwLQH8A/AXzijPtyXwDYAKBWkWkp+40k4wrglO0yQlXnAthTZHJXAGOd4bEAbg6Z/o6abwFUF5E6ADrt3wuTAAACxElEQVQCmKGqe1R1L4AZADo5885Q1W/V/rLvhKwr7ahqnqoucYb3A1gNu1vcd/vD+U4HnNEM56UArgEwxZledF8E9tEUANc6Z25dAUxQ1aOquh7AWtjvqUz9pkSkPoAuAN50xgU+3RclSNlvJBkJIFyXEfWSsN1Uqa2qec7wzwBqO8Ml7YfSpm8JMz3tOZftl8DOfH25P5wij2UAdsB+oLkA9qnqCWeR0Pj//Z2d+b8AqInY91G6GgZgIIACZ7wm/LsvFMB0EVks1k0OkMLfiOddQfiZqqqI+KqdrYhUBfAegAdVNT+0CNJP+0NVTwJoJSLVAbwPoEWKQ0oJEbkRwA5VXSwiHVIdTxpor6pbReRsADNE5IfQmcn+jSTjCsBvXUZsdy7F4LzvcKaXtB9Km14/zPS0JSIZsIP/OFWd6kz27f4AAFXdB2A2gCtgl/CBk67Q+P/9nZ35ZwLYjdj3UTpqB+AmEdkAK565BsBw+HNfQFW3Ou87YCcGrZHK30gSKj0qwCopGiNYSXOB19tN1gtAIxSuBH4ehSt0nnOGu6Bwhc5CDVborIdV5tRwhs/S8BU6nVP9fUvZDwIrcxxWZLrv9geATADVneHKAP4F4EYAk1G44vMPznBfFK74nOQMX4DCFZ/rYJWeZfI3BaADgpXAvtsXAKoAqBYy/DWATqn8jSTri3eGtQrJBfBYqv8QLn6v8QDyAByHlbfdAyuvnAVgDYCZIX8YgT1AJxfACgBZIeu5G1aptRbAXSHTswCsdD7zdzh3bqfjC0B7WPnmcgDLnFdnP+4PABcBWOrsi5UAHnemN3F+oGudA2AlZ/ppzvhaZ36TkHU95nzfHxHSoqMs/qZQOAH4bl843/k757UqEGsqfyPsCoKIyKd4JzARkU8xARAR+RQTABGRTzEBEBH5FBMAEZFPMQEQEfkUEwARkU/9f6KH32W1s6qGAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "episodes_p2 = [2647,5062,6988,7721,9006,9489,\n", " 10482,11303,12967,14767,17088,17370,17887,18599,19641,\n", " 20361,20921,22419,23198,24514,26366,27983,29873,\n", " 30851,31931,32370,34069,35248,36460,38501,39551,\n", " 40200,42374,43610]\n", "\n", "loss_p2 = [29.99,29.21,30.04,26.03,26.04,26.78,27.92,32.33,25.37,26.68,26.14,28.39,\n", " 33.06,24.59,26.5,30.07,30.02,32.05,25.4,29.14,28.68,30.82, 30.10,31.20,\n", " 33.85,30.20,35.34,31.25,32.28,30.29,32.97,29.07,31.01,28.14]\n", "\n", "plt.plot(episodes_p2, loss_p2, 'r--')\n", "plt.axis([0, 50000, 0, 40])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dqn.test(env, nb_episodes=10, visualize=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#SARSA Agent -- Reinforcement Learning \n", "from rl.agents.sarsa import SARSAAgent\n", "sarsa = SARSAAgent(model, nb_actions, \n", " policy=None, test_policy=None, \n", " gamma=0.99, nb_steps_warmup=10, \n", " train_interval=1)\n", "sarsa.compile(Adam(lr=1e-3), metrics=['mae', 'acc'])\n", "sarsa.fit(env, nb_steps=50000, visualize=True, verbose=2)\n", "sarsa.test(env, nb_episodes=10, visualize=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }