{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/erle/.local/lib/python2.7/site-packages/matplotlib/__init__.py:1350: UserWarning: This call to matplotlib.use() has no effect\n", "because the backend has already been chosen;\n", "matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\n", "or matplotlib.backends is imported for the first time.\n", "\n", " warnings.warn(_use_error_msg)\n", "[2016-09-21 14:55:04,801] Making new env: MountainCar-v0\n" ] } ], "source": [ "import matplotlib\n", "from matplotlib import pyplot\n", "import numpy as np\n", "\n", "import sys\n", "sys.path.append(\"..\")\n", "from hiora_cartpole import features\n", "from hiora_cartpole import fourier_fa\n", "from hiora_cartpole import driver\n", "\n", "import gym\n", "\n", "env = gym.make('MountainCar-v0')\n", "\n", "state_ranges = np.array([env.observation_space.low, env.observation_space.high])\n", "four_n_weights, four_feature_vec \\\n", " = fourier_fa.make_feature_vec(state_ranges,\n", " n_acts=3,\n", " order=7)\n", "\n", "#fv = feature_vec(cartpole.observation_space.sample(), cartpole.action_space.sample())\n", "\n", "from hiora_cartpole import linfa\n", "experience = linfa.init(lmbda=0.9,\n", " init_alpha=1.0,\n", " epsi=0.01,\n", " feature_vec=four_feature_vec,\n", " n_weights=four_n_weights,\n", " act_space=env.action_space,\n", " theta=None,\n", " is_use_alpha_bounds=True)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "collapsed": false }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " this.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width);\n", " canvas.attr('height', height);\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib notebook\n", "state_ranges = np.array([env.observation_space.low, env.observation_space.high])\n", "driver.plot_2D_V(state_ranges, env.action_space, four_feature_vec, experience.theta)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "collapsed": false }, "outputs": [], "source": [ "experience, steps_per_episode = driver.train(env, linfa, experience, n_episodes=10, max_steps=200, is_render=True)tg \n", "#pyplot.plot(steps_per_episode)\n", "#pyplot.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2016-09-20 10:13:46,516] Making new env: MountainCar-v0\n" ] } ], "source": [ "from hiora_cartpole import features\n", "\n", "env = gym.make('MountainCar-v0')\n", "\n", "state_ranges = np.array([env.observation_space.low, env.observation_space.high])\n", "\n", "tilec_n_weights, tilec_feature_vec = features.make_feature_vec(state_ranges, 3, [9, 9], 5)\n", "\n", "#fv = feature_vec(cartpole.observation_space.sample(), cartpole.action_space.sample())\n", "\n", "from hiora_cartpole import linfa\n", "fexperience = linfa.init(lmbda=0.9,\n", " init_alpha=1.0,\n", " epsi=0.01,\n", " feature_vec=tilec_feature_vec,\n", " n_weights=tilec_n_weights,\n", " act_space=env.action_space,\n", " theta=None,\n", " is_use_alpha_bounds=True)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m----------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0mTraceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfexperience\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps_per_episode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha_per_episode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdriver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlinfa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfexperience\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_episodes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_render\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m# Credits: http://matplotlib.org/examples/api/two_scales.html\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0max1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyplot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msubplots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0max1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msteps_per_episode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'b'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0max2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0max1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtwinx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/driver.pyc\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(env, learner, experience, n_episodes, max_steps, is_render)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mxrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mis_render\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pylint: disable=expression-not-assigned\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m experience, action = learner.think(experience, observation, reward,\n\u001b[1;32m 57\u001b[0m done)\n", "\u001b[0;32m/home/erle/repos/gym/gym/core.pyc\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedMode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unsupported rendering mode: {}. (Supported modes for {}: {})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 189\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/gym/gym/envs/classic_control/mountain_car.pyc\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcartrans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_rotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mpos\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreturn_rgb_array\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/home/erle/repos/gym/gym/envs/classic_control/rendering.pyc\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, return_rgb_array)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_rgb_array\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mglClearColor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mswitch_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/.local/lib/python2.7/site-packages/pyglet/window/__init__.pyc\u001b[0m in \u001b[0;36mclear\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1149\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mwindow\u001b[0m \u001b[0mmust\u001b[0m \u001b[0mbe\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mactive\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msee\u001b[0m \u001b[0;34m`\u001b[0m\u001b[0mswitch_to\u001b[0m\u001b[0;34m`\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1150\u001b[0m '''\n\u001b[0;32m-> 1151\u001b[0;31m \u001b[0mgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglClear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGL_COLOR_BUFFER_BIT\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0mgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGL_DEPTH_BUFFER_BIT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdispatch_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/.local/lib/python2.7/site-packages/pyglet/gl/lib.pyc\u001b[0m in \u001b[0;36merrcheck\u001b[0;34m(result, func, arguments)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0merrcheck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marguments\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_debug_gl_trace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "fexperience, steps_per_episode, alpha_per_episode \\\n", " = driver.train(env, linfa, fexperience, n_episodes=100, max_steps=200, is_render=True)\n", "# Credits: http://matplotlib.org/examples/api/two_scales.html\n", "fig, ax1 = pyplot.subplots()\n", "ax1.plot(steps_per_episode, color='b')\n", "ax2 = ax1.twinx()\n", "ax2.plot(alpha_per_episode, color='r')\n", "pyplot.show()" ] }, { "cell_type": "code", "execution_count": 103, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(40, 40) (40, 40) (40, 40)\n" ] }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " this.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width);\n", " canvas.attr('height', height);\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eexperience, steps_per_episode, alpha_per_episode \\\n", " = driver.train(env, linfa, eexperience, n_episodes=5000, max_steps=200, is_render=False)\n", "# Credits: http://matplotlib.org/examples/api/two_scales.html\n", "fig, ax1 = pyplot.subplots()\n", "ax1.plot(steps_per_episode, color='b')\n", "ax2 = ax1.twinx()\n", "ax2.plot(alpha_per_episode, color='r')\n", "pyplot.show()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " this.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width);\n", " canvas.attr('height', height);\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib notebook\n", "driver.plot_2D_V(state_ranges, env.action_space, easyt_feature_vec, -eexperience.theta)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [], "source": [ "eexperience, steps_per_episode, alpha_per_episode \\\n", " = driver.train(env, linfa, eexperience, n_episodes=20, max_steps=200, is_render=True)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.4 s, sys: 4 ms, total: 4.4 s\n", "Wall time: 4.4 s\n" ] }, { "data": { "text/plain": [ "(Immutable(feature_vec=, theta=array([ 0. , -1.6374056, -7.3466172, ..., 0. , 0. , 0. ]), E=array([ 0., 0., 0., ..., 0., 0., 0.]), epsi=0.01, init_alpha=1.0, p_alpha=0.024309124525753373, lmbda=0.9, p_obs=None, p_act=None, p_feat=None, act_space=Discrete(3), is_use_alpha_bounds=True),\n", " array([200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200,\n", " 200, 200, 200, 200, 200, 161, 200, 200, 200, 151, 174, 200, 169,\n", " 171, 200, 138, 188, 141, 88, 88, 135, 140, 141, 92, 158, 150,\n", " 88, 90, 86, 139, 144, 161, 122, 161, 152, 136, 135, 148, 139,\n", " 100, 141, 140, 173, 200, 154, 152, 145, 138, 139, 135, 132, 134,\n", " 134, 134, 136, 138, 134, 134, 119, 116, 110, 109, 187, 135, 117,\n", " 120, 113, 116, 89, 136, 103, 104, 106, 109, 128, 138, 86, 137,\n", " 144, 140, 136, 88, 130, 84, 83, 168, 129], dtype=int32),\n", " array([ 0.04487296, 0.03067321, 0.03067321, 0.03067321, 0.03067321,\n", " 0.02911278, 0.02911278, 0.02702132, 0.02702132, 0.02702132,\n", " 0.02702132, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912,\n", " 0.02430912, 0.02430912, 0.02430912, 0.02430912, 0.02430912]))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time driver.train(env, linfa, fexperience, n_episodes=100, max_steps=200, is_render=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Faster." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m----------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0mTraceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mmax_episodes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10000\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mavg_window\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m max_diff=1)\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mprint\u001b[0m \u001b[0mepisode_nr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_avg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/driver.pyc\u001b[0m in \u001b[0;36mtrain_until_converged\u001b[0;34m(env, train_and_prep, init_experience, max_steps, max_episodes, avg_window, max_diff)\u001b[0m\n\u001b[1;32m 138\u001b[0m imap(lambda (a1, a2): abs(a1 - a2),\n\u001b[1;32m 139\u001b[0m pairwise(\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0maverages\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlimited_cnts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mavg_window\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m )\n\u001b[1;32m 142\u001b[0m )\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/driver.pyc\u001b[0m in \u001b[0;36miterate\u001b[0;34m(f, x)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m#### Some iterators\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/driver.pyc\u001b[0m in \u001b[0;36m\u001b[0;34m((_, dt))\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0mtrain_and_prep_ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_and_prep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m cnts_dtimesteps = iterate(lambda (_, dt): train_and_prep_ms(dt),\n\u001b[0m\u001b[1;32m 126\u001b[0m (0, first_dtimestep))\n\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/driver.pyc\u001b[0m in \u001b[0;36mtrain_and_prep_inner\u001b[0;34m(first_dtimestep, max_steps)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0mn_dtimestep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mn_dtimestep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdtimestep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext_dtimestep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtimestep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdtimestep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/driver.pyc\u001b[0m in \u001b[0;36mnext_dtimestep_inner\u001b[0;34m(dtimestep)\u001b[0m\n\u001b[1;32m 77\u001b[0m new_experience, new_action = think(dtimestep.experience,\n\u001b[1;32m 78\u001b[0m \u001b[0mdtimestep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobservation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m dtimestep.reward)\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0mnew_observation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_reward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_done\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_action\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/linfa.pyc\u001b[0m in \u001b[0;36mthink\u001b[0;34m(e, o, r, done)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mchoose_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# action\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0mfeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_vec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0mQnext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfeat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/linfa.pyc\u001b[0m in \u001b[0;36mchoose_action\u001b[0;34m(e, o)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mchoose_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mtrue_with_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepsi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/erle/repos/cartpole/hiora_cartpole/linfa.pyc\u001b[0m in \u001b[0;36mtrue_with_prob\u001b[0;34m(p)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrue_with_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# What do I want?\n", "#\n", "# - Write one procedure that just trains for a number of episodes.\n", "#\n", "# - Write another procedure that keeps a running average of episode lengths and\n", "# stops training when the average doesn't change much anymore.\n", "#\n", "# - Possibly write a procedure that returns the sequence of Q functions\n", "# resulting from training.\n", "\n", "\n", "next_dtimestep = driver.make_next_dtimestep(env, linfa.think)\n", "train_and_prep = driver.make_train_and_prep(env, next_dtimestep, linfa.wrapup)\n", "\n", "episode_nr, last_avg, experience \\\n", " = driver.train_until_converged(\n", " env=env,\n", " train_and_prep=train_and_prep,\n", " init_experience=experience,\n", " max_steps=100,\n", " max_episodes=10000,\n", " avg_window=200,\n", " max_diff=1)\n", "print episode_nr, last_avg\n", "\n", "\n", "\n", "\n", "cnts_dtimesteps = driver.cnts_dtimesteps_iter(env, train_and_prep, experience,\n", " 100)\n", "\n", "thetas = driver.train_return_thetas(cnts_dtimesteps, 1000)\n", "\n", "sqes = 1.0 / experience.theta.shape[0] * np.sum(np.diff(thetas, axis=0) ** 2, axis=1)\n", "\n", "pyplot.plot(sqes)\n", "pyplot.show()\n", "\n", "sums = np.sum(np.abs(thetas), axis=1)\n", "pyplot.plot(sums)\n", "pyplot.show()\n", "\n", "with np.load(\"hard-earned-theta.npz\") as data:\n", " old_theta = data['arr_0']\n", " print np.sum(old_theta)\n", "\n", "#hard_earned_theta = np.copy(experience.theta)\n", "#np.savez_compressed(\"hard-earned-theta\", hard_earned_theta)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([1, 2])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.array([[1, 2], [3, 4]])[0]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "env.close()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12+" } }, "nbformat": 4, "nbformat_minor": 1 }