{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# GBDTs.jl -- Grammar-Based Decision Trees\n", "\n", "Grammar-based decision tree (GBDT) is an interpretable machine learning model that can be used for the classification and categorization of heterogeneous multivariate time series data. GBDTs combine decision trees with a grammar framework. Each split of the decision tree is governed by a logical expression derived from a user-supplied grammar. The flexibility of the grammar framework enables GBDTs to be applied to a wide range of problems. In particular, GBDT has been previously applied to analyze multivariate heterogeneous time series data of failures in aircraft collision avoidance systems [1].\n", "\n", "[1] Lee et al. \"Interpretable Categorization of Heterogeneous Time Series Data\", preprint, 2018.\n", "\n", "GBDTs.jl depends on ExprOptimization.jl for optimizing the expressions at each node.\n", "\n", "To install the package:\n", "\n", " Pkg.clone(\"https://github.com/sisl/GBDTs.jl\")\n", " \n", "To get started:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [], "source": [ "using GBDTs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Australian Sign Language (Auslan) Dataset Example\n", "\n", "We analyze a subset of the Australian Sign Language dataset from the UCI Repository as an example. The dataset contains 8 words (labels): hello, please, yes, no, right wrong, same, different.\n", "\n", "Load the labeled dataset from file:\n", "(Hint: if you unzip the data file to a directory of the same name, it'll load much faster)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "using MultivariateTimeSeries\n", "X, y = read_data_labeled(joinpath(dirname(pathof(GBDTs)), \"..\", \"data\", \"auslan_youtube8\"));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define a grammar\n", "\n", "Here we use a simple grammar based on temporal logic:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1: b = G(bvec)\n", "2: b = F(bvec)\n", "3: b = G(implies(bvec, bvec))\n", "4: bvec = and(bvec, bvec)\n", "5: bvec = or(bvec, bvec)\n", "6: bvec = not(bvec)\n", "7: bvec = lt(rvec, rvec)\n", "8: bvec = lte(rvec, rvec)\n", "9: bvec = gt(rvec, rvec)\n", "10: bvec = gte(rvec, rvec)\n", "11: bvec = f_lt(x, xid, v, vid)\n", "12: bvec = f_lte(x, xid, v, vid)\n", "13: bvec = f_gt(x, xid, v, vid)\n", "14: bvec = f_gte(x, xid, v, vid)\n", "15: rvec = x[!, xid]\n", "16: xid = :x_1\n", "17: xid = :y_1\n", "18: xid = :z_1\n", "19: xid = :roll_1\n", "20: xid = :pitch_1\n", "21: xid = :yaw_1\n", "22: xid = :thumbbend_1\n", "23: xid = :forebend_1\n", "24: xid = :middlebend_1\n", "25: xid = :ringbend_1\n", "26: xid = :littlebend_1\n", "27: xid = :x_2\n", "28: xid = :y_2\n", "29: xid = :z_2\n", "30: xid = :roll_2\n", "31: xid = :pitch_2\n", "32: xid = :yaw_2\n", "33: xid = :thumbbend_2\n", "34: xid = :forebend_2\n", "35: xid = :middlebend_2\n", "36: xid = :ringbend_2\n", "37: xid = :littlebend_2\n", "38: vid = 1\n", "39: vid = 2\n", "40: vid = 3\n", "41: vid = 4\n", "42: vid = 5\n", "43: vid = 6\n", "44: vid = 7\n", "45: vid = 8\n", "46: vid = 9\n", "47: vid = 10\n" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grammar = @grammar begin\n", " b = G(bvec) | F(bvec) | G(implies(bvec,bvec))\n", " bvec = and(bvec, bvec)\n", " bvec = or(bvec, bvec)\n", " bvec = not(bvec)\n", " bvec = lt(rvec, rvec)\n", " bvec = lte(rvec, rvec)\n", " bvec = gt(rvec, rvec)\n", " bvec = gte(rvec, rvec)\n", " bvec = f_lt(x, xid, v, vid)\n", " bvec = f_lte(x, xid, v, vid)\n", " bvec = f_gt(x, xid, v, vid)\n", " bvec = f_gte(x, xid, v, vid)\n", " rvec = x[!,xid]\n", " xid = |([:x_1,:y_1,:z_1,:roll_1,:pitch_1,:yaw_1,:thumbbend_1,:forebend_1,:middlebend_1,:ringbend_1,:littlebend_1,:x_2,:y_2,:z_2,:roll_2,:pitch_2,:yaw_2,:thumbbend_2,:forebend_2,:middlebend_2,:ringbend_2,:littlebend_2])\n", " vid = |(1:10)\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the functions used in the grammar:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "gte (generic function with 1 method)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "G(v) = all(v) #globally\n", "F(v) = any(v) #eventually\n", "f_lt(x, xid, v, vid) = lt(x[!,xid], v[xid][vid]) #feature is less than a constant\n", "f_lte(x, xid, v, vid) = lte(x[!,xid], v[xid][vid]) #feature is less than or equal to a constant\n", "f_gt(x, xid, v, vid) = gt(x[!,xid], v[xid][vid]) #feature is greater than a constant\n", "f_gte(x, xid, v, vid) = gte(x[!,xid], v[xid][vid]) #feature is greater than or equal to a constant\n", "\n", "#workarounds for slow dot operators:\n", "implies(v1, v2) = (a = similar(v1); a .= v2 .| .!v1) #implies\n", "not(v) = (a = similar(v); a .= .!v) #not\n", "and(v1, v2) = (a = similar(v1); a .= v1 .& v2) #and\n", "or(v1, v2) = (a = similar(v1); a .= v1 .| v2) #or\n", "lt(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .< x2) #less than\n", "lte(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .≤ x2) #less than or equal to\n", "gt(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .> x2) #greater than\n", "gte(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .≥ x2) #greater than or equal to" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll specify the constants in the grammar by discretizing uniformly per feature:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "const v = Dict{Symbol,Vector{Float64}}()\n", "mins, maxes = minimum(X), maximum(X)\n", "for (i,xid) in enumerate(Symbol.(names(X)))\n", " v[xid] = collect(range(mins[i],stop=maxes[i],length=10))\n", "end;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Learn a GBDT from data\n", "\n", "GBDT uses ExprOptimization to optimize each split. A number of optimization algorithms are available (see ExprOptimization.jl).\n", "\n", "Specify the search parameters for the optimization:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MonteCarlo(2000, 5, ExprOptimization.MonteCarlos.NoTracking())" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p = MonteCarlo(2000, 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learn the GBDT:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "using Random; Random.seed!(1)\n", "model = induce_tree(grammar, :b, p, X, y, 6);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the tree" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using TikzGraphs:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "TikzPictures.TikzPicture(\"\\\\graph [layered layout, , ] {\\n1/\\\"1: G(f\\\\_lte(x, :middlebend\\\\_1, v, 1))\\\" [],\\n2/\\\"2: F(f\\\\_gt(x, :middlebend\\\\_2, v, 3))\\\" [],\\n3/\\\"3: G(f\\\\_lt(x, :littlebend\\\\_2, v, 5))\\\" [],\\n4/\\\"4: label=wrong\\\" [],\\n5/\\\"5: F(f\\\\_gt(x, :yaw\\\\_2, v, 7))\\\" [],\\n6/\\\"6: F(and(f\\\\_gt(x, :roll\\\\_2, v, 6), f\\\\_gte(x, :forebend\\\\_2, v, 6)))\\\" [],\\n7/\\\"7: label=no\\\" [],\\n8/\\\"8: label=please\\\" [],\\n9/\\\"9: F(f\\\\_gte(x, :thumbbend\\\\_2, v, 9))\\\" [],\\n10/\\\"10: label=yes\\\" [],\\n11/\\\"11: label=right\\\" [],\\n12/\\\"12: label=hello\\\" [],\\n13/\\\"13: G(f\\\\_gt(x, :roll\\\\_2, v, 6))\\\" [],\\n14/\\\"14: label=same\\\" [],\\n15/\\\"15: label=different\\\" [],\\n;\\n1 -> [,] 2;\\n1 -> [,] 13;\\n2 -> [,] 3;\\n2 -> [,] 12;\\n3 -> [,] 4;\\n3 -> [,] 5;\\n5 -> [,] 6;\\n5 -> [,] 9;\\n6 -> [,] 7;\\n6 -> [,] 8;\\n9 -> [,] 10;\\n9 -> [,] 11;\\n13 -> [,] 14;\\n13 -> [,] 15;\\n};\\n\", \"\", \"\\\\usepackage{fontspec}\\n\\\\setmainfont{Latin Modern Math}\\n\\\\usetikzlibrary{arrows}\\n\\\\usetikzlibrary{graphs}\\n\\\\usetikzlibrary{graphdrawing}\\n\\n% from: https://tex.stackexchange.com/questions/453132/fresh-install-of-tl2018-no-tikz-graph-drawing-libraries-found\\n\\\\usepackage{luacode}\\n\\\\begin{luacode*}\\n\\tfunction pgf_lookup_and_require(name)\\n\\tlocal sep = package.config:sub(1,1)\\n\\tlocal function lookup(name)\\n\\tlocal sub = name:gsub('%.',sep) \\n\\tif kpse.find_file(sub, 'lua') then\\n\\trequire(name)\\n\\telseif kpse.find_file(sub, 'clua') then\\n\\tcollectgarbage('stop') \\n\\trequire(name)\\n\\tcollectgarbage('restart')\\n\\telse\\n\\treturn false\\n\\tend\\n\\treturn true\\n\\tend\\n\\treturn\\n\\tlookup('pgf.gd.' .. name .. '.library') or\\n\\tlookup('pgf.gd.' .. name) or\\n\\tlookup(name .. '.library') or\\n\\tlookup(name) \\n\\tend\\n\\\\end{luacode*}\\n\\n\\\\usegdlibrary{layered}\", \"tikzpicture\", \"\", \"\", true, true)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "display(model; edgelabels=false) #suppress edge labels for clarity (left branch is true, right branch is false)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using AbstractTrees:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1: label=1, loss=75.4, G(f_lte(x, :middlebend_1, v, 1))\n", "├─ 2: label=1, loss=67.07, F(f_gt(x, :middlebend_2, v, 3))\n", "│ ├─ 3: label=2, loss=60.4, G(f_lt(x, :littlebend_2, v, 5))\n", "│ │ ├─ 4: label=6\n", "│ │ └─ 5: label=2, loss=50.4, F(f_gt(x, :yaw_2, v, 7))\n", "│ │ ├─ 6: label=2, loss=0.8, F(and(f_gt(x, :roll_2, v, 6), f_gte(x, :forebend_2, v, 6)))\n", "│ │ │ ├─ 7: label=4\n", "│ │ │ │ ⋮\n", "│ │ │ │ \n", "│ │ │ └─ 8: label=2\n", "│ │ │ ⋮\n", "│ │ │ \n", "│ │ └─ 9: label=3, loss=0.4, F(f_gte(x, :thumbbend_2, v, 9))\n", "│ │ ├─ 10: label=3\n", "│ │ │ ⋮\n", "│ │ │ \n", "│ │ └─ 11: label=5\n", "│ │ ⋮\n", "│ │ \n", "│ └─ 12: label=1\n", "└─ 13: label=7, loss=0.4, G(f_gt(x, :roll_2, v, 6))\n", " ├─ 14: label=7\n", " └─ 15: label=8\n" ] } ], "source": [ "show(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use the GBDT model for prediction\n", "\n", "Here, we'll predict the training data and evaluate the accuracy." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ind = collect(1:length(X))\n", "y_pred = classify(model, X, ind)\n", "accuracy = count(y_pred .== y[ind]) / length(ind)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Determine the members of each node" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1×135 adjoint(::Vector{Int64}) with eltype Int64:\n", " 28 29 30 31 32 33 34 35 36 37 … 156 157 158 159 160 161 162" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mvec = node_members(model, X, ind)\n", "mvec[3]' #members of node 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.7.3", "language": "julia", "name": "julia-1.7" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }