{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Policy Iteration\n", "\n", "由策略评估价值,根据价值改进策略...反反复复,这叫“策略迭代”。\n", "\n", "算法如下。\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example 4.2 Jack’s Car Rental\n", "\n", "Jack manages two locations for a nationwide car rental company. Each day, some number of customers arrive at each location to rent cars. If Jack has a car available, he rents it out and is credited 10 dollars by the national company. If he is out of cars at that location, then the business is lost. Cars become available for renting the day after they are returned. To help ensure that cars are available where they are needed, Jack can move them between the two locations overnight, at a cost of 2 dollars per car moved. We assume that the number of cars requested and returned at each\n", "location are Poisson random variables, meaning that the probability that the number is $n$ is $\\frac{\\lambda^n}{n!}e^{-\\lambda}$, where $\\lambda$ is the expected number. Suppose $\\lambda$ is 3 and 4 for rental requests at the first and second locations and 3 and 2 for returns. To simplify the problem slightly, we assume that there can be no more than 20 cars at each location (any additional cars are returned to the nationwide company, and thus disappear from the problem) and a maximum of five cars can be moved from one location to the other in one night. We take the discount rate to be # = 0.9 and formulate this as a continuing finite MDP, where the time steps are days, the state is the number of cars at each location at the end of the day, and the actions are the net numbers of cars moved between the two locations overnight." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-01-16T13:04:57.490129Z", "start_time": "2020-01-16T13:04:52.199283Z" } }, "outputs": [], "source": [ "#######################################################################\n", "# Copyright (C) #\n", "# 2016 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #\n", "# 2016 Kenta Shimada(hyperkentakun@gmail.com) #\n", "# 2017 Aja Rangaswamy (aja004@gmail.com) #\n", "# Permission given to modify the code as long as you keep this #\n", "# declaration at the top #\n", "#######################################################################\n", "\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import seaborn as sns\n", "from scipy.stats import poisson\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-01-16T13:07:49.225095Z", "start_time": "2020-01-16T13:07:49.220114Z" } }, "outputs": [], "source": [ "# maximum # of cars in each location\n", "MAX_CARS = 20\n", "\n", "# maximum # of cars to move during night\n", "MAX_MOVE_OF_CARS = 5\n", "\n", "# expectation for rental requests in first location\n", "RENTAL_REQUEST_FIRST_LOC = 3\n", "\n", "# expectation for rental requests in second location\n", "RENTAL_REQUEST_SECOND_LOC = 4\n", "\n", "# expectation for # of cars returned in first location\n", "RETURNS_FIRST_LOC = 3\n", "\n", "# expectation for # of cars returned in second location\n", "RETURNS_SECOND_LOC = 2\n", "\n", "DISCOUNT = 0.9\n", "\n", "# credit earned by a car\n", "RENTAL_CREDIT = 10\n", "\n", "# cost of moving a car\n", "MOVE_CAR_COST = 2\n", "\n", "# all possible actions\n", "actions = np.arange(-MAX_MOVE_OF_CARS, MAX_MOVE_OF_CARS + 1)\n", "\n", "# An up bound for poisson distribution\n", "# If n is greater than this value, then the probability of getting n is truncated to 0\n", "POISSON_UPPER_BOUND = 11\n", "\n", "# Probability for poisson distribution\n", "# @lam: lambda should be less than 10 for this function\n", "poisson_cache = dict()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-01-16T13:08:00.174475Z", "start_time": "2020-01-16T13:08:00.170485Z" } }, "outputs": [], "source": [ "def poisson_probability(n, lam):\n", " global poisson_cache\n", " key = n * 10 + lam\n", " if key not in poisson_cache:\n", " # 设置字典,防止重复求泊松分布\n", " # key 相当于一种哈希编码 Hash(n)\n", " poisson_cache[key] = poisson.pmf(n, lam)\n", " return poisson_cache[key]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2020-01-16T13:08:11.358555Z", "start_time": "2020-01-16T13:08:11.349579Z" } }, "outputs": [], "source": [ "def expected_return(state, action, state_value, constant_returned_cars):\n", " \"\"\"\n", " @state: [# of cars in first location, # of cars in second location]\n", " @action: positive if moving cars from first location to second location,\n", " negative if moving cars from second location to first location\n", " @stateValue: state value matrix\n", " @constant_returned_cars: if set True, model is simplified such that\n", " the # of cars returned in daytime becomes constant\n", " rather than a random value from poisson distribution, which will reduce calculation time\n", " and leave the optimal policy/value state matrix almost the same\n", " \"\"\"\n", " # initailize total return\n", " returns = 0.0\n", "\n", " # cost for moving cars\n", " returns -= MOVE_CAR_COST * abs(action)\n", "\n", " # moving cars\n", " NUM_OF_CARS_FIRST_LOC = min(state[0] - action, MAX_CARS)\n", " NUM_OF_CARS_SECOND_LOC = min(state[1] + action, MAX_CARS)\n", "\n", " # go through all possible rental requests\n", " for rental_request_first_loc in range(POISSON_UPPER_BOUND):\n", " for rental_request_second_loc in range(POISSON_UPPER_BOUND):\n", " # probability for current combination of rental requests\n", " prob = poisson_probability(rental_request_first_loc, RENTAL_REQUEST_FIRST_LOC) * \\\n", " poisson_probability(rental_request_second_loc, RENTAL_REQUEST_SECOND_LOC)\n", "\n", " num_of_cars_first_loc = NUM_OF_CARS_FIRST_LOC\n", " num_of_cars_second_loc = NUM_OF_CARS_SECOND_LOC\n", " '''\n", " 注意作者的编程习惯,全部大写表示当前客观变量\n", " (当前2个 LOC 各客观存在这么多车)\n", " 小写表示临时变量,用于计算\n", " '''\n", "\n", " # valid rental requests should be less than actual # of cars\n", " valid_rental_first_loc = min(num_of_cars_first_loc, rental_request_first_loc)\n", " valid_rental_second_loc = min(num_of_cars_second_loc, rental_request_second_loc)\n", "\n", " # get credits for renting\n", " reward = (valid_rental_first_loc + valid_rental_second_loc) * RENTAL_CREDIT\n", " num_of_cars_first_loc -= valid_rental_first_loc\n", " num_of_cars_second_loc -= valid_rental_second_loc\n", " \n", " '''\n", " 这里之所以要用2层 for\n", " 是为了遍历状态,求期望\n", " '''\n", "\n", " if constant_returned_cars:\n", " # get returned cars, those cars can be used for renting tomorrow\n", " returned_cars_first_loc = RETURNS_FIRST_LOC\n", " returned_cars_second_loc = RETURNS_SECOND_LOC\n", " num_of_cars_first_loc = min(num_of_cars_first_loc + returned_cars_first_loc, MAX_CARS)\n", " num_of_cars_second_loc = min(num_of_cars_second_loc + returned_cars_second_loc, MAX_CARS)\n", " returns += prob * (reward + DISCOUNT * state_value[num_of_cars_first_loc, num_of_cars_second_loc])\n", " '''\n", " 这里很重要\n", " 作者在这个函数开头便说明了:\n", " 还车数当成常数,对结果没有影响\n", " 但若是还使用泊松分布生成返回值,则 O(n^4)\n", " 极大地影响了运行时间\n", " '''\n", " else:\n", " for returned_cars_first_loc in range(POISSON_UPPER_BOUND):\n", " for returned_cars_second_loc in range(POISSON_UPPER_BOUND):\n", " prob_return = poisson_probability(\n", " returned_cars_first_loc, RETURNS_FIRST_LOC) * poisson_probability(returned_cars_second_loc, RETURNS_SECOND_LOC)\n", " num_of_cars_first_loc_ = min(num_of_cars_first_loc + returned_cars_first_loc, MAX_CARS)\n", " num_of_cars_second_loc_ = min(num_of_cars_second_loc + returned_cars_second_loc, MAX_CARS)\n", " prob_ = prob_return * prob\n", " returns += prob_ * (reward + DISCOUNT *\n", " state_value[num_of_cars_first_loc_, num_of_cars_second_loc_])\n", " return returns" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2020-01-16T14:12:29.413904Z", "start_time": "2020-01-16T14:12:29.404927Z" } }, "outputs": [], "source": [ "def figure_4_2(constant_returned_cars=True):\n", " value = np.zeros((MAX_CARS + 1, MAX_CARS + 1))\n", " policy = np.zeros(value.shape, dtype=np.int)\n", "\n", " iterations = 0\n", " _, axes = plt.subplots(2, 3, figsize=(40, 20)) # 注意这里\n", " plt.subplots_adjust(wspace=0.1, hspace=0.2) # 这里的 subplot 技巧\n", " axes = axes.flatten()\n", " while True:\n", " fig = sns.heatmap(np.flipud(policy), cmap=\"YlGnBu\", ax=axes[iterations])\n", " fig.set_ylabel('# cars at first location', fontsize=30)\n", " fig.set_yticks(list(reversed(range(MAX_CARS + 1))))\n", " fig.set_xlabel('# cars at second location', fontsize=30)\n", " fig.set_title('policy {}'.format(iterations), fontsize=30)\n", "\n", " # policy evaluation (in-place)\n", " while True:\n", " old_value = value.copy()\n", " for i in range(MAX_CARS + 1):\n", " for j in range(MAX_CARS + 1):\n", " new_state_value = expected_return([i, j], policy[i, j], value, constant_returned_cars)\n", " value[i, j] = new_state_value\n", " '''\n", " 注意这里的编程习惯\n", " 使用 new_state_value 过渡\n", " 增强易读性\n", " '''\n", " max_value_change = abs(old_value - value).max()\n", " print('max value change {}'.format(max_value_change))\n", " '''\n", " 直到评估出当前策略对应的 最优价值 才结束\n", " '''\n", " if max_value_change < 1e-4:\n", " break\n", "\n", " # policy improvement\n", " policy_stable = True\n", " for i in range(MAX_CARS + 1):\n", " for j in range(MAX_CARS + 1):\n", " old_action = policy[i, j]\n", " action_returns = []\n", " for action in actions:\n", " if (0 <= action <= i) or (-j <= action <= 0):\n", " action_returns.append(expected_return([i, j], action, value, constant_returned_cars))\n", " else:\n", " action_returns.append(-np.inf)\n", " new_action = actions[np.argmax(action_returns)] # 注意这句\n", " policy[i, j] = new_action\n", " if policy_stable and old_action != new_action: \n", " # 确认 policy_stable == True \n", " # 以保证程序安全性,如果为 False 就没必要考虑置为 False 了\n", " policy_stable = False\n", " print('policy stable {}'.format(policy_stable))\n", "\n", " if policy_stable:\n", " fig = sns.heatmap(np.flipud(value), cmap=\"YlGnBu\", ax=axes[-1])\n", " fig.set_ylabel('# cars at first location', fontsize=30)\n", " fig.set_yticks(list(reversed(range(MAX_CARS + 1))))\n", " fig.set_xlabel('# cars at second location', fontsize=30)\n", " fig.set_title('optimal value', fontsize=30)\n", " break\n", "\n", " iterations += 1\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2020-01-16T14:13:36.797660Z", "start_time": "2020-01-16T14:12:30.598734Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "max value change 196.62783361783852\n", "max value change 134.98823859766583\n", "max value change 91.41415360228919\n", "max value change 67.17097732555729\n", "max value change 51.29055484635097\n", "max value change 38.49091000659837\n", "max value change 29.406139835126424\n", "max value change 25.7210573245398\n", "max value change 22.381602293031023\n", "max value change 19.40385808254939\n", "max value change 16.77577350573091\n", "max value change 14.47251552455765\n", "max value change 12.464101852186843\n", "max value change 10.719367983418692\n", "max value change 9.20806226246873\n", "max value change 7.9019189666795455\n", "max value change 6.775146571130392\n", "max value change 5.8045764710083745\n", "max value change 4.969618520007145\n", "max value change 4.252112693842776\n", "max value change 3.6361309524054946\n", "max value change 3.107761240497666\n", "max value change 2.654891834022692\n", "max value change 2.26700589940549\n", "max value change 1.9349911763441128\n", "max value change 1.650966802154585\n", "max value change 1.4081276418079938\n", "max value change 1.2006055672075036\n", "max value change 1.02334664187498\n", "max value change 0.8720029351049448\n", "max value change 0.7428376083516355\n", "max value change 0.6326419232035505\n", "max value change 0.5386628774742235\n", "max value change 0.45854026040933604\n", "max value change 0.3902520158000584\n", "max value change 0.33206690395809346\n", "max value change 0.28250355471067223\n", "max value change 0.2402951004837064\n", "max value change 0.20435866938208846\n", "max value change 0.1737691018435612\n", "max value change 0.14773633074884174\n", "max value change 0.12558593365213255\n", "max value change 0.10674242749371388\n", "max value change 0.09071493100810812\n", "max value change 0.07708486873008269\n", "max value change 0.06549543334426744\n", "max value change 0.05564256088280217\n", "max value change 0.047267206266042194\n", "max value change 0.040148735572074656\n", "max value change 0.03409927655258116\n", "max value change 0.028958890796900505\n", "max value change 0.02459144993093787\n", "max value change 0.02088111467702447\n", "max value change 0.01772932984539466\n", "max value change 0.0150522606048753\n", "max value change 0.012778605996800252\n", "max value change 0.010847734796413988\n", "max value change 0.00920809667559297\n", "max value change 0.007815868406908066\n", "max value change 0.006633800647819044\n", "max value change 0.005630235829983121\n", "max value change 0.0047782719802853535\n", "max value change 0.004055050928627679\n", "max value change 0.003441152547054571\n", "max value change 0.002920079317334512\n", "max value change 0.0024778178350288727\n", "max value change 0.0021024658457236\n", "max value change 0.0017839150607983356\n", "max value change 0.0015135814573454809\n", "max value change 0.0012841759864272717\n", "max value change 0.0010895096651211134\n", "max value change 0.0009243279118891223\n", "max value change 0.0007841697620847299\n", "max value change 0.000665248240920846\n", "max value change 0.0005643487139082026\n", "max value change 0.00047874254232738167\n", "max value change 0.00040611372588728045\n", "max value change 0.0003444966009737982\n", "max value change 0.000292222921245866\n", "max value change 0.0002478769196159192\n", "max value change 0.00021025715051337102\n", "max value change 0.00017834408060934948\n", "max value change 0.00015127258103575514\n", "max value change 0.00012830856951495662\n", "max value change 0.00010882917825938421\n", "max value change 9.230592559106299e-05\n", "policy stable False\n", "max value change 72.93565506480746\n", "max value change 5.771584637253568\n", "max value change 2.1472970104344995\n", "max value change 1.070365975080108\n", "max value change 0.8619106467957636\n", "max value change 0.7181428891676092\n", "max value change 0.611364010490604\n", "max value change 0.5169059906119742\n", "max value change 0.4358272831748309\n", "max value change 0.3670218562992318\n", "max value change 0.30890785349942007\n", "max value change 0.259927010978231\n", "max value change 0.21868429274547907\n", "max value change 0.18397356667821896\n", "max value change 0.15476712387498992\n", "max value change 0.1301950284682789\n", "max value change 0.10952318723241206\n", "max value change 0.09213308201026393\n", "max value change 0.07750397308279844\n", "max value change 0.065197614125168\n", "max value change 0.054845259444334715\n", "max value change 0.046136675532011395\n", "max value change 0.038810871455780216\n", "max value change 0.03264828952961807\n", "max value change 0.02746423075132043\n", "max value change 0.02310332179803254\n", "max value change 0.019434859485443212\n", "max value change 0.016348893962117472\n", "max value change 0.013752933586260951\n", "max value change 0.011569172929284832\n", "max value change 0.009732160858675343\n", "max value change 0.008186838889457704\n", "max value change 0.006886890992063854\n", "max value change 0.005793355427499591\n", "max value change 0.004873456997415815\n", "max value change 0.004099624717810002\n", "max value change 0.003448665460666689\n", "max value change 0.0029010688248831684\n", "max value change 0.0024404223665897007\n", "max value change 0.002052919695017863\n", "max value change 0.0017269466670768452\n", "max value change 0.0014527332954799022\n", "max value change 0.0012220609166320173\n", "max value change 0.001028015870474519\n", "max value change 0.0008647822832017482\n", "max value change 0.000727467756007627\n", "max value change 0.0006119567266296144\n", "max value change 0.0005147871261783621\n", "max value change 0.0004330466088617868\n", "max value change 0.00036428526584586507\n", "max value change 0.0003064422000988998\n", "max value change 0.0002577837491344326\n", "max value change 0.00021685153535599966\n", "max value change 0.00018241874482782805\n", "max value change 0.00015345336976224644\n", "max value change 0.00012908726307614415\n", "max value change 0.00010859013127628714\n", "max value change 9.134763803331225e-05\n", "policy stable False\n", "max value change 4.7865793901779625\n", "max value change 3.2947349349497017\n", "max value change 2.2411823866665372\n", "max value change 1.616931343950455\n", "max value change 1.1197864003121367\n", "max value change 0.7204544260453076\n", "max value change 0.443826224180043\n", "max value change 0.270089591177225\n", "max value change 0.16639579119885184\n", "max value change 0.1097569388878128\n", "max value change 0.09306955083684443\n", "max value change 0.07883243113371918\n", "max value change 0.06673516197616891\n", "max value change 0.05647744756430484\n", "max value change 0.04778890580797679\n", "max value change 0.04043363544485601\n", "max value change 0.03420889623009771\n", "max value change 0.02894175601244342\n", "max value change 0.0244852782967655\n", "max value change 0.020714866829109724\n", "max value change 0.01752498189301832\n", "max value change 0.014826276628639334\n", "max value change 0.01254313570910881\n", "max value change 0.010611575530163009\n", "max value change 0.008977459770164842\n", "max value change 0.00759498609147613\n", "max value change 0.006425404193464601\n", "max value change 0.005435930515432119\n", "max value change 0.004598829696703888\n", "max value change 0.00389063733302919\n", "max value change 0.003291502340118768\n", "max value change 0.0027846305547996053\n", "max value change 0.0023558140037494013\n", "max value change 0.0019930326484427496\n", "max value change 0.0016861174636346732\n", "max value change 0.0014264653930240456\n", "max value change 0.001206798198722936\n", "max value change 0.0010209584465883381\n", "max value change 0.0008637369118673632\n", "max value change 0.0007307265586291578\n", "max value change 0.0006181990110007973\n", "max value change 0.0005230000359119913\n", "max value change 0.00044246113805002096\n", "max value change 0.00037432475153309497\n", "max value change 0.00031668096374914967\n", "max value change 0.0002679139766428307\n", "max value change 0.00022665681694888917\n", "max value change 0.00019175301355289776\n", "max value change 0.00016222418821598694\n", "max value change 0.00013724262737468962\n", "max value change 0.00011610807808892787\n", "max value change 9.822812404536307e-05\n", "policy stable False\n", "max value change 0.5643315459673204\n", "max value change 0.19760617142037518\n", "max value change 0.10013580858492332\n", "max value change 0.06076229858263105\n", "max value change 0.04080851176706801\n", "max value change 0.02724975517776329\n", "max value change 0.01637959485265128\n", "max value change 0.00917172069227945\n", "max value change 0.0049277609952014245\n", "max value change 0.0025834353657501197\n", "max value change 0.0013420404746966597\n", "max value change 0.0007016294298409775\n", "max value change 0.00037558255417025066\n", "max value change 0.00020989058543818828\n", "max value change 0.00013043237390775175\n", "max value change 0.00011051700198549952\n", "max value change 9.361574132071837e-05\n", "policy stable False\n", "max value change 0.04079438312567163\n", "max value change 0.010408227162770345\n", "max value change 0.005110707129347247\n", "max value change 0.0032318390198042835\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "max value change 0.0021719229242762594\n", "max value change 0.0013911772695109903\n", "max value change 0.0008154469392138708\n", "max value change 0.0004459807777266178\n", "max value change 0.0002340408432246477\n", "max value change 0.00012037610895276885\n", "max value change 6.173777182993945e-05\n", "policy stable True\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 2880x1440 with 12 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure_4_2()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 多线程解法 如下\n", "\n", "jupyter 上迟迟不执行,怀疑 jupyter 无法处理多线程。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "start_time": "2020-01-16T14:33:22.506Z" } }, "outputs": [], "source": [ "# This file is contributed by Tahsincan Kรถse which implements a synchronous policy evaluation, while the car_rental.py\n", "# implements an asynchronous policy evaluation. This file also utilizes multi-processing for acceleration and contains\n", "# an answer to Exercise 4.5\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import math\n", "import tqdm\n", "import multiprocessing as mp\n", "from functools import partial\n", "import time\n", "import itertools\n", "%matplotlib inline\n", "\n", "############# PROBLEM SPECIFIC CONSTANTS #######################\n", "MAX_CARS = 20\n", "MAX_MOVE = 5\n", "MOVE_COST = -2\n", "ADDITIONAL_PARK_COST = -4\n", "\n", "RENT_REWARD = 10\n", "# expectation for rental requests in first location\n", "RENTAL_REQUEST_FIRST_LOC = 3\n", "# expectation for rental requests in second location\n", "RENTAL_REQUEST_SECOND_LOC = 4\n", "# expectation for # of cars returned in first location\n", "RETURNS_FIRST_LOC = 3\n", "# expectation for # of cars returned in second location\n", "RETURNS_SECOND_LOC = 2\n", "################################################################\n", "\n", "poisson_cache = dict()\n", "\n", "\n", "def poisson(n, lam):\n", " global poisson_cache\n", " key = n * 10 + lam\n", " if key not in poisson_cache.keys():\n", " poisson_cache[key] = math.exp(-lam) * math.pow(lam, n) / math.factorial(n)\n", " return poisson_cache[key]\n", "\n", "\n", "class PolicyIteration:\n", " def __init__(self, truncate, parallel_processes, delta=1e-2, gamma=0.9, solve_4_5=False):\n", " self.TRUNCATE = truncate\n", " self.NR_PARALLEL_PROCESSES = parallel_processes\n", " self.actions = np.arange(-MAX_MOVE, MAX_MOVE + 1)\n", " self.inverse_actions = {el: ind[0] for ind, el in np.ndenumerate(self.actions)}\n", " self.values = np.zeros((MAX_CARS + 1, MAX_CARS + 1))\n", " self.policy = np.zeros(self.values.shape, dtype=np.int)\n", " self.delta = delta\n", " self.gamma = gamma\n", " self.solve_extension = solve_4_5\n", "\n", " def solve(self):\n", " iterations = 0\n", " total_start_time = time.time()\n", " while True:\n", " start_time = time.time()\n", " self.values = self.policy_evaluation(self.values, self.policy)\n", " elapsed_time = time.time() - start_time\n", " print(f'PE => Elapsed time {elapsed_time} seconds')\n", " start_time = time.time()\n", "\n", " policy_change, self.policy = self.policy_improvement(self.actions, self.values, self.policy)\n", " elapsed_time = time.time() - start_time\n", " print(f'PI => Elapsed time {elapsed_time} seconds')\n", " if policy_change == 0:\n", " break\n", " iterations += 1\n", " total_elapsed_time = time.time() - total_start_time\n", " print(f'Optimal policy is reached after {iterations} iterations in {total_elapsed_time} seconds')\n", "\n", " # out-place\n", " def policy_evaluation(self, values, policy):\n", "\n", " global MAX_CARS\n", " while True:\n", " new_values = np.copy(values)\n", " k = np.arange(MAX_CARS + 1)\n", " # cartesian product\n", " all_states = ((i, j) for i, j in itertools.product(k, k))\n", "\n", " results = []\n", " with mp.Pool(processes=self.NR_PARALLEL_PROCESSES) as p:\n", " '''\n", " 多线程,传入 all_states 参数\n", " 固定 func, policy, values\n", " 在临界区(不知理解的对不对)\n", " '''\n", " cook = partial(self.expected_return_pe, policy, values)\n", " results = p.map(cook, all_states)\n", "\n", " for v, i, j in results:\n", " new_values[i, j] = v\n", "\n", " difference = np.abs(new_values - values).sum()\n", " print(f'Difference: {difference}')\n", " values = new_values\n", " if difference < self.delta:\n", " print(f'Values are converged!')\n", " return values\n", "\n", " def policy_improvement(self, actions, values, policy):\n", " new_policy = np.copy(policy)\n", "\n", " expected_action_returns = np.zeros((MAX_CARS + 1, MAX_CARS + 1, np.size(actions)))\n", " cooks = dict()\n", " with mp.Pool(processes=8) as p:\n", " for action in actions:\n", " k = np.arange(MAX_CARS + 1)\n", " all_states = ((i, j) for i, j in itertools.product(k, k))\n", " cooks[action] = partial(self.expected_return_pi, values, action)\n", " results = p.map(cooks[action], all_states)\n", " for v, i, j, a in results:\n", " expected_action_returns[i, j, self.inverse_actions[a]] = v\n", " for i in range(expected_action_returns.shape[0]):\n", " for j in range(expected_action_returns.shape[1]):\n", " new_policy[i, j] = actions[np.argmax(expected_action_returns[i, j])]\n", "\n", " policy_change = (new_policy != policy).sum()\n", " print(f'Policy changed in {policy_change} states')\n", " return policy_change, new_policy\n", "\n", " # O(n^4) computation for all possible requests and returns\n", " def bellman(self, values, action, state):\n", " expected_return = 0\n", " if self.solve_extension:\n", " if action > 0:\n", " # Free shuttle to the second location\n", " expected_return += MOVE_COST * (action - 1)\n", " else:\n", " expected_return += MOVE_COST * abs(action)\n", " else:\n", " expected_return += MOVE_COST * abs(action)\n", "\n", " for req1 in range(0, self.TRUNCATE):\n", " for req2 in range(0, self.TRUNCATE):\n", " # moving cars\n", " num_of_cars_first_loc = int(min(state[0] - action, MAX_CARS))\n", " num_of_cars_second_loc = int(min(state[1] + action, MAX_CARS))\n", "\n", " # valid rental requests should be less than actual # of cars\n", " real_rental_first_loc = min(num_of_cars_first_loc, req1)\n", " real_rental_second_loc = min(num_of_cars_second_loc, req2)\n", "\n", " # get credits for renting\n", " reward = (real_rental_first_loc + real_rental_second_loc) * RENT_REWARD\n", "\n", " if self.solve_extension:\n", " if num_of_cars_first_loc >= 10:\n", " reward += ADDITIONAL_PARK_COST\n", " if num_of_cars_second_loc >= 10:\n", " reward += ADDITIONAL_PARK_COST\n", "\n", " num_of_cars_first_loc -= real_rental_first_loc\n", " num_of_cars_second_loc -= real_rental_second_loc\n", "\n", " # probability for current combination of rental requests\n", " prob = poisson(req1, RENTAL_REQUEST_FIRST_LOC) * \\\n", " poisson(req2, RENTAL_REQUEST_SECOND_LOC)\n", " for ret1 in range(0, self.TRUNCATE):\n", " for ret2 in range(0, self.TRUNCATE):\n", " num_of_cars_first_loc_ = min(num_of_cars_first_loc + ret1, MAX_CARS)\n", " num_of_cars_second_loc_ = min(num_of_cars_second_loc + ret2, MAX_CARS)\n", " prob_ = poisson(ret1, RETURNS_FIRST_LOC) * \\\n", " poisson(ret2, RETURNS_SECOND_LOC) * prob\n", " # Classic Bellman equation for state-value\n", " # prob_ corresponds to p(s'|s,a) for each possible s' -> (num_of_cars_first_loc_,num_of_cars_second_loc_)\n", " expected_return += prob_ * (\n", " reward + self.gamma * values[num_of_cars_first_loc_, num_of_cars_second_loc_])\n", " return expected_return\n", "\n", " # Parallelization enforced different helper functions\n", " # Expected return calculator for Policy Evaluation\n", " def expected_return_pe(self, policy, values, state):\n", "\n", " action = policy[state[0], state[1]]\n", " expected_return = self.bellman(values, action, state)\n", " return expected_return, state[0], state[1]\n", "\n", " # Expected return calculator for Policy Improvement\n", " def expected_return_pi(self, values, action, state):\n", "\n", " if ((action >= 0 and state[0] >= action) or (action < 0 and state[1] >= abs(action))) == False:\n", " return -float('inf'), state[0], state[1], action\n", " expected_return = self.bellman(values, action, state)\n", " return expected_return, state[0], state[1], action\n", "\n", " def plot(self):\n", " print(self.policy)\n", " plt.figure()\n", " plt.xlim(0, MAX_CARS + 1)\n", " plt.ylim(0, MAX_CARS + 1)\n", " plt.table(cellText=self.policy, loc=(0, 0), cellLoc='center')\n", " plt.show()\n", "\n", "\n", "TRUNCATE = 9\n", "solver = PolicyIteration(TRUNCATE, parallel_processes=4, delta=1e-1, gamma=0.9, solve_4_5=True)\n", "solver.solve()\n", "solver.plot()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }