{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Adaptive Stress Testing: Walk1D Example\n", "\n", "This notebook expands on the Walk1D example, see the [Walk1D.jl](https://github.com/sisl/POMDPStressTesting.jl/blob/master/test/Walk1D.jl) file for the non-notebook version and [walk1d.pdf](https://github.com/sisl/POMDPStressTesting.jl/blob/master/test/pdf/walk1d.pdf) for the write-up version.\n", "\n", "---\n", "\n", "See the [documentation](https://sisl.github.io/POMDPStressTesting.jl/dev/) for more details.\n", "\n", "## Abstract\n", "In this self-contained tutorial, we define a simple problem for adaptive stress testing (AST)\n", "to find failures. This problem, called Walk1D, samples random walking distances from a standard\n", "normal distribution $\\mathcal{N}(0,1)$ and defines failures as walking past a certain threshold\n", "(which is set to ±10 in this example). AST will either select the seed which deterministically\n", "controls the sampled value from the distribution (i.e. from the transition model) or will directly\n", "sample the provided environmental distributions. These action modes are determined by the seed-action or\n", "sample-action options. AST will guide the simulation to failure events using a notion of distance to failure,\n", "while simultaneously trying to find the set of actions that maximizes the log-likelihood of the samples.\n", "\n", "Some definitions to note for this example problem:\n", "- **System**: a one-dimensional walking agent\n", "- **Environment**: distribution of random walking actions, sampled from a standard normal distribution $\\mathcal N(0.1)$\n", "- **Failure event**: agent walks outside of the ±10 region\n", "- **Distance metric**: how close to the ±10 edge is the agent?" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": false }, "outputs": [], "source": [ "using POMDPStressTesting # this package\n", "using Distributions # for the Normal distribution\n", "using Parameters # for @with_kw default struct parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "# Gray-box Simulator and Environment\n", "The simulator and environment are treated as gray-box because we need\n", "access to the state-transition distributions and their associated likelihoods. Refer to the [gray-box definition](https://sisl.github.io/POMDPStressTesting.jl/dev/#sim_env) section in the documentation for further details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Parameters\n", "First, we define the parameters of our simulation." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Walk1DParams" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@with_kw mutable struct Walk1DParams\n", " startx::Float64 = 0 # Starting x-position\n", " threshx::Float64 = 10 # +- boundary threshold\n", " endtime::Int64 = 30 # Simulate end time\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GrayBox.Simulation\n", "Next, we define a `GrayBox.Simulation` structure which stores simulation-related values." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Walk1DSim" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@with_kw mutable struct Walk1DSim <: GrayBox.Simulation\n", " params::Walk1DParams = Walk1DParams() # Parameters\n", " x::Float64 = 0 # Current x-position\n", " t::Int64 = 0 # Current time ±\n", " distribution::Distribution = Normal(0, 1) # Transition distribution\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GrayBox.environment\n", "Then, we define our `GrayBox.Environment` distributions.\n", "When using the `ASTSampleAction`, as opposed to `ASTSeedAction`,\n", "we need to provide access to the sampleable environment." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "GrayBox.environment(sim::Walk1DSim) = GrayBox.Environment(:x => sim.distribution)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GrayBox.transition!\n", "We override the transition function from the `GrayBox` interface,\n", "which takes an environment sample as input. We apply the sample in our simulator,\n", "and return the log-likelihood." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "function GrayBox.transition!(sim::Walk1DSim, sample::GrayBox.EnvironmentSample)\n", " sim.t += 1 # Keep track of time\n", " sim.x += sample[:x].value # Move agent using sampled value from input\n", " return logpdf(sample)::Real # Summation handled by `logpdf()`\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "# Black-box System\n", "The system under test, in this case a simple single-dimensional moving agent,\n", "is always treated as *black-box*. The following interface functions are overridden\n", "to minimally interact with the system, and use outputs from the system to\n", "determine failure event indications and distance metrics. Refer to the [black-box definition](https://sisl.github.io/POMDPStressTesting.jl/dev/#system) section of the documntation for further details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BlackBox.initialize!\n", "Now we override the `BlackBox` interface, starting with the\n", "function that initializes the simulation object. Interface functions\n", "ending in `!` may modify the `sim` object in place." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "function BlackBox.initialize!(sim::Walk1DSim)\n", " sim.t = 0\n", " sim.x = sim.params.startx\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BlackBox.distance\n", "We define how close we are to a failure event using a non-negative distance metric." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "BlackBox.distance(sim::Walk1DSim) = max(sim.params.threshx - abs(sim.x), 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BlackBox.isevent\n", "We define an indication that a failure event occurred." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "BlackBox.isevent(sim::Walk1DSim) = abs(sim.x) >= sim.params.threshx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BlackBox.isterminal\n", "Similarly, we define an indication that the simulation (or system) is in a terminal state." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "BlackBox.isterminal(sim::Walk1DSim) = BlackBox.isevent(sim) || sim.t >= sim.params.endtime" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BlackBox.evaluate!\n", "Lastly, we use our defined interface to evaluate the system under test.\n", "Using the input sample, we return the log-likelihood, distance to an event, and event indication." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "function BlackBox.evaluate!(sim::Walk1DSim, sample::GrayBox.EnvironmentSample)\n", " logprob::Real = GrayBox.transition!(sim, sample) # Step simulation\n", " d::Real = BlackBox.distance(sim) # Calculate miss distance\n", " event::Bool = BlackBox.isevent(sim) # Check event indication\n", " return (logprob::Real, d::Real, event::Bool)\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "# AST Setup and Running\n", "Setting up our simulation, we instantiate our simulation object and\n", "pass that to the Markov decision proccess (MDP) object of the adaptive stress testing\n", "formulation. We use Monte Carlo tree search (MCTS) with progressive widening on the action\n", "space as our solver. Hyperparameters are passed to `MCTSPWSolver`, which is\n", "a simple wrapper around the `POMDPs.jl` implementation of [MCTS.jl](https://github.com/JuliaPOMDP/MCTS.jl). Lastly, we solve the MDP\n", "to produce a `planner`. Note we are using the `ASTSampleAction`.\n", "\n", "### Setup/create AST planner\n", "* `planner` is used to play out the search\n", "* `planner.mdp::ASTMDP` is the main MDP problem formulation object for AST (this holds reward metrics)\n", "* `planner.mdp.sim::Walk1DSim` is the main simulation object, holding all simulation information (e.g., current x position, settings for the simulation, etc)\n", "* `solver::MCTSPWSolver` holds solver-specific parameters and is used to generate the `planner`\n", " * *See below for additional solvers*" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "setup_ast (generic function with 2 methods)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function setup_ast(seed=0)\n", " # Create gray-box simulation object\n", " sim::GrayBox.Simulation = Walk1DSim()\n", "\n", " # AST MDP formulation object\n", " mdp::ASTMDP = ASTMDP{ASTSampleAction}(sim)\n", " mdp.params.debug = true # record metrics\n", " mdp.params.top_k = 10 # record top k best trajectories\n", " mdp.params.seed = seed # set RNG seed for determinism\n", "\n", " # Hyperparameters for MCTS-PW as the solver\n", " solver = MCTSPWSolver(n_iterations=1000, # number of algorithm iterations\n", " exploration_constant=1.0, # UCT exploration\n", " k_action=1.0, # action widening\n", " alpha_action=0.5, # action widening\n", " depth=sim.params.endtime) # tree depth\n", "\n", " # Get online planner (no work done, yet)\n", " planner = solve(solver, mdp)\n", "\n", " return planner\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Search\n", "After setup, we *search* for failures using the planner and output the best action trace." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "planner = setup_ast();" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\n" ] }, { "data": { "text/plain": [ "10-element Array{ASTAction,1}:\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n", "\n", " ASTSampleAction\n", " sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}\n" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "action_trace = search!(planner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Playback\n", "We can also *playback* specific trajectories and print intermediate $x$-values." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0\n", "0.015327970235574946\n", "0.30370765717228826\n", "1.4997663258450742\n", "2.109577482502034\n", "2.952826155527804\n", "4.426053954911257\n", "6.803697217082685\n", "8.03548741774146\n", "9.178959112657289\n", "10.490317369782634\n" ] }, { "data": { "text/plain": [ "ASTState\n", " t_index: Int64 11\n", " parent: ASTState\n", " action: ASTSampleAction\n", " hash: UInt64 0x12ff54a6c31b4cce\n", " q_value: Float64 -1.7787687724700838\n", " terminal: Bool true\n" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "final_state = playback(planner, action_trace, sim->sim.x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Metrics\n", "Finally, we can print metrics associated with the AST run for further analysis." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First failure: 23 of 62506\n", "Number of failures: 510\n", "Failure rate: 0.81592%\n" ] }, { "data": { "text/plain": [ "0.8159216715195341" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "failure_rate = print_metrics(planner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "# Visualize interactive MCTS tree (using D3.js)\n", "When using the `MCTSPWSolver`, we can output the tree from the `search!` function and visulize it using `D3Trees.jl`." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\n" ] }, { "data": { "text/html": [ " \n", " \n", "
\n", " \n", "\n", " Attempting to display the tree. If the tree is large, this may take some time.\n", "
\n", "\n", " Note: D3Trees.jl requires an internet connection. If no tree appears, please check your connection. To help fix this, please see this issue. You may also diagnose errors with the javascript console (Ctrl-Shift-J in chrome).\n", "
\n", "