{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Serial Value Iteration" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "After installing the package, we first load PLite into the current process." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "# Pkg.clone(\"https://github.com/haoyio/PLite.jl\") # do this once\n", "using PLite" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Problem definition\n", "\n", "For this example, we define a simple 1-D gridworld type problem where the goal of the agent is to move and stop at the center of the grid starting from anywhere. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "The idea of this section is to define the problem mathematically without fussing about how we might choose to solve it later (e.g., use a discretized representation). \n", "\n", "In general, given the wide variety of MDP solvers available, it's important not to restrict our problem representation by what solver we might eventually choose. Rather, we want to choose the best solver after understanding our problem (e.g., find some problem structure that we can exploit)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### MDP initialization\n", "\n", "We first define some constants and initialize the empty `MDP` object." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "PLite.MDP(Dict{AbstractString,PLite.LazyVar}(),Dict{AbstractString,PLite.LazyVar}(),PLite.LazyFunc(true,ASCIIString[],##emptyfunc#7358),PLite.LazyFunc(true,ASCIIString[],##emptyfunc#7357))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# constants\n", "const MinX = 0\n", "const MaxX = 100\n", "const StepX = 20\n", "\n", "# mdp definition\n", "mdp = MDP()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### State and action spaces definition\n", "\n", "We define the state and action spaces using a factored representation. If we can't factor the representation, we can simply define the space using a single discrete or continuous variable. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "For each state and action variable, we either define a continuous or discrete variable. Note that we use the state variables' \"natural representation\" and avoid discretizing `x` even though for value iteration we would eventually have to. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "PLite.ValuesVar(\"move\",ASCIIString[\"W\",\"E\",\"stop\"])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# state space\n", "statevariable!(mdp, \"x\", MinX, MaxX) # continuous\n", "statevariable!(mdp, \"goal\", [\"no\", \"yes\"]) # discrete\n", "\n", "# action space\n", "actionvariable!(mdp, \"move\", [\"W\", \"E\", \"stop\"]) # discrete" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Transition and reward functions definition\n", "\n", "To define either the transition or reward function, we pass in the `MDP` object, the transition function itself, and an ordered set of the transition function's argument names. \n", "\n", "In the case of the transition function, we can define it using either the $T(s,a)$ or $T(s,a,s')$ format. Our example here uses the former way. In the case of the reward function, we define it using the $R(s,a)$ format." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "First define the transition function." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "mytransition (generic function with 1 method)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function isgoal(x::Float64)\n", " if abs(x - MaxX / 2) < StepX\n", " return \"yes\"\n", " else\n", " return \"no\"\n", " end\n", "end\n", "\n", "function mytransition(x::Float64, goal::AbstractString, move::AbstractString)\n", " if isgoal(x) == \"yes\" && goal == \"yes\"\n", " return [([x, isgoal(x)], 1.0)]\n", " end\n", "\n", " if move == \"E\"\n", " if x >= MaxX\n", " return [\n", " ([x, isgoal(x)], 0.9),\n", " ([x - StepX, isgoal(x - StepX)], 0.1)]\n", " elseif x <= MinX\n", " return [\n", " ([x, isgoal(x)], 0.2),\n", " ([x + StepX, isgoal(x + StepX)], 0.8)]\n", " else\n", " return [\n", " ([x, isgoal(x)], 0.1),\n", " ([x - StepX, isgoal(x - StepX)], 0.1),\n", " ([x + StepX, isgoal(x + StepX)], 0.8)]\n", " end\n", " elseif move == \"W\"\n", " if x >= MaxX\n", " return [\n", " ([x, isgoal(x)], 0.1),\n", " ([x - StepX, isgoal(x - StepX)], 0.9)]\n", " elseif x <= MinX\n", " return [\n", " ([x, isgoal(x)], 0.9),\n", " ([x + StepX, isgoal(x + StepX)], 0.1)]\n", " else\n", " return [\n", " ([x, isgoal(x)], 0.1),\n", " ([x - StepX, isgoal(x - StepX)], 0.8),\n", " ([x + StepX, isgoal(x + StepX)], 0.1)]\n", " end\n", " elseif move == \"stop\"\n", " return [([x, isgoal(x)], 1.0)]\n", " end\n", "end" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "After defining the transition function, we pass it into `mdp`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "PLite.LazyFunc(false,ASCIIString[\"x\",\"goal\",\"move\"],mytransition)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transition!(mdp, [\"x\", \"goal\", \"move\"], mytransition)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Repeat this process for the reward function." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "PLite.LazyFunc(false,ASCIIString[\"x\",\"goal\",\"move\"],myreward)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function myreward(x::Float64, goal::AbstractString, move::AbstractString)\n", " if goal == \"yes\" && move == \"stop\"\n", " return 1\n", " else\n", " return 0\n", " end\n", "end\n", "\n", "reward!(mdp, [\"x\", \"goal\", \"move\"], myreward)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Solver selection\n", "\n", "So far, we've only been interested in coding up the mathematical formulation of the problem. From here on, we're only interested in selecting the solver and giving additional information to the solver object such that we can solve the MDP (such as the discretization scheme)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Infinite horizon MDP\n", "\n", "For an infinite horizon problem with discount $\\gamma$, it can be proven that the value of an optimal policy satisfies the *Bellman equation*\n", "$$U^{\\star}(s) = \\max_{a}\\left( R\\left(s,a\\right) + \\gamma\\sum_{s'}T\\left(s'\\mid s, a\\right) U^{\\star}\\left(s'\\right) \\right)\\text{.}$$\n", "For the convergence proof, see http://web.stanford.edu/class/ee266/lectures/dpproof.pdf." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "To initialize the serial value iteration solver, simply type the following. There are default parameters for the algorithm. Namely, `maxiter`, `tol`, and `discount`. These are keyword arguments with default values. If we don't like them, we can change their values." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "PLite.SerialValueIteration(true,1000,0.0001,0.99,Dict{AbstractString,PLite.LazyDiscrete}(),Dict{AbstractString,PLite.LazyDiscrete}(),GridInterpolations.RectangleGrid with 1 points,GridInterpolations.RectangleGrid with 1 points)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "solver = SerialValueIteration()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "In PLite, value iteration requires all variables to be discretized. In the above problem, we need to discretize `x`, so we write the following.\n", "\n", "Note that the solver uses the `GridInterpolations.jl` package for multilinear interpolation to approximate the values between the discretized state variable values if the $T(s, a)$ type transition is defined." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "PLite.LazyDiscrete(\"x\",20.0)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "const StepX = 20\n", "discretize_statevariable!(solver, \"x\", StepX)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Finite horizon MDP\n", "\n", "In value iteration, we compute optimal value function $U_{n}$ associated with a finite horizon of $n$ and no discounting. If $n=0$, then $U_{0}(s)=0$ for all $s$. We can compute $U_{n}$ recursively from this base case\n", "$$U_{n}(s) = \\max_{a}\\left( R\\left(s,a\\right) + \\sum_{s'}T\\left(s'\\mid s, a\\right) U_{n-1}\\left(s'\\right) \\right)\\text{.}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Notice that the value iteration solver was built with infinite horizon problems in mind. It’s easy, however, to modify it to solve finite horizon problems by simply changing the parameters of the solvers. \n", "\n", "When the iterations terminate, it'll give a warning that the maximum number of iterations have been reached. This message was built to warn the user about convergence issues for infinite horizon problems, so just ignore it since we know what we're doing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For an MDP with a horizon of 40 and no discounting, we can define the solver as follows. (Note that we redefined the discretization scheme since it was lost when we overrode the solver variable.)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "PLite.LazyDiscrete(\"x\",20.0)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "solver = SerialValueIteration(maxiter=40, discount=1, verbose=false)\n", "discretize_statevariable!(solver, \"x\", StepX)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Solution and policy extraction\n", "\n", "The optimal value function $U^{\\star}$ appears on both sides of the equation. Value iteration approximates $U^{\\star}$ by iteratively updating the estimate of $U^{\\star}$ using the above equation. Once we know $U^{\\star}$, we can extract an optimal policy via\n", "$$\\pi\\left(s\\right) \\leftarrow \\underset{a} {\\mathrm{argmax}} \\left( R\\left(s,a\\right) + \\gamma\\sum_{s'}T\\left(s'\\mid s, a\\right) U^{\\star}\\left(s'\\right) \\right)\\text{.}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "To generate the solution object, we simply input the following." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: mdp and value iteration solver passed basic checks\n", "WARNING: maximum number of iterations reached; check accuracy of solutions" ] }, { "data": { "text/plain": [ "PLite.ValueIterationSolution(3x12 Array{Float64,2}:\n", " 36.4688 36.7344 37.875 38.8611 38.6111 … 39.0 39.0 38.6111 37.5 \n", " 37.3438 38.5938 38.8594 37.8889 36.8611 39.0 39.0 36.8611 36.6111\n", " 36.3438 37.5938 39.0 39.0 37.6111 40.0 40.0 38.6111 37.5 ,GridInterpolations.RectangleGrid with 12 points,GridInterpolations.RectangleGrid with 3 points,0.06014998299999998,40,1.0)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "INFO: value iteration solution generated\n", "cputime [s] = 0.06014998299999998\n", "number of iterations = 40\n", "residual = 1.0\n" ] } ], "source": [ "solution = solve(mdp, solver)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Finally, to generate the optimal policy function, we type in the following." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "policy (generic function with 1 method)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "policy = getpolicy(mdp, solution)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "We can then extract the optimal action at a given state by querying `policy` as follows." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "1-element Array{Any,1}:\n", " \"E\"" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stateq = (12, \"no\")\n", "actionq = policy(stateq...) # equally valid to type actionq = policy(12, \"no\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "The result says that we should move \"east,\" or move right on the grid. This makes sense since we're to the left of the grid center. :D" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Julia 0.4.0", "language": "julia", "name": "julia-0.4" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.4.0" } }, "nbformat": 4, "nbformat_minor": 0 }