{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SNUZ - white noise to drive _and_ relax you!\n", "\n", "> Arrive at your destination safely, comfortably, and well rested. We combine state-of-the-art methods in random search to get you safely to your destination. Using random methods lets us generate efficient routes, and high quality (mandatory) white noise for your journey -- across the town or across the country!\n", "\n", "In this experiment an autonomous car will learn to drive up a hill. We'll compare random search ([ARS](https://arxiv.org/abs/1803.07055)) to Proximal Policy Optimization ([PPO](https://blog.openai.com/openai-baselines-ppo/)).\n", "\n", "# Aims\n", "1. Install pytorch, et al\n", "2. Answer the question: does random search do better than a state of the 'cart' RL method in ...one of the simplest continuous control tasks?\n", "3. _Acquirehire_.\n", "\n", "\n", "# Install\n", "Before doing anything else, we need to install some libraries.\n", "\n", "From the command line, run:\n", "\n", "`pip install gym`\n", "\n", "`pip install ray`\n", "\n", "`pip install opencv-python`\n", " \n", "Then for your OS, do:\n", "\n", "## Mac\n", "`conda install pytorch torchvision -c pytorch`\n", "## Linux\n", "`conda install pytorch torchvision -c pytorch`\n", "## Windows\n", "`conda install pytorch -c pytorch`\n", "\n", "`pip3 install torchvision`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ADMCode import visualize as vis\n", "from ADMCode.snuz import run_ppo\n", "from ADMCode.snuz import run_ars\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import warnings\n", "\n", "warnings.simplefilter('ignore', np.RankWarning)\n", "warnings.filterwarnings(\"ignore\", module=\"matplotlib\")\n", "warnings.filterwarnings(\"ignore\")\n", "sns.set(style='white', font_scale=1.3)\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'png'\n", "%config InlineBackend.savefig.dpi = 150" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Task\n", "\n", "We're going to teaching a car to drive up a hill! This is the `MountainCarContinuous-v0` from the OpenAI [gym].(https://gym.openai.com)\n", "\n", "![car](images/car.gif)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Vrooooom!\n", "Let's get driving, uphill! First let's try PPO.\n", "\n", "\n", "## PPO\n", "\n", "The default hyperparameters are:\n", "\n", " gamma = 0.99 # Try me?\n", " lam = 0.98 # Try me?\n", " actor_hidden1 = 64 # Try me?\n", " actor_hidden2 = 64 # Try me?\n", " actor_hidden3 = 64 # Try me?\n", " critic_hidden1 = 64 # Try me?\n", " critic_lr = 0.0003 # Try me? (small changes)\n", " actor_lr = 0.0003 # Try me? (small changes)\n", " batch_size = 64 # Leave me be\n", " l2_rate = 0.001 # Leave me be\n", " clip_param = 0.2 # Leave me be\n", " num_training_epochs = 10 # Try me?\n", " num_episodes = 10 # Try me?\n", " num_memories = 24 # Try me?\n", " num_training_epochs = 10 # Try me?\n", " clip_actions = True # Leave me be\n", " clip_std = 1.0 # Leave me be\n", " seed_value = None # Try me (with int only)\n", " \n", "Parameters can be changed by passing to `run_ppo`. For example `run_ppo(num_episodes=20, actor_lr=0.0006`) doubles the train time and the learning rate of the PPO." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'run_ppo' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mepisodes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_ppo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_episodes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'run_ppo' is not defined" ] } ], "source": [ "episodes, scores = run_ppo(render=True, num_episodes=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the average reward / episode." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(episodes, scores)\n", "plt.xlabel(\"Episode\")\n", "plt.xlabel(\"Reward\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compare, say, 10 episodes of PPO to 10 of...\n", "\n", "\n", "## ARS\n", "\n", "The [ARS](https://arxiv.org/abs/1803.07055) code was modified from Recht's [original source](https://github.com/modestyachts/ARS). \n", "\n", "\n", "The default hyperparameters are:\n", "\n", " num_episodes = 10 # Try me?\n", " n_directions = 8 # Try me?\n", " deltas_used = 8 # Try me?\n", " step_size = 0.02 # Try me?\n", " delta_std = 0.03 # Try me?\n", " n_workers = 1 # Leave me be\n", " rollout_length = 240 # Try me?\n", " shift = 0 # Leave me be (all below)\n", " seed = 237\n", " policy_type = 'linear'\n", " dir_path = 'data'\n", " filter = 'MeanStdFilter' # Leave me be\n", " \n", " _Note_: Due to the way the backend of ARS works (it uses a [ray](https://ray.readthedocs.io/en/latest/), a dist. job system) we can't render exps here. Sorry. :(" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "episodes, scores = run_ars(num_episodes=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(episodes, scores)\n", "plt.xlabel(\"Episode\")\n", "plt.xlabel(\"Reward\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }