{ "nbformat": 4, "nbformat_minor": 2, "metadata": { "language_info": { "name": "python", "codemirror_mode": { "name": "ipython", "version": 3 }, "version": "3.7.0-final" }, "orig_nbformat": 2, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3, "kernelspec": { "name": "python37064bitbasecondaf1f4ce8bd9ee468caf98567667ef0765", "display_name": "Python 3.7.0 64-bit ('base': conda)" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "对于第12章的学习。\n", "\n", "两个案例,一个是19格随机游走;另一个是山间的摇摆车。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 第一个案例\n", "\n", "这个案例用于对比书上所提出的三种算法的表现:\n", "- 最基础的带有离线lambda思想的:`离线lambda-回报`\n", "- 与时序差分结合的,使用了资格迹的:`半梯度TD(lambda)`\n", "- 经过一系列讨论,得出的后向视图的、计算量小的算法:`真实在线TD(lambda)`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#######################################################################\n", "# Copyright (C) #\n", "# 2016-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #\n", "# 2016 Kenta Shimada(hyperkentakun@gmail.com) #\n", "# Permission given to modify the code as long as you keep this #\n", "# declaration at the top #\n", "#######################################################################\n", "\n", "import numpy as np\n", "import matplotlib\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm\n", "\n", "# all states\n", "N_STATES = 19\n", "\n", "# all states but terminal states\n", "STATES = np.arange(1, N_STATES + 1)\n", "\n", "# start from the middle state\n", "START_STATE = 10\n", "\n", "# two terminal states\n", "# an action leading to the left terminal state has reward -1\n", "# an action leading to the right terminal state has reward 1\n", "END_STATES = [0, N_STATES + 1]\n", "\n", "# true state values from Bellman equation\n", "TRUE_VALUE = np.arange(-20, 22, 2) / 20.0\n", "TRUE_VALUE[0] = TRUE_VALUE[N_STATES + 1] = 0.0\n", "\n", "# base class for lambda-based algorithms in this chapter\n", "# In this example, we use the simplest linear feature function, state aggregation.\n", "# And we use exact 19 groups, so the weights for each group is exact the value for that state\n", "class ValueFunction:\n", " # @rate: lambda, as it's a keyword in python, so I call it rate\n", " # @stepSize: alpha, step size for update\n", " def __init__(self, rate, step_size):\n", " self.rate = rate\n", " self.step_size = step_size\n", " \"\"\"\n", " 这里的 w 与 x 的关系还是一一对应且线性的\n", " 与表格型同\n", " \"\"\"\n", " self.weights = np.zeros(N_STATES + 2)\n", "\n", " # the state value is just the weight\n", " def value(self, state):\n", " return self.weights[state]\n", "\n", " # feed the algorithm with new observation\n", " # derived class should override this function\n", " def learn(self, state, reward):\n", " return\n", "\n", " # initialize some variables at the beginning of each episode\n", " # must be called at the very beginning of each episode\n", " # derived class should override this function\n", " def new_episode(self):\n", " return\n", "\n", "# Off-line lambda-return algorithm\n", "class OffLineLambdaReturn(ValueFunction):\n", " def __init__(self, rate, step_size):\n", " ValueFunction.__init__(self, rate, step_size)\n", " # To accelerate learning, set a truncate value for power of lambda\n", " self.rate_truncate = 1e-3\n", "\n", " def new_episode(self):\n", " # initialize the trajectory\n", " self.trajectory = [START_STATE]\n", " # only need to track the last reward in one episode, as all others are 0\n", " self.reward = 0.0\n", "\n", " def learn(self, state, reward):\n", " # add the new state to the trajectory\n", " self.trajectory.append(state)\n", " if state in END_STATES:\n", " # start off-line learning once the episode ends\n", " self.reward = reward\n", " self.T = len(self.trajectory) - 1\n", " self.off_line_learn()\n", "\n", " # get the n-step return from the given time\n", " def n_step_return_from_time(self, n, time):\n", " # gamma is always 1 and rewards are zero except for the last reward\n", " # the formula can be simplified\n", " \"\"\"\n", " 原公式 12.1\n", " 注意这里 G_{t:t+n}与公式中不同,考虑到任务的特殊性\n", " (简化了前面 R 累加的过程,因为只有 end_state 的 R 才不等于 0)\n", " \"\"\"\n", " end_time = min(time + n, self.T)\n", " returns = self.value(self.trajectory[end_time])\n", " if end_time == self.T:\n", " returns += self.reward\n", " return returns\n", "\n", " # get the lambda-return from the given time\n", " def lambda_return_from_time(self, time):\n", " returns = 0.0\n", " lambda_power = 1\n", " for n in range(1, self.T - time):\n", " returns += lambda_power * self.n_step_return_from_time(n, time)\n", " lambda_power *= self.rate\n", " if lambda_power < self.rate_truncate:\n", " \"\"\"\n", " 虽然算法中是加到 T - t - 1 项;\n", " 但是实际实现中,为了效率,省去过于小的项\n", " \"\"\"\n", " # If the power of lambda has been too small, discard all the following sequences\n", " break\n", " returns *= 1 - self.rate\n", " if lambda_power >= self.rate_truncate:\n", " returns += lambda_power * self.reward\n", " return returns\n", "\n", " # perform off-line learning at the end of an episode\n", " def off_line_learn(self):\n", " for time in range(self.T):\n", " \"\"\"\n", " 每个 time 都对应图 12.1 ,只不过方块不同\n", " 起点情况是方块为 T-0-1;\n", " 重点情况是方块为 T-(T-1)-1=0\n", " 换个角度理解:\n", " 如此,便为每个经历过的状态都来了至少一次的 lambda 回报\n", " \"\"\"\n", " # update for each state in the trajectory\n", " state = self.trajectory[time]\n", " delta = self.lambda_return_from_time(time) - self.value(state)\n", " delta *= self.step_size\n", " self.weights[state] += delta\n", "\n", "# TD(lambda) algorithm\n", "class TemporalDifferenceLambda(ValueFunction):\n", " def __init__(self, rate, step_size):\n", " ValueFunction.__init__(self, rate, step_size)\n", " self.new_episode()\n", "\n", " def new_episode(self):\n", " # initialize the eligibility trace\n", " self.eligibility = np.zeros(N_STATES + 2)\n", " # initialize the beginning state\n", " self.last_state = START_STATE\n", "\n", " def learn(self, state, reward):\n", " # update the eligibility trace and weights\n", " self.eligibility *= self.rate\n", " self.eligibility[self.last_state] += 1\n", " delta = reward + self.value(state) - self.value(self.last_state)\n", " delta *= self.step_size\n", " self.weights += delta * self.eligibility\n", " self.last_state = state\n", "\n", "# True online TD(lambda) algorithm\n", "class TrueOnlineTemporalDifferenceLambda(ValueFunction):\n", " def __init__(self, rate, step_size):\n", " ValueFunction.__init__(self, rate, step_size)\n", "\n", " def new_episode(self):\n", " # initialize the eligibility trace\n", " self.eligibility = np.zeros(N_STATES + 2)\n", " # initialize the beginning state\n", " self.last_state = START_STATE\n", " # initialize the old state value\n", " self.old_state_value = 0.0\n", "\n", " def learn(self, state, reward):\n", " # update the eligibility trace and weights\n", " last_state_value = self.value(self.last_state)\n", " state_value = self.value(state)\n", " dutch = 1 - self.step_size * self.rate * self.eligibility[self.last_state]\n", " \"\"\"\n", " *如下是我们在看书本是可能忽略的\n", " 这个类是对真实在线 TD(lambda) 的复现\n", " 我阅读时忽略了每次迭代更新的是向量,而非单个迹元素\n", " 因此一开始没有理解为什么是整个 self.eligibility *= self.rate\n", " \"\"\"\n", " self.eligibility *= self.rate\n", " self.eligibility[self.last_state] += dutch\n", " delta = reward + state_value - last_state_value\n", " self.weights += self.step_size * (delta + last_state_value - self.old_state_value) * self.eligibility\n", " self.weights[self.last_state] -= self.step_size * (last_state_value - self.old_state_value)\n", " self.old_state_value = state_value\n", " self.last_state = state\n", "\n", "# 19-state random walk\n", "def random_walk(value_function):\n", " value_function.new_episode()\n", " state = START_STATE\n", " while state not in END_STATES:\n", " next_state = state + np.random.choice([-1, 1])\n", " if next_state == 0:\n", " reward = -1\n", " elif next_state == N_STATES + 1:\n", " reward = 1\n", " else:\n", " reward = 0\n", " value_function.learn(next_state, reward)\n", " state = next_state\n", "\n", "# general plot framework\n", "# @valueFunctionGenerator: generate an instance of value function\n", "# @runs: specify the number of independent runs\n", "# @lambdas: a series of different lambda values\n", "# @alphas: sequences of step size for each lambda\n", "def parameter_sweep(value_function_generator, runs, lambdas, alphas):\n", " # play for 10 episodes for each run\n", " episodes = 10\n", " # track the rms errors\n", " errors = [np.zeros(len(alphas_)) for alphas_ in alphas]\n", " for run in tqdm(range(runs)):\n", " for lambdaIndex, rate in enumerate(lambdas):\n", " for alphaIndex, alpha in enumerate(alphas[lambdaIndex]):\n", " valueFunction = value_function_generator(rate, alpha)\n", " for episode in range(episodes):\n", " random_walk(valueFunction)\n", " stateValues = [valueFunction.value(state) for state in STATES]\n", " errors[lambdaIndex][alphaIndex] += np.sqrt(np.mean(np.power(stateValues - TRUE_VALUE[1: -1], 2)))\n", "\n", " # average over runs and episodes\n", " for error in errors:\n", " error /= episodes * runs\n", "\n", " for i in range(len(lambdas)):\n", " plt.plot(alphas[i], errors[i], label='lambda = ' + str(lambdas[i]))\n", " plt.xlabel('alpha')\n", " plt.ylabel('RMS error')\n", " plt.legend()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Figure 12.3: Off-line lambda-return algorithm\n", "def figure_12_3():\n", " lambdas = [0.0, 0.4, 0.8, 0.9, 0.95, 0.975, 0.99, 1]\n", " alphas = [np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 0.55, 0.05),\n", " np.arange(0, 0.22, 0.02),\n", " np.arange(0, 0.11, 0.01)]\n", " parameter_sweep(OffLineLambdaReturn, 50, lambdas, alphas)\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": "100%|██████████| 50/50 [03:50<00:00, 4.57s/it]\n" }, { "data": { "image/png": "\n", "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure_12_3()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Figure 12.6: TD(lambda) algorithm\n", "def figure_12_6():\n", " lambdas = [0.0, 0.4, 0.8, 0.9, 0.95, 0.975, 0.99, 1]\n", " alphas = [np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 0.99, 0.09),\n", " np.arange(0, 0.55, 0.05),\n", " np.arange(0, 0.33, 0.03),\n", " np.arange(0, 0.22, 0.02),\n", " np.arange(0, 0.11, 0.01),\n", " np.arange(0, 0.044, 0.004)]\n", " parameter_sweep(TemporalDifferenceLambda, 50, lambdas, alphas)\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": "100%|██████████| 50/50 [00:51<00:00, 1.02s/it]\n" }, { "data": { "image/png": "\n", "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure_12_6()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Figure 12.8: True online TD(lambda) algorithm\n", "def figure_12_8():\n", " lambdas = [0.0, 0.4, 0.8, 0.9, 0.95, 0.975, 0.99, 1]\n", " alphas = [np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 1.1, 0.1),\n", " np.arange(0, 0.88, 0.08),\n", " np.arange(0, 0.44, 0.04),\n", " np.arange(0, 0.11, 0.01)]\n", " parameter_sweep(TrueOnlineTemporalDifferenceLambda, 50, lambdas, alphas)\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": "100%|██████████| 50/50 [00:59<00:00, 1.23s/it]\n" }, { "data": { "image/png": "\n", "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure_12_8()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 第二个案例\n", "\n", "应用Sarsa(lambda),使用不同类型的迹。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "#######################################################################\n", "# Copyright (C) #\n", "# 2017-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #\n", "# Permission given to modify the code as long as you keep this #\n", "# declaration at the top #\n", "#######################################################################\n", "\n", "import numpy as np\n", "import matplotlib\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from math import floor\n", "from tqdm import tqdm\n", "\n", "#######################################################################\n", "# Following are some utilities for tile coding from Rich.\n", "# To make each file self-contained, I copied them from\n", "# http://incompleteideas.net/tiles/tiles3.py-remove\n", "# with some naming convention changes\n", "#\n", "# Tile coding starts\n", "class IHT:\n", " \"Structure to handle collisions\"\n", " def __init__(self, size_val):\n", " self.size = size_val\n", " self.overfull_count = 0\n", " self.dictionary = {}\n", "\n", " def count(self):\n", " return len(self.dictionary)\n", "\n", " def full(self):\n", " return len(self.dictionary) >= self.size\n", "\n", " def get_index(self, obj, read_only=False):\n", " d = self.dictionary\n", " if obj in d:\n", " return d[obj]\n", " elif read_only:\n", " return None\n", " size = self.size\n", " count = self.count()\n", " if count >= size:\n", " if self.overfull_count == 0: print('IHT full, starting to allow collisions')\n", " self.overfull_count += 1\n", " return hash(obj) % self.size\n", " else:\n", " d[obj] = count\n", " return count\n", "\n", "def hash_coords(coordinates, m, read_only=False):\n", " if isinstance(m, IHT): return m.get_index(tuple(coordinates), read_only)\n", " if isinstance(m, int): return hash(tuple(coordinates)) % m\n", " if m is None: return coordinates\n", "\n", "def tiles(iht_or_size, num_tilings, floats, ints=None, read_only=False):\n", " \"\"\"returns num-tilings tile indices corresponding to the floats and ints\"\"\"\n", " if ints is None:\n", " ints = []\n", " qfloats = [floor(f * num_tilings) for f in floats]\n", " tiles = []\n", " for tiling in range(num_tilings):\n", " tilingX2 = tiling * 2\n", " coords = [tiling]\n", " b = tiling\n", " for q in qfloats:\n", " coords.append((q + b) // num_tilings)\n", " b += tilingX2\n", " coords.extend(ints)\n", " tiles.append(hash_coords(coords, iht_or_size, read_only))\n", " return tiles\n", "# Tile coding ends\n", "#######################################################################\n", "\n", "# all possible actions\n", "ACTION_REVERSE = -1\n", "ACTION_ZERO = 0\n", "ACTION_FORWARD = 1\n", "# order is important\n", "ACTIONS = [ACTION_REVERSE, ACTION_ZERO, ACTION_FORWARD]\n", "\n", "# bound for position and velocity\n", "POSITION_MIN = -1.2\n", "POSITION_MAX = 0.5\n", "VELOCITY_MIN = -0.07\n", "VELOCITY_MAX = 0.07\n", "\n", "# discount is always 1.0 in these experiments\n", "DISCOUNT = 1.0\n", "\n", "# use optimistic initial value, so it's ok to set epsilon to 0\n", "EPSILON = 0\n", "\n", "# maximum steps per episode\n", "STEP_LIMIT = 5000\n", "\n", "# take an @action at @position and @velocity\n", "# @return: new position, new velocity, reward (always -1)\n", "def step(position, velocity, action):\n", " new_velocity = velocity + 0.001 * action - 0.0025 * np.cos(3 * position)\n", " new_velocity = min(max(VELOCITY_MIN, new_velocity), VELOCITY_MAX)\n", " new_position = position + new_velocity\n", " new_position = min(max(POSITION_MIN, new_position), POSITION_MAX)\n", " reward = -1.0\n", " if new_position == POSITION_MIN:\n", " new_velocity = 0.0\n", " return new_position, new_velocity, reward\n", "\n", "# accumulating trace update rule\n", "# @trace: old trace (will be modified)\n", "# @activeTiles: current active tile indices\n", "# @lam: lambda\n", "# @return: new trace for convenience\n", "def accumulating_trace(trace, active_tiles, lam):\n", " \"\"\"\n", " 注意此处 w 还是由瓦片编码,\n", " 因此对应的迹 z 也是对应维度的\n", " \"\"\"\n", " trace *= lam * DISCOUNT\n", " trace[active_tiles] += 1\n", " return trace\n", "\n", "# replacing trace update rule\n", "# @trace: old trace (will be modified)\n", "# @activeTiles: current active tile indices\n", "# @lam: lambda\n", "# @return: new trace for convenience\n", "def replacing_trace(trace, activeTiles, lam):\n", " \"\"\"\n", " np.aragnge() 产生了索引序列\n", " 如果上述数字在 activeTiles 中出现过 True\n", " np.in1d() 返回 [boolean] 列表\n", " \"\"\"\n", " active = np.in1d(np.arange(len(trace)), activeTiles)\n", " trace[active] = 1\n", " \"\"\"\n", " ~ 按位取反,~3=-4\n", " \"\"\"\n", " trace[~active] *= lam * DISCOUNT\n", " return trace\n", "\n", "# replacing trace update rule, 'clearing' means set all tiles corresponding to non-selected actions to 0\n", "# @trace: old trace (will be modified)\n", "# @activeTiles: current active tile indices\n", "# @lam: lambda\n", "# @clearingTiles: tiles to be cleared\n", "# @return: new trace for convenience\n", "def replacing_trace_with_clearing(trace, active_tiles, lam, clearing_tiles):\n", " active = np.in1d(np.arange(len(trace)), active_tiles)\n", " trace[~active] *= lam * DISCOUNT\n", " trace[clearing_tiles] = 0\n", " trace[active] = 1\n", " return trace\n", "\n", "# dutch trace update rule\n", "# @trace: old trace (will be modified)\n", "# @activeTiles: current active tile indices\n", "# @lam: lambda\n", "# @alpha: step size for all tiles\n", "# @return: new trace for convenience\n", "def dutch_trace(trace, active_tiles, lam, alpha):\n", " coef = 1 - alpha * DISCOUNT * lam * np.sum(trace[active_tiles])\n", " trace *= DISCOUNT * lam\n", " trace[active_tiles] += coef\n", " return trace\n", "\n", "# wrapper class for Sarsa(lambda)\n", "class Sarsa:\n", " # In this example I use the tiling software instead of implementing standard tiling by myself\n", " # One important thing is that tiling is only a map from (state, action) to a series of indices\n", " # It doesn't matter whether the indices have meaning, only if this map satisfy some property\n", " # View the following webpage for more information\n", " # http://incompleteideas.net/sutton/tiles/tiles3.html\n", " # @maxSize: the maximum # of indices\n", " def __init__(self, step_size, lam, trace_update=accumulating_trace, num_of_tilings=8, max_size=2048):\n", " self.max_size = max_size\n", " self.num_of_tilings = num_of_tilings\n", " self.trace_update = trace_update\n", " self.lam = lam\n", "\n", " # divide step size equally to each tiling\n", " self.step_size = step_size / num_of_tilings\n", "\n", " self.hash_table = IHT(max_size)\n", "\n", " # weight for each tile\n", " self.weights = np.zeros(max_size)\n", "\n", " # trace for each tile\n", " self.trace = np.zeros(max_size)\n", "\n", " # position and velocity needs scaling to satisfy the tile software\n", " self.position_scale = self.num_of_tilings / (POSITION_MAX - POSITION_MIN)\n", " self.velocity_scale = self.num_of_tilings / (VELOCITY_MAX - VELOCITY_MIN)\n", "\n", " # get indices of active tiles for given state and action\n", " def get_active_tiles(self, position, velocity, action):\n", " # I think positionScale * (position - position_min) would be a good normalization.\n", " # However positionScale * position_min is a constant, so it's ok to ignore it.\n", " active_tiles = tiles(self.hash_table, self.num_of_tilings,\n", " [self.position_scale * position, self.velocity_scale * velocity],\n", " [action])\n", " return active_tiles\n", "\n", " # estimate the value of given state and action\n", " def value(self, position, velocity, action):\n", " if position == POSITION_MAX:\n", " return 0.0\n", " active_tiles = self.get_active_tiles(position, velocity, action)\n", " return np.sum(self.weights[active_tiles])\n", "\n", " # learn with given state, action and target\n", " def learn(self, position, velocity, action, target):\n", " active_tiles = self.get_active_tiles(position, velocity, action)\n", " estimation = np.sum(self.weights[active_tiles])\n", " delta = target - estimation\n", " if self.trace_update == accumulating_trace or self.trace_update == replacing_trace:\n", " self.trace_update(self.trace, active_tiles, self.lam)\n", " elif self.trace_update == dutch_trace:\n", " self.trace_update(self.trace, active_tiles, self.lam, self.step_size)\n", " elif self.trace_update == replacing_trace_with_clearing:\n", " clearing_tiles = []\n", " for act in ACTIONS:\n", " if act != action:\n", " \"\"\"\n", " 如果这个 s 中, a 不是当前选择的动作,\n", " 删去这个 (s, a) 对应的瓦片\n", " \"\"\"\n", " clearing_tiles.extend(self.get_active_tiles(position, velocity, act))\n", " self.trace_update(self.trace, active_tiles, self.lam, clearing_tiles)\n", " else:\n", " raise Exception('Unexpected Trace Type')\n", " self.weights += self.step_size * delta * self.trace\n", "\n", " # get # of steps to reach the goal under current state value function\n", " def cost_to_go(self, position, velocity):\n", " costs = []\n", " for action in ACTIONS:\n", " costs.append(self.value(position, velocity, action))\n", " return -np.max(costs)\n", "\n", "# get action at @position and @velocity based on epsilon greedy policy and @valueFunction\n", "def get_action(position, velocity, valueFunction):\n", " if np.random.binomial(1, EPSILON) == 1:\n", " return np.random.choice(ACTIONS)\n", " values = []\n", " for action in ACTIONS:\n", " values.append(valueFunction.value(position, velocity, action))\n", " return np.argmax(values) - 1\n", "\n", "# play Mountain Car for one episode based on given method @evaluator\n", "# @return: total steps in this episode\n", "def play(evaluator):\n", " position = np.random.uniform(-0.6, -0.4)\n", " velocity = 0.0\n", " action = get_action(position, velocity, evaluator)\n", " steps = 0\n", " while True:\n", " next_position, next_velocity, reward = step(position, velocity, action)\n", " next_action = get_action(next_position, next_velocity, evaluator)\n", " steps += 1\n", " target = reward + DISCOUNT * evaluator.value(next_position, next_velocity, next_action)\n", " evaluator.learn(position, velocity, action, target)\n", " position = next_position\n", " velocity = next_velocity\n", " action = next_action\n", " if next_position == POSITION_MAX:\n", " break\n", " if steps >= STEP_LIMIT:\n", " print('Step Limit Exceeded!')\n", " break\n", " return steps" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# figure 12.10, effect of the lambda and alpha on early performance of Sarsa(lambda)\n", "def figure_12_10():\n", " runs = 30\n", " episodes = 50\n", " alphas = np.arange(1, 8) / 4.0\n", " lams = [0.99, 0.95, 0.5, 0]\n", "\n", " steps = np.zeros((len(lams), len(alphas), runs, episodes))\n", " for lamInd, lam in enumerate(lams):\n", " for alphaInd, alpha in enumerate(alphas):\n", " for run in tqdm(range(runs)):\n", " evaluator = Sarsa(alpha, lam, replacing_trace)\n", " for ep in range(episodes):\n", " step = play(evaluator)\n", " steps[lamInd, alphaInd, run, ep] = step\n", "\n", " # average over episodes\n", " steps = np.mean(steps, axis=3)\n", "\n", " # average over runs\n", " steps = np.mean(steps, axis=2)\n", "\n", " for lamInd, lam in enumerate(lams):\n", " plt.plot(alphas, steps[lamInd, :], label='lambda = %s' % (str(lam)))\n", " plt.xlabel('alpha * # of tilings (8)')\n", " plt.ylabel('averaged steps per episode')\n", " plt.ylim([180, 300])\n", " plt.legend()\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": "20%|██ | 6/30 [00:14<00:58, 2.44s/it]Step Limit Exceeded!\n100%|██████████| 30/30 [01:15<00:00, 2.67s/it]\n100%|██████████| 30/30 [01:02<00:00, 2.18s/it]\n100%|██████████| 30/30 [01:03<00:00, 2.16s/it]\n100%|██████████| 30/30 [01:02<00:00, 2.08s/it]\n100%|██████████| 30/30 [01:02<00:00, 2.12s/it]\n 70%|███████ | 21/30 [00:57<00:23, 2.65s/it]Step Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\n100%|██████████| 30/30 [01:41<00:00, 3.20s/it]\n 0%| | 0/30 [00:00\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure_12_10()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# figure 12.11, summary comparision of Sarsa(lambda) algorithms\n", "# I use 8 tilings rather than 10 tilings\n", "def figure_12_11():\n", " traceTypes = [dutch_trace, replacing_trace, replacing_trace_with_clearing, accumulating_trace]\n", " alphas = np.arange(0.2, 2.2, 0.2)\n", " episodes = 20\n", " runs = 30\n", " lam = 0.9\n", " rewards = np.zeros((len(traceTypes), len(alphas), runs, episodes))\n", "\n", " for traceInd, trace in enumerate(traceTypes):\n", " for alphaInd, alpha in enumerate(alphas):\n", " for run in tqdm(range(runs)):\n", " evaluator = Sarsa(alpha, lam, trace)\n", " for ep in range(episodes):\n", " if trace == accumulating_trace and alpha > 0.6:\n", " \"\"\"\n", " alpha 大于 0.6 的积累迹,不计算\n", " 直接给最差的步数\n", " \"\"\"\n", " steps = STEP_LIMIT\n", " else:\n", " steps = play(evaluator)\n", " rewards[traceInd, alphaInd, run, ep] = -steps\n", "\n", " # average over episodes\n", " rewards = np.mean(rewards, axis=3)\n", "\n", " # average over runs\n", " rewards = np.mean(rewards, axis=2)\n", "\n", " for traceInd, trace in enumerate(traceTypes):\n", " plt.plot(alphas, rewards[traceInd, :], label=trace.__name__)\n", " plt.xlabel('alpha * # of tilings (8)')\n", " plt.ylabel('averaged rewards pre episode')\n", " plt.ylim([-550, -150])\n", " plt.legend()\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": "100%|██████████| 30/30 [00:41<00:00, 1.38s/it]\n100%|██████████| 30/30 [00:33<00:00, 1.18s/it]\n100%|██████████| 30/30 [00:28<00:00, 1.04it/s]\n100%|██████████| 30/30 [00:27<00:00, 1.05it/s]\n100%|██████████| 30/30 [00:27<00:00, 1.11it/s]\n100%|██████████| 30/30 [00:28<00:00, 1.07it/s]\n100%|██████████| 30/30 [00:26<00:00, 1.12it/s]\n100%|██████████| 30/30 [00:28<00:00, 1.08s/it]\n100%|██████████| 30/30 [00:27<00:00, 1.17it/s]\n100%|██████████| 30/30 [00:28<00:00, 1.11it/s]\n100%|██████████| 30/30 [01:10<00:00, 2.36s/it]\n100%|██████████| 30/30 [00:47<00:00, 1.55s/it]\n100%|██████████| 30/30 [00:39<00:00, 1.25s/it]\n100%|██████████| 30/30 [00:36<00:00, 1.18s/it]\n100%|██████████| 30/30 [00:35<00:00, 1.11s/it]\n100%|██████████| 30/30 [00:35<00:00, 1.31s/it]\n100%|██████████| 30/30 [00:34<00:00, 1.20s/it]\n100%|██████████| 30/30 [00:33<00:00, 1.02s/it]\n100%|██████████| 30/30 [00:34<00:00, 1.18s/it]\n100%|██████████| 30/30 [00:34<00:00, 1.14s/it]\n100%|██████████| 30/30 [01:39<00:00, 3.31s/it]\n100%|██████████| 30/30 [01:09<00:00, 2.23s/it]\n100%|██████████| 30/30 [01:01<00:00, 2.04s/it]\n100%|██████████| 30/30 [00:55<00:00, 1.77s/it]\n100%|██████████| 30/30 [00:50<00:00, 1.75s/it]\n100%|██████████| 30/30 [00:47<00:00, 1.53s/it]\n100%|██████████| 30/30 [00:45<00:00, 1.43s/it]\n100%|██████████| 30/30 [00:43<00:00, 1.52s/it]\n100%|██████████| 30/30 [00:51<00:00, 2.13s/it]\n 40%|████ | 12/30 [00:21<00:32, 1.82s/it]Step Limit Exceeded!\nStep Limit Exceeded!\nStep Limit Exceeded!\n100%|██████████| 30/30 [00:56<00:00, 1.71s/it]\n100%|██████████| 30/30 [00:44<00:00, 2.00s/it]\n100%|██████████| 30/30 [00:45<00:00, 1.23s/it]\n100%|██████████| 30/30 [00:00<00:00, 15040.54it/s]\n100%|██████████| 30/30 [00:00<00:00, 30109.86it/s]\n100%|██████████| 30/30 [00:00<00:00, 30037.99it/s]\n100%|██████████| 30/30 [00:00\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure_12_11()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ] }