{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l\r", "\u001b[K |█▉ | 10kB 26.2MB/s eta 0:00:01\r", "\u001b[K |███▋ | 20kB 846kB/s eta 0:00:01\r", "\u001b[K |█████▍ | 30kB 1.3MB/s eta 0:00:01\r", "\u001b[K |███████▏ | 40kB 1.4MB/s eta 0:00:01\r", "\u001b[K |█████████ | 51kB 1.0MB/s eta 0:00:01\r", "\u001b[K |██████████▊ | 61kB 1.2MB/s eta 0:00:01\r", "\u001b[K |████████████▌ | 71kB 1.4MB/s eta 0:00:01\r", "\u001b[K |██████████████▎ | 81kB 1.5MB/s eta 0:00:01\r", "\u001b[K |████████████████ | 92kB 1.6MB/s eta 0:00:01\r", "\u001b[K |█████████████████▉ | 102kB 1.4MB/s eta 0:00:01\r", "\u001b[K |███████████████████▋ | 112kB 1.4MB/s eta 0:00:01\r", "\u001b[K |█████████████████████▍ | 122kB 1.4MB/s eta 0:00:01\r", "\u001b[K |███████████████████████▏ | 133kB 1.4MB/s eta 0:00:01\r", "\u001b[K |█████████████████████████ | 143kB 1.4MB/s eta 0:00:01\r", "\u001b[K |██████████████████████████▊ | 153kB 1.4MB/s eta 0:00:01\r", "\u001b[K |████████████████████████████▌ | 163kB 1.4MB/s eta 0:00:01\r", "\u001b[K |██████████████████████████████▎ | 174kB 1.4MB/s eta 0:00:01\r", "\u001b[K |████████████████████████████████| 184kB 1.4MB/s \n", "\u001b[?25h" ] } ], "source": [ "!pip install fastai fastdot -q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 04 TabNet\n", "\n", "In this notebook we'll be looking at comparing the `TabNet` architecture to our regular `fastai` fully connected models. We'll be utilizing Michael Grankin's `fast_tabnet` wrapper to utilize the model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Keyring is skipped due to an exception: Failed to unlock the collection!\u001b[0m\n", "\u001b[33mWARNING: Keyring is skipped due to an exception: Failed to unlock the collection!\u001b[0m\n", "\u001b[33m WARNING: Keyring is skipped due to an exception: Failed to unlock the collection!\u001b[0m\n" ] } ], "source": [ "!pip install fast_tabnet==0.0.8 pytorch_tabnet==1.0.6 -q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TabNet is an attention-based network for tabular data, originating [here](https://arxiv.org/pdf/1908.07442.pdf). Let's first look at our fastai architecture and then compare it with TabNet utilizing the `fastdot` library." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First let's build our data real quick so we know just what we're visualizing. We'll use `ADULTs` again" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.tabular.all import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll build our `TabularPandas` object and the `DataLoaders`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "splits = RandomSplitter()(range_of(df))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "to = TabularPandas(df, procs, cat_names, cont_names, y_names=\"salary\", splits=splits)\n", "dls = to.dataloaders(bs=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And let's look at one batch to understand how the data is coming into the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls.one_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we can see first is our categoricals, second is our continuous, and the third is our `y`. With this in mind, let's make a `TabularModel` with 200 and 100 layer sizes:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'tabular_learner' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtabular_learner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'tabular_learner' is not defined" ] } ], "source": [ "learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now a basic visualization of this model can be made with `fastdot` like below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastdot import *" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Hidden" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "layers_cat = ['Embedding Matrix', 'Dropout']\n", "inp = ['Input']\n", "cont_bn = ['BatchNorm1d']\n", "lin_bn_drop = ['BatchNorm1d', 'Dropout', 'Linear', 'ReLU']\n", "full_lin = ['LinBnDrop\\n(ni, 200)', 'LinBnDrop\\n(200,100)', 'LinBnDrop\\n(100,2)']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "block1, block2, block3, block4, block5 = ['Preprocessed Input', 'Categorical\\nEmbeddings',\n", " 'Continous\\nBatch Normalization', 'Fully\\nConnected Layers',\n", " 'LinBnDrop']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "conns = ((block1, block2),\n", " (block1, block3),\n", " (block2, block4),\n", " (block3, block4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "def color(o):\n", " if o == 'Embedding Matrix': return 'white'\n", " if o == 'Input': return 'gray'\n", " if 'Dropout' in o: return 'gold'\n", " if 'BatchNorm' in o: return 'pink'\n", " if 'Lin' in o: return 'lightblue'\n", " if 'ReLU' in o: return 'gray'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "node_defaults['fillcolor'] = color" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(#4) [,,,]" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "model = [\n", " seq_cluster(inp, block1),\n", " seq_cluster(layers_cat, block2),\n", " seq_cluster(cont_bn, block3),\n", " seq_cluster(full_lin, block4)]\n", "\n", "fcc = seq_cluster(lin_bn_drop, block5)\n", "g = graph_items(*model)\n", "g.add_items(fcc)\n", "g.add_items(*object_connections(conns))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "cluster_n80b6434c22cb4d8da0bbace4de11edc7\n", "\n", "\n", "Preprocessed Input\n", "\n", "\n", "\n", "\n", "cluster_nbf9d3d1f410e4729bb633b2c8b05c3b0\n", "\n", "\n", "Categorical\n", "Embeddings\n", "\n", "\n", "\n", "\n", "cluster_nf3d12e9abb6141959f639f645330f61e\n", "\n", "\n", "Continous\n", "Batch Normalization\n", "\n", "\n", "\n", "\n", "cluster_n24949ae5aaa7450a90745cb63b088d3f\n", "\n", "\n", "Fully\n", "Connected Layers\n", "\n", "\n", "\n", "\n", "cluster_n9cc233b5e81b45358b7aaf1c30fbb77b\n", "\n", "\n", "LinBnDrop\n", "\n", "\n", "\n", "\n", "\n", "n45dee0203e324fd8b7c19ed66dd3488c\n", "\n", "\n", "Input\n", "\n", "\n", "\n", "\n", "\n", "n02680aac1cd04b40931a88cbccecf9c8\n", "\n", "\n", "Embedding Matrix\n", "\n", "\n", "\n", "\n", "\n", "n45dee0203e324fd8b7c19ed66dd3488c->n02680aac1cd04b40931a88cbccecf9c8\n", "\n", "\n", "\n", "\n", "\n", "n6f8061439682402581456b9d2038df95\n", "\n", "\n", "BatchNorm1d\n", "\n", "\n", "\n", "\n", "\n", "n45dee0203e324fd8b7c19ed66dd3488c->n6f8061439682402581456b9d2038df95\n", "\n", "\n", "\n", "\n", "\n", "nc9d2854b9c5843638fe80c616e7b9e96\n", "\n", "\n", "Dropout\n", "\n", "\n", "\n", "\n", "\n", "n02680aac1cd04b40931a88cbccecf9c8->nc9d2854b9c5843638fe80c616e7b9e96\n", "\n", "\n", "\n", "\n", "\n", "n4997f871d208435c9663e0389e5d4765\n", "\n", "\n", "LinBnDrop\n", "(ni, 200)\n", "\n", "\n", "\n", "\n", "\n", "nc9d2854b9c5843638fe80c616e7b9e96->n4997f871d208435c9663e0389e5d4765\n", "\n", "\n", "\n", "\n", "\n", "n6f8061439682402581456b9d2038df95->n4997f871d208435c9663e0389e5d4765\n", "\n", "\n", "\n", "\n", "\n", "n25548523949d4d8db3c0e1efe9dfa826\n", "\n", "\n", "LinBnDrop\n", "(200,100)\n", "\n", "\n", "\n", "\n", "\n", "n4997f871d208435c9663e0389e5d4765->n25548523949d4d8db3c0e1efe9dfa826\n", "\n", "\n", "\n", "\n", "\n", "n9a6cda4896cd422faf57311520a15c49\n", "\n", "\n", "LinBnDrop\n", "(100,2)\n", "\n", "\n", "\n", "\n", "\n", "n25548523949d4d8db3c0e1efe9dfa826->n9a6cda4896cd422faf57311520a15c49\n", "\n", "\n", "\n", "\n", "\n", "n12553fd1051e4971bcc0e7b214273f3b\n", "\n", "\n", "BatchNorm1d\n", "\n", "\n", "\n", "\n", "\n", "n3f0ba4c43a174e2986a4440d71f9b6b0\n", "\n", "\n", "Dropout\n", "\n", "\n", "\n", "\n", "\n", "n12553fd1051e4971bcc0e7b214273f3b->n3f0ba4c43a174e2986a4440d71f9b6b0\n", "\n", "\n", "\n", "\n", "\n", "n75e9b07d307d41499e203d6f49ee34f2\n", "\n", "\n", "Linear\n", "\n", "\n", "\n", "\n", "\n", "n3f0ba4c43a174e2986a4440d71f9b6b0->n75e9b07d307d41499e203d6f49ee34f2\n", "\n", "\n", "\n", "\n", "\n", "n22116590bb8f4cc3bdbadca821569594\n", "\n", "\n", "ReLU\n", "\n", "\n", "\n", "\n", "\n", "n75e9b07d307d41499e203d6f49ee34f2->n22116590bb8f4cc3bdbadca821569594\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "g" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How does this compare to TabNet? This is TabNet:" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Hide" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "block1, block2, block3, block4, block5 = ['Preprocessed Input', 'Categorical\\nEmbeddings',\n", " 'Continous\\nBatch Normalization', 'TabNet', 'Output']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "conns = ((block1, block2),\n", " (block1, block3),\n", " (block2, block4),\n", " (block3, block4),\n", " (block4, block5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "feat_tfmer = [*shared, *specifics]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "fcc = ['Linear']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "tabnet = ['Attention\\nTransformer', 'Feature\\nTransformer', 'Final\\nMapping (Linear)']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "def color(o):\n", " if o == 'Embedding Matrix': return 'white'\n", " if o == 'Input': return 'gray'\n", " if 'Dropout' in o: return 'white'\n", " if 'BatchNorm' in o: return 'pink'\n", " if 'Att' in o: return 'lightblue'\n", " if 'Feat' in o: return 'darkseagreen2'\n", " if 'Mask' in o: return 'lightgray'\n", " if 'Lin' in o: return 'gold2'\n", " if 'Out' in o: return 'lightgray'\n", " return 'white'\n", "\n", "node_defaults['fillcolor'] = color" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "def cluster_color(o):\n", " if 'Attention' in o: return 'lightblue'\n", " if 'Feat' in o: return 'darkseagreen2'\n", " return 'lightgray'\n", "\n", "cluster_defaults['fillcolor'] = cluster_color" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "block1, block2, block3, block4, block5 = ['Preprocessed Input', 'Categorical\\nEmbeddings',\n", " 'Continous\\nBatch Normalization', 'TabNet', 'Output']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "out = ['Output', 'Mask_Loss', 'Mask_Explain', 'Masks']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "model = [\n", " seq_cluster(inp, block1),\n", " seq_cluster(layers_cat, block2),\n", " seq_cluster(cont_bn, block3),\n", " seq_cluster(tabnet, block4),\n", " *Cluster(block5).add_items(*out),\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "conns = ((block1, block2),\n", " (block1, block3),\n", " (block2, block4),\n", " (block3, block4),\n", " (block4, out[0]),\n", " (block4, out[1]),\n", " (block4, out[2]),\n", " (block4, out[3]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(#8) [,,,,,,,]" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "g = graph_items(*model)\n", "g.add_items(*object_connections(conns))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "att_tfmer = ['Linear', 'GhostBatchNorm', 'torch.mul(x, prior)', 'Sparsemax']\n", "shared = ['Linear\\n(ni, 80)', 'Linear\\n(ni-2, 80)']\n", "specifics = ['GLU Block']\n", "feat_tfmer = [*shared, *specifics]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "att_clus = seq_cluster(att_tfmer, 'Attention Transformer')\n", "feat_clus = seq_cluster(feat_tfmer, 'Feature Transformer')" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Graphs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "cluster_n69bba691913145a29ae3ba24a356379b\n", "\n", "\n", "Preprocessed Input\n", "\n", "\n", "\n", "\n", "cluster_n4a0489b5df2e4ea99f3a7e64f3311bee\n", "\n", "\n", "Categorical\n", "Embeddings\n", "\n", "\n", "\n", "\n", "cluster_n55518dfbcfc946798753482761137097\n", "\n", "\n", "Continous\n", "Batch Normalization\n", "\n", "\n", "\n", "\n", "cluster_n5d99cbd28efd4b1fb81063ba30948b13\n", "\n", "\n", "TabNet\n", "\n", "\n", "\n", "\n", "\n", "n973d0f7e41d14bc099d00a2668f94793\n", "\n", "\n", "Input\n", "\n", "\n", "\n", "\n", "\n", "n082baf498bdb4aff8cab21ed0af7b1bc\n", "\n", "\n", "Embedding Matrix\n", "\n", "\n", "\n", "\n", "\n", "n973d0f7e41d14bc099d00a2668f94793->n082baf498bdb4aff8cab21ed0af7b1bc\n", "\n", "\n", "\n", "\n", "\n", "n15bf013a5d794c269157211616b18dd0\n", "\n", "\n", "BatchNorm1d\n", "\n", "\n", "\n", "\n", "\n", "n973d0f7e41d14bc099d00a2668f94793->n15bf013a5d794c269157211616b18dd0\n", "\n", "\n", "\n", "\n", "\n", "nfee2c6e95bbc4182b14b8a580600eff7\n", "\n", "\n", "Dropout\n", "\n", "\n", "\n", "\n", "\n", "n082baf498bdb4aff8cab21ed0af7b1bc->nfee2c6e95bbc4182b14b8a580600eff7\n", "\n", "\n", "\n", "\n", "\n", "n9aff5b9db3b341568d65b04d368ce086\n", "\n", "\n", "Attention\n", "Transformer\n", "\n", "\n", "\n", "\n", "\n", "nfee2c6e95bbc4182b14b8a580600eff7->n9aff5b9db3b341568d65b04d368ce086\n", "\n", "\n", "\n", "\n", "\n", "n15bf013a5d794c269157211616b18dd0->n9aff5b9db3b341568d65b04d368ce086\n", "\n", "\n", "\n", "\n", "\n", "n1b1c2177713843b89209afb55442048a\n", "\n", "\n", "Feature\n", "Transformer\n", "\n", "\n", "\n", "\n", "\n", "n9aff5b9db3b341568d65b04d368ce086->n1b1c2177713843b89209afb55442048a\n", "\n", "\n", "\n", "\n", "\n", "n742b31293ee245d29c4b4882d704612d\n", "\n", "\n", "Final\n", "Mapping (Linear)\n", "\n", "\n", "\n", "\n", "\n", "n1b1c2177713843b89209afb55442048a->n742b31293ee245d29c4b4882d704612d\n", "\n", "\n", "\n", "\n", "\n", "n08cae2e4ad514f54a3e98bf2d040784b\n", "\n", "\n", "Output\n", "\n", "\n", "\n", "\n", "\n", "n742b31293ee245d29c4b4882d704612d->n08cae2e4ad514f54a3e98bf2d040784b\n", "\n", "\n", "\n", "\n", "\n", "n02545a85b1c1458fbedcfe5f4484b84f\n", "\n", "\n", "Mask_Loss\n", "\n", "\n", "\n", "\n", "\n", "n742b31293ee245d29c4b4882d704612d->n02545a85b1c1458fbedcfe5f4484b84f\n", "\n", "\n", "\n", "\n", "\n", "nad73b0e0cccc4dda924fb6b280bf556d\n", "\n", "\n", "Mask_Explain\n", "\n", "\n", "\n", "\n", "\n", "n742b31293ee245d29c4b4882d704612d->nad73b0e0cccc4dda924fb6b280bf556d\n", "\n", "\n", "\n", "\n", "\n", "n16824f8a0ea347e5a722ef84d59cc614\n", "\n", "\n", "Masks\n", "\n", "\n", "\n", "\n", "\n", "n742b31293ee245d29c4b4882d704612d->n16824f8a0ea347e5a722ef84d59cc614\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "g" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "cluster_n2f150d778aac43468cc10b8800f80b7c\n", "\n", "\n", "Feature Transformer\n", "\n", "\n", "\n", "\n", "\n", "n29e909cfd081451aba834363a466e0c9\n", "\n", "\n", "Linear\n", "(ni, 80)\n", "\n", "\n", "\n", "\n", "\n", "n726c77f9b69e4ec989cbf80fafc6c7cf\n", "\n", "\n", "Linear\n", "(ni-2, 80)\n", "\n", "\n", "\n", "\n", "\n", "n29e909cfd081451aba834363a466e0c9->n726c77f9b69e4ec989cbf80fafc6c7cf\n", "\n", "\n", "\n", "\n", "\n", "n54fefc2cc3434f4eb123a97121e6a1e8\n", "\n", "\n", "GLU Block\n", "\n", "\n", "\n", "\n", "\n", "n726c77f9b69e4ec989cbf80fafc6c7cf->n54fefc2cc3434f4eb123a97121e6a1e8\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "graph_items(feat_clus)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "cluster_n0b220cff35e64f53a0d5650052bbc4e3\n", "\n", "\n", "Attention Transformer\n", "\n", "\n", "\n", "\n", "\n", "ncdac5fa12000419895f3112fd542956b\n", "\n", "\n", "Linear\n", "\n", "\n", "\n", "\n", "\n", "n5fae93e1fe804ead94bfe97cc5efb5c8\n", "\n", "\n", "GhostBatchNorm\n", "\n", "\n", "\n", "\n", "\n", "ncdac5fa12000419895f3112fd542956b->n5fae93e1fe804ead94bfe97cc5efb5c8\n", "\n", "\n", "\n", "\n", "\n", "n7034cf7c75034e57a04bf7e33aab5123\n", "\n", "\n", "torch.mul(x, prior)\n", "\n", "\n", "\n", "\n", "\n", "n5fae93e1fe804ead94bfe97cc5efb5c8->n7034cf7c75034e57a04bf7e33aab5123\n", "\n", "\n", "\n", "\n", "\n", "n21339f5f2f0d4ed1b4efd68f01239f98\n", "\n", "\n", "Sparsemax\n", "\n", "\n", "\n", "\n", "\n", "n7034cf7c75034e57a04bf7e33aab5123->n21339f5f2f0d4ed1b4efd68f01239f98\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "graph_items(att_clus)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "So a few things to note, we now have two transformers, one that keeps an eye on the features and another that keeps an eye on the attention. We could call the Attention transformer the **encoder** and the Feature transformer the **decoder**. What this attention let's us do is see *exactly* how our model is behaving, moreso than just how our feature importance and other techniques \"guess\"" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Now that we have this done, how do we make a model?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Using TabNet\n", "\n", "I have found in my experiments that TabNet isn't quite as good as fastai's tabular model, but as attention can be important and is a hot topic, we'll use it here. Another con of this model is it takes *many* epochs to get a decent accuracy as we will see:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from fast_tabnet.core import *" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "ename": "ImportError", "evalue": "cannot import name 'TabNetNoEmbeddings' from 'fast_tabnet.core' (/home/ml1/anaconda3/envs/fastai/lib/python3.7/site-packages/fast_tabnet/core.py)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mfast_tabnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTabNetNoEmbeddings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mImportError\u001b[0m: cannot import name 'TabNetNoEmbeddings' from 'fast_tabnet.core' (/home/ml1/anaconda3/envs/fastai/lib/python3.7/site-packages/fast_tabnet/core.py)" ] } ], "source": [ "from fast_tabnet.core import TabNetNoEmbeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we need to grab the embedding matrix sizes:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(10, 6), (17, 8), (8, 5), (16, 8), (7, 5), (6, 4), (3, 3)]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emb_szs = get_emb_sz(to); emb_szs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now we can make use of our model! There's many different values we can pass in, here's a brief summary:\n", "\n", "* `n_d`: Dimensions of the prediction layer (usually between 4 to 64)\n", "* `n_a`: Dimensions of the attention layer (similar to `n_d`)\n", "* `n_steps`: Number of sucessive steps in our network (usually 3 to 10)\n", "* `gamma`: A scalling factor for updating attention (usually between 1.0 to 2.0)\n", "* `momentum`: Momentum in all batch normalization\n", "* `n_independent`: Number of independant GLU layers in each block (default is 2)\n", "* `n_shared`: Number of shared GLU layers in each block (default is 2)\n", "* `epsilon`: Should be kept very low (avoid `log(0)`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's build one similar to the model we showed in the above. To do so we'll set the dimensions of the prediction layer to 8, the number of attention layer dimensions to 32, and our steps to 4:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class TabNetModel(Module):\n", " \"Attention model for tabular data.\"\n", " def __init__(self, emb_szs, n_cont, out_sz, embed_p=0., y_range=None,\n", " n_d=8, n_a=8,\n", " n_steps=3, gamma=1.3,\n", " n_independent=2, n_shared=2, epsilon=1e-15,\n", " virtual_batch_size=128, momentum=0.02):\n", " self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])\n", " self.emb_drop = nn.Dropout(embed_p)\n", " self.bn_cont = nn.BatchNorm1d(n_cont)\n", " n_emb = sum(e.embedding_dim for e in self.embeds)\n", " self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range\n", " self.tab_net = TabNetNoEmbeddings(n_emb + n_cont, out_sz, n_d, n_a, n_steps,\n", " gamma, n_independent, n_shared, epsilon, virtual_batch_size, momentum)\n", "\n", " def forward(self, x_cat, x_cont, att=False):\n", " if self.n_emb != 0:\n", " x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]\n", " x = torch.cat(x, 1)\n", " x = self.emb_drop(x)\n", " if self.n_cont != 0:\n", " x_cont = self.bn_cont(x_cont)\n", " x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont\n", " x, m_loss, m_explain, masks = self.tab_net(x)\n", " if self.y_range is not None:\n", " x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]\n", " if att:\n", " return x, m_loss, m_explain, masks\n", " else:\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we need to make new `DataLoaders` because we currently have a batch size of 1" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "dls = to.dataloaders(bs=1024)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then build the model:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "net = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=32, n_steps=1); " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we'll build our `Learner` and use the `ranger` optimizer:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, net, CrossEntropyLossFlat(), metrics=accuracy, opt_func=ranger)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.00000000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "ValueError", "evalue": "not enough values to unpack (expected 4, got 2)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_flat_cos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1e-1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastcore/utils.py\u001b[0m in \u001b[0;36m_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 450\u001b[0m \u001b[0minit_args\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 451\u001b[0m \u001b[0msetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minst\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'init_args'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 452\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0minst\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mto_return\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 453\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 454\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastai/callback/schedule.py\u001b[0m in \u001b[0;36mfit_flat_cos\u001b[0;34m(self, n_epoch, lr, div_final, pct_start, wd, cbs, reset_opt)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhypers\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0mscheds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcombined_cos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpct_start\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mdiv_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 137\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcbs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParamScheduler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscheds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mL\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_opt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreset_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;31m# Cell\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastcore/utils.py\u001b[0m in \u001b[0;36m_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 450\u001b[0m \u001b[0minit_args\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 451\u001b[0m \u001b[0msetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minst\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'init_args'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 452\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0minst\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mto_return\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 453\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 454\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, n_epoch, lr, wd, cbs, reset_opt)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'begin_epoch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_epoch_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_epoch_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mCancelEpochException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'after_cancel_epoch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36m_do_epoch_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'begin_train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 177\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 178\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mCancelTrainException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'after_cancel_train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'after_train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mall_batches\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mall_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_iter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mone_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 156\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mone_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mone_batch\u001b[0;34m(self, i, b)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'begin_batch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'after_pred'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'after_loss'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 720\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 722\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 723\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 724\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x_cat, x_cont, att)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mx_cont\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn_cont\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_cont\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_cont\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_emb\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mx_cont\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm_explain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtab_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my_range\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 4, got 2)" ] } ], "source": [ "learn.fit_flat_cos(5, 1e-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now as you can see it actually didn't take that long to get to the 83% fairly quickly. On my other tests I wasn't able to do quite as well but try it out! The code is here for you to use and play with." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = learn.dls.test_dl(df.iloc[:20], bs=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch = dl.one_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[5, 8, 3, 0, 6, 5, 1]]),\n", " tensor([[ 0.7629, -0.8397, 0.7556]]),\n", " tensor([[1]], dtype=torch.int8))" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "batch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytorch_tabnet.tab_model import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_dims = [emb_szs[i][1] for i in range(len(emb_szs))]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[6, 8, 5, 8, 5, 4, 3]" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "cat_dims" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cat dims are the first bits of the embedding sizes. `cat_idxs` are what index in the batch our categorical variables come from. In our case it's everything after 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_idxs = [3,4,5,6,7,8, 9]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[3, 4, 5, 6, 7, 8, 9]" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "cat_idxs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tot = len(to.cont_names) + len(to.cat_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The 42 comes from the first input out of the embeddings:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "matrix = create_explain_matrix(tot,\n", " cat_dims,\n", " cat_idxs,\n", " 42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = learn.dls.test_dl(df.iloc[:20], bs=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's patch in an `explainer` function to `Learner`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@patch\n", "def explain(x:Learner, dl:TabDataLoader):\n", " \"Get explain values for a set of predictions\"\n", " dec_y = []\n", " x.model.eval()\n", " for batch_nb, data in enumerate(dl):\n", " with torch.no_grad():\n", " out, M_loss, M_explain, masks = x.model(data[0], data[1], True)\n", " for key, value in masks.items():\n", " masks[key] = csc_matrix.dot(value.numpy(), matrix)\n", " if batch_nb == 0:\n", " res_explain = csc_matrix.dot(M_explain.numpy(),\n", " matrix)\n", " res_masks = masks\n", " else:\n", " res_explain = np.vstack([res_explain,\n", " csc_matrix.dot(M_explain.numpy(),\n", " matrix)])\n", " for key, value in masks.items():\n", " res_masks[key] = np.vstack([res_masks[key], value])\n", "\n", " dec_y.append(int(learn.loss_func.decodes(out)))\n", " return dec_y, res_masks, res_explain" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll pass in a `DataLoader`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dec_y, res_masks, res_explain = learn.explain(dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now we can visualize them with `plot_explain`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_explain(masks, lbls, figsize=(12,12)):\n", " \"Plots masks with `lbls` (`dls.x_names`)\"\n", " fig = plt.figure(figsize=figsize)\n", " ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])\n", " plt.yticks(np.arange(0, len(masks[0]), 1.0))\n", " plt.xticks(np.arange(0, len(masks[0][0]), 1.0))\n", " ax.set_xticklabels(lbls, rotation=90)\n", " plt.ylabel('Sample Number')\n", " plt.xlabel('Variable')\n", " plt.imshow(masks[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We pass in the masks and the `x_names` and we can see for *each* input how it affected the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lbls = dls.x_names" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAMzCAYAAABTE+6HAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deZhlVX23/ftL03QDgoCKqKA4QeKI\n2CCKiSIacYjEMfBG45T04xBFk5hHnySiJmYwGmNMHFpFNCEYx0QzqIgimijQDDIozqI4EcMgosy/\n94+9q7u6rFVVTffZ+1B9f66rrq6zz6lav57qe9baa0hVIUnSfLYbuwBJ0vQyJCRJTYaEJKnJkJAk\nNRkSkqSm7ccuYCl2yKpazc5jl6GG7LBy7BIAuHHH6agDIFf8dOwSNqj9dhi7BADylWvHLkELuJLL\nflRVt5l7/WYREqvZmQfk8LHLUMP2e+09dgkAXHXv241dwgar/uOMsUvY4No33WnsEgDY4REXjV2C\nFvCJev+8f0EON0mSmgwJSVKTISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLUZEhIkpoMCUlSkyEh\nSWoyJCRJTaOERJIjknw5ydeSvHSMGiRJixs8JJKsAP4eeBRwD+DoJPcYug5J0uLG6EkcDHytqr5R\nVdcC7wGOHKEOSdIixgiJOwDfmfX44v6aJGnKTO2hQ0nWAmsBVrPTyNVI0rZpjJ7Ed4F9Zj3eu7+2\niapaV1VrqmrNSlYNVpwkaaMxQuIM4O5J7pxkB+Ao4MMj1CFJWsTgw01VdX2S3wE+BqwAjquqC4au\nQ5K0uFHuSVTVfwD/MUbbkqSlc8W1JKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUpMhIUlqMiQkSU2G\nhCSpyZCQJDUZEpKkJkNCktQ0tYcO6ebj8kOm42DBW7zvtLFLmEo7/PEtxy5h6nzluDVjlwDAfs9a\nP3YJi7InIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQ\nkCQ1GRKSpKZRQiLJcUkuSXL+GO1LkpZmrJ7E8cARI7UtSVqiUUKiqk4FLh2jbUnS0k3toUNJ1gJr\nAVaz08jVSNK2aWpvXFfVuqpaU1VrVrJq7HIkaZs0tSEhSRqfISFJahprCuyJwOeA/ZNcnOTZY9Qh\nSVrYKDeuq+roMdqVJG0eh5skSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElN\nhoQkqcmQkCQ1Te2hQ9Nqxe67j10CADdcdtnYJWxwyy/8aOwSALj64fcfu4QNVlxz49glbLDdZ84e\nu4Sps9+z1o9dAgAXveqBY5ew0R+/f97L9iQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKT\nISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLUZEhIkpoGD4kk+yT5VJIvJrkgyTFD1yBJWpoxzpO4\nHvi9qjoryS7AmUlOqqovjlCLJGkBg/ckqur7VXVW//mVwJeAOwxdhyRpcaOeTJdkX+B+wGnzPLcW\nWAuwmp0GrUuS1BntxnWSWwAfAF5UVT+e+3xVrauqNVW1ZiWrhi9QkjROSCRZSRcQJ1TVB8eoQZK0\nuDFmNwV4B/ClqvrroduXJC3dGD2JQ4GnAQ9Lck7/8egR6pAkLWLwG9dV9VkgQ7crSdp8rriWJDUZ\nEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLUNOqhQzdHN1x22dgl\nTJ1cedXYJQCw8hNfH7sE3Uz81bc+P3YJALxk37Er2Oirjev2JCRJTYaEJKnJkJAkNRkSkqQmQ0KS\n1GRISJKaDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktQ0eEgkWZ3k9CRfSHJBklcOXYMk\naWnG2Cr8GuBhVfWTJCuBzyb5z6qajr17JUkbDB4SVVXAT/qHK/uPGroOSdLiRrknkWRFknOAS4CT\nquq0eV6zNsn6JOuv45rhi5QkjRMSVXVDVR0A7A0cnORe87xmXVWtqao1K1k1fJGSpHFnN1XV5cCn\ngCPGrEOSNL8xZjfdJslu/ec7Ao8ALhy6DknS4saY3XQ74F1JVtCF1Hur6t9GqEOStIgxZjedC9xv\n6HYlSZvPFdeSpCZDQpLUZEhIkpoMCUlSkyEhSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKa\nxtjgT8vMt3/jLmOXAMA+75iiw6lutdvYFWxww9e+OXYJAHzrTx84dgkbvGTfsSu4+bAnIUlqMiQk\nSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLU\nNFpIJFmR5Owk/zZWDZKkhY3ZkzgG+NKI7UuSFjFKSCTZG3gM8PYx2pckLc1YPYm/Af4AuLH1giRr\nk6xPsv46pujEMUnahgweEkkeC1xSVWcu9LqqWldVa6pqzUpWDVSdJGm2MXoShwKPS/It4D3Aw5L8\n4wh1SJIWMXhIVNXLqmrvqtoXOAr4ZFU9deg6JEmLc52EJKlp+zEbr6pTgFPGrEGS1GZPQpLUZEhI\nkpoMCUlSkyEhSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUlOqauwaFrVr9qgH5PCx\ny5Ckrepj3ztn7BI2WHG7r51ZVWvmXrcnIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJ\nSVKTISFJajIkJElNhoQkqcmQkCQ1GRKSpKbtx2g0ybeAK4EbgOvn23lQkjS+UUKid1hV/WjE9iVJ\ni3C4SZLUNFZIFPDxJGcmWTtSDZKkRYw13PTgqvpukj2Bk5JcWFWnzn5BHx5rAVaz0xg1StI2b5Se\nRFV9t//1EuBDwMHzvGZdVa2pqjUrWTV0iZIkRgiJJDsn2WXmc+BXgPOHrkOStLgxhptuC3woyUz7\n/1RVHx2hDknSIgYPiar6BnDfoduVJG0+p8BKkpoMCUlSkyEhSWoyJCRJTYaEJKnJkJAkNRkSkqQm\nQ0KS1GRISJKaDAlJUpMhIUlqMiQkSU1jnnEtbVUXv+xBY5ewwd5//t9jl7DB1/7xfmOXAMDdnnr2\n2CVMnUfe/oCxS5jla/NetSchSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUpMhIUlq\nMiQkSU2GhCSpyZCQJDUZEpKkplFCIsluSd6f5MIkX0rywDHqkCQtbKytwt8AfLSqnpRkB2CnkeqQ\nJC1g8JBIckvgl4FnAFTVtcC1Q9chSVrcGMNNdwb+B3hnkrOTvD3JznNflGRtkvVJ1l/HNcNXKUka\nJSS2Bw4E3lxV9wOuAl4690VVta6q1lTVmpWsGrpGSRLjhMTFwMVVdVr/+P10oSFJmjKDh0RV/QD4\nTpL9+0uHA18cug5J0uLGmt30AuCEfmbTN4BnjlSHJGkBo4REVZ0DrBmjbUnS0rniWpLUZEhIkpoM\nCUlSkyEhSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUtNYG/xpGbnkX39h7BIAuOPR\n54xdwgbZ7ZZjl7DB3Z569tglAHDlUYeMXcIGu7zn82OXcLNhT0KS1GRISJKaDAlJUpMhIUlqMiQk\nSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNg4dEkv2TnDPr48dJXjR0HZKk\nxQ2+VXhVfRk4ACDJCuC7wIeGrkOStLixh5sOB75eVReNXIckaR5jh8RRwIkj1yBJahgtJJLsADwO\neF/j+bVJ1idZfx3XDFucJAkYtyfxKOCsqvrhfE9W1bqqWlNVa1ayauDSJEkwbkgcjUNNkjTVRgmJ\nJDsDjwA+OEb7kqSlGXwKLEBVXQXcaoy2JUlLN/bsJknSFDMkJElNhoQkqcmQkCQ1GRKSpCZDQpLU\nZEhIkpoMCUlSkyEhSWoyJCRJTYaEJKnJkJAkNY2ywZ+Wlz2PvHDsEgC4/D/vOnYJG+z6qK+PXcIG\nNz7kfmOXAMAu7/n82CVssN0B9xi7BABuPOeLY5ewKHsSkqQmQ0KS1GRISJKaDAlJUpMhIUlqMiQk\nSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJaholJJK8OMkFSc5PcmKS1WPUIUla2OAh\nkeQOwAuBNVV1L2AFcNTQdUiSFjfWcNP2wI5Jtgd2Ar43Uh2SpAUMHhJV9V3gtcC3ge8DV1TVx+e+\nLsnaJOuTrL+Oa4YuU5LEOMNNuwNHAncGbg/snOSpc19XVeuqak1VrVnJqqHLlCQxznDTw4FvVtX/\nVNV1wAeBB41QhyRpEWOExLeBQ5LslCTA4cCXRqhDkrSIMe5JnAa8HzgLOK+vYd3QdUiSFrf9GI1W\n1bHAsWO0LUlaOldcS5KaDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkpgVDIsl2SZ4yVDGSpOmy\nYEhU1Y3AHwxUiyRpyixluOkTSX4/yT5J9pj5mHhlkqTRLWVbjl/vf33+rGsF3GXrlyNJmiaLhkRV\n3XmIQqQt9aA9vzl2CRucP3YBs2x/2nRssnzdQw8cu4SNTjlr7ApuNhYdbuq39P6jJOv6x3dP8tjJ\nlyZJGttS7km8E7iWjQcDfRf404lVJEmaGksJibtW1WuA6wCq6qdAJlqVJGkqLCUkrk2yI93NapLc\nFbhmolVJkqbCUmY3HQt8FNgnyQnAocAzJlmUJGk6LGV200lJzgIOoRtmOqaqfjTxyiRJo1vq8aUP\nAR5MN+S0EvjQxCqSJE2NpUyBfRPwHOA8uunf/yfJ30+6MEnS+JbSk3gY8ItVNXPj+l3ABROtSpI0\nFZYyu+lrwB1nPd6nvyZJWuaaPYkkH6G7B7EL8KUkp/ePHwCcPkx5kqQxLTTc9NrBqpAkTaVmSFTV\np2c/TrLrQq+XJC0/i/7QT7IWeBVwNXAj3VoJtwqXpG3AUnoGLwHutTUX0CU5BvhtusB5W1X9zdb6\n3pKkrWcps5u+Dvx0azWY5F50AXEwcF/gsUnutrW+vyRp61lKT+JlwH8nOY1ZG/tV1QtvYpu/CJzW\n7yZLkk8DTwBecxO/nyRpQpYSEm8FPkm34vrGrdDm+cCrk9wK+BnwaGD93Bf190LWAqxmp63QrCRp\ncy0lJFZW1e9urQar6ktJ/hL4OHAVcA5wwzyvWwesA9g1e9TWal+StHRLuSfxn0nWJrldkj1mPrak\n0ap6R1Xdv6p+GbgM+MqWfD9J0mQspSdxdP/ry2Zd26IpsEn2rKpLktyR7n7EITf1e0mSJmcp50nc\neQLtfqC/J3Ed8PyqunwCbUiSttBSFtP95nzXq+rdN7XRqvqlm/q1kqThLGW46aBZn68GDgfOAm5y\nSEiSbh6WMtz0gtmPk+wGvGdiFUmSpsZSZjfNdRUwifsUkqQps5R7EjPnSkAXKvcA3jvJoiRJ02Ep\n9yRmnytxPXBRVV08oXokSVNkKfckPr3YayRJy9NCx5d+k43DTHNVVd11MiVJkqbFQj2JNXMebwc8\nBfh94OyJVSRJmhoLHV/6vwBJtgOeRnf40DnAY6rqi8OUJ0ka00LDTSuBZwEvBj4L/FpVfW2owiRJ\n41touOmbdLOZ/gb4NnCfJPeZebKqPjjh2iRJI1soJD5Bd+P6vv3HbAUYEpoq599/a5yJtfzcePXV\nY5cAwPY/uXbsEjZ4wpcuGbsEAD7wi3uOXcKiFron8YwB65AkTaGbsi2HJGkbYUhIkpoMCUlS06Ih\nkWSnJH+c5G3947sneezkS5MkjW0pPYl3AtcAD+wffxf404lVJEmaGksJibtW1WvozqOmqn4KZKJV\nSZKmwlJC4tokO9Jv9pfkrnQ9C0nSMreU8ySOBT4K7JPkBOBQ4BmTLEqSNB2Wcp7ESUnOAg6hG2Y6\npqp+NPHKJEmjW2iDvwPnXPp+/+sdk9yxqs6aXFmSpGmwUE/idQs8V8DDtnItkqQps9DeTYcNWYgk\nafosZTHd6iS/m+SDST6Q5EVJVi/h645LckmS82dd2yPJSUm+2v+6+5b+BiRJk7OUKbDvBu4JvBH4\nu/7zf1jC1x0PHDHn2kuBk6vq7sDJ/WNJ0pRayhTYe1XVPWY9/lSSRY8vrapTk+w75/KRwEP7z98F\nnAL83yXUIEkawVJ6EmclOWTmQZIHAOtvYnu3raqZWVI/AG7bemGStUnWJ1l/nWv3JGkUS+lJ3B/4\n7yTf7h/fEfhykvOAqqr7tL+0raoqSS3w/DpgHcCu2aP5OknS5CwlJObeV9gSP0xyu6r6fpLbAdNx\nhqAkaV6LDjdV1UXAj4FbArea+aiqi/rnNseHgaf3nz8d+NfN/HpJ0oAW7Ukk+RO6vZq+Tr/JH0tY\nTJfkRLqb1LdOcjHdHlB/Abw3ybOBi4Cn3NTCJUmTt5ThpqfQbRd+7eZ846o6uvHU4ZvzfSRJ41nK\n7Kbzgd0mXYgkafospSfx58DZ/crpDXNRq+pxE6tKkjQVlhIS7wL+EjgPuHGy5UiSpslSQuKnVfW3\nE69EkjR1lhISn0ny53TTV2cPN3mehCQtc0sJifv1vx4y65rnSUjSNmApx5d6roQkbaOW0pMgyWPo\ntgjfcI5EVb1qUkVJkqbDUg4degvw68ALgABPBu404bokSVNgKYvpHlRVvwlcVlWvBB4I7DfZsiRJ\n02ApIfGz/tefJrk9cB1wu8mVJEmaFku5J/FvSXYD/go4i25m09smWpUkaSqkaunn+SRZBayuqism\nV9LP2zV71APivoBa2I6fbh50OLifPeSHY5ewwZVHHbL4iwawy3s+P3YJWsAn6v1nVtWaudebw01J\nDkqy16zHvwm8F/iTJHtMpkxJ0jRZ6J7EW4FrAZL8Mt1ZEO8GrqA/VlSStLwtdE9iRVVd2n/+68C6\nqvoA8IEk50y+NEnS2BbqSaxIMhMihwOfnPXckhbhSZJu3hb6YX8i8OkkP6KbBvsZgCR3oxtykiQt\nc82QqKpXJzmZbk3Ex2vjNKjt6FZfS5KWuQWHjarq5+asVdVXJleOJGmaLGXFtSRpG2VISJKaDAlJ\nUpMhIUlqMiQkSU2GhCSpaWIhkeS4JJckOX/WtScnuSDJjUl+brdBSdJ0mWRP4njgiDnXzgeeAJw6\nwXYlSVvJxPZgqqpTk+w759qXAJJMqllJ0lY0tRv1JVkLrAVYzU4jVyNJ26apvXFdVeuqak1VrVnJ\nqrHLkaRt0tSGhCRpfIaEJKlpklNgTwQ+B+yf5OIkz07y+CQXAw8E/j3JxybVviRpy01ydtPRjac+\nNKk2JUlbl8NNkqQmQ0KS1GRISJKaDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiS\nmqb20CFpc/3s+pVjl7DBNY85aOwSNtjlPZ8fuwQAtt/rtmOXsMH1P/jh2CXcbNiTkCQ1GRKSpCZD\nQpLUZEhIkpoMCUlSkyEhSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUtPEQiLJcUku\nSXL+rGt/leTCJOcm+VCS3SbVviRpy02yJ3E8cMScaycB96qq+wBfAV42wfYlSVtoYiFRVacCl865\n9vGqur5/+Hlg70m1L0nacmPek3gW8J+tJ5OsTbI+yfrruGbAsiRJM0YJiSR/CFwPnNB6TVWtq6o1\nVbVmJauGK06StMHgx5cmeQbwWODwqqqh25ckLd2gIZHkCOAPgIdU1U+HbFuStPkmOQX2ROBzwP5J\nLk7ybODvgF2Ak5Kck+Qtk2pfkrTlJtaTqKqj57n8jkm1J0na+lxxLUlqMiQkSU2GhCSpyZCQJDUZ\nEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQkCQ1DX6ehDQpF332jmOXsEE9cHqOStn3\n38euoHP9D344dglTZ8Xd7jx2CRt9df7L9iQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKT\nISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLUNLGQSHJckkuSnD/r2p8kOTfJOUk+nuT2k2pfkrTl\nJtmTOB44Ys61v6qq+1TVAcC/AS+fYPuSpC00sZCoqlOBS+dc+/GshzsD07PpviTp5wx+6FCSVwO/\nCVwBHLbA69YCawFWs9MwxUmSNjH4jeuq+sOq2gc4AfidBV63rqrWVNWalawarkBJ0gZjzm46AXji\niO1LkhYxaEgkufush0cCFw7ZviRp80zsnkSSE4GHArdOcjFwLPDoJPsDNwIXAc+ZVPuSpC03sZCo\nqqPnufyOSbUnSdr6XHEtSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUpMhIUlqMiQk\nSU2GhCSpafBDh6RJefYTPzZ2CRt84l67jF3CBtvvs/fYJXS2XzF2BRtc/82Lxi4BgBu+9s2xS1iU\nPQlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQ\nkCQ1TSwkkhyX5JIk58/z3O8lqSS3nlT7kqQtN8mexPHAEXMvJtkH+BXg2xNsW5K0FUwsJKrqVODS\neZ56PfAHQE2qbUnS1jHoPYkkRwLfraovDNmuJOmmGexkuiQ7Af+PbqhpKa9fC6wFWM1OE6xMktQy\nZE/irsCdgS8k+RawN3BWkr3me3FVrauqNVW1ZiWrBixTkjRjsJ5EVZ0H7DnzuA+KNVX1o6FqkCRt\nnklOgT0R+Bywf5KLkzx7Um1JkiZjYj2Jqjp6kef3nVTbkqStwxXXkqQmQ0KS1GRISJKaDAlJUpMh\nIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVLTYFuFLxc/+7WDxy4BgB3/5fSxS5g6\nbzr18LFL2GA/pufv56f3vN3YJQCww0fPGLsE3QT2JCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKa\nDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktQ0sZBIclySS5KcP+vaK5J8N8k5/cejJ9W+\nJGnLTbIncTxwxDzXX19VB/Qf/zHB9iVJW2hiIVFVpwKXTur7S5Imb4x7Er+T5Nx+OGr31ouSrE2y\nPsn667hmyPokSb2hQ+LNwF2BA4DvA69rvbCq1lXVmqpas5JVQ9UnSZpl0JCoqh9W1Q1VdSPwNmA6\nzgKVJM1r0JBIMvuw3ccD57deK0ka3/aT+sZJTgQeCtw6ycXAscBDkxwAFPAt4P9Mqn1J0pabWEhU\n1dHzXH7HpNqTJG19rriWJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQkCQ1GRKS\npCZDQpLUZEhIkpomtsHfcrXjv5w+dglquNv+3x+7hA2ycoexS9hgh4+eMXYJAOSge49dwgZ1xnlj\nlwDAinvuP3YJGzUObrAnIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIk\nJElNhoQkqcmQkCQ1TSwkkhyX5JIk58+5/oIkFya5IMlrJtW+JGnLTbIncTxwxOwLSQ4DjgTuW1X3\nBF47wfYlSVtoYiFRVacCl865/FzgL6rqmv41l0yqfUnSlhv6nsR+wC8lOS3Jp5McNHD7kqTNMPSh\nQ9sDewCHAAcB701yl6qquS9MshZYC7CanQYtUpLUGboncTHwweqcDtwI3Hq+F1bVuqpaU1VrVrJq\n0CIlSZ2hQ+JfgMMAkuwH7AD8aOAaJElLNLHhpiQnAg8Fbp3kYuBY4DjguH5a7LXA0+cbapIkTYeJ\nhURVHd146qmTalOStHW54lqS1GRISJKaDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRk\nSEiSmgwJSVKTISFJahr60CEtQ98+9kFjlwDArsdPz4bCu133nbFL2GDF3e8ydgkA3HDGeWOXsEEO\nuvfYJQDT9WfSYk9CktRkSEiSmgwJSVKTISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLUZEhIkpoM\nCUlSkyEhSWoyJCRJTRMLiSTHJbkkyfmzrv1zknP6j28lOWdS7UuSttwktwo/Hvg74N0zF6rq12c+\nT/I64IoJti9J2kITC4mqOjXJvvM9lyTAU4CHTap9SdKWG+vQoV8CflhVX229IMlaYC3AanYaqi5J\n0ixj3bg+GjhxoRdU1bqqWlNVa1ayaqCyJEmzDd6TSLI98ATg/kO3LUnaPGP0JB4OXFhVF4/QtiRp\nM0xyCuyJwOeA/ZNcnOTZ/VNHschQkyRpOkxydtPRjevPmFSbkqStyxXXkqQmQ0KS1GRISJKaDAlJ\nUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVLTWCfTaRnZ8Yc1dgkA7PbVq8Yu\nYSp98//ba+wSALjjK78xdgkbbPfN741dAgA3jF3AEtiTkCQ1GRKSpCZDQpLUZEhIkpoMCUlSkyEh\nSWoyJCRJTYaEJKnJkJAkNRkSkqQmQ0KS1GRISJKaDAlJUtPEQiLJcUkuSXL+rGsHJPl8knOSrE9y\n8KTalyRtuUn2JI4Hjphz7TXAK6vqAODl/WNJ0pSaWEhU1anApXMvA7v2n98SmI5N3SVJ8xr60KEX\nAR9L8lq6gHpQ64VJ1gJrAVaz0zDVSZI2MfSN6+cCL66qfYAXA+9ovbCq1lXVmqpas5JVgxUoSdpo\n6JB4OvDB/vP3Ad64lqQpNnRIfA94SP/5w4CvDty+JGkzTOyeRJITgYcCt05yMXAs8NvAG5JsD1xN\nf89BkjSdJhYSVXV046n7T6pNSdLW5YprSVKTISFJajIkJElNhoQkqcmQkCQ1GRKSpCZDQpLUZEhI\nkpoMCUlSkyEhSWoyJCRJTUMfOqRl6DZv+dzYJWgBV+91/dglTJ2v/d5+Y5cAwJ1fNv3/d+xJSJKa\nDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJSVKTISFJajIkJElNhoQkqWli\nIZHkuCSXJDl/1rX7JvlckgaNfi0AABtYSURBVPOSfCTJrpNqX5K05SbZkzgeOGLOtbcDL62qewMf\nAl4ywfYlSVtoYiFRVacCl865vB9wav/5ScATJ9W+JGnLDX1P4gLgyP7zJwP7tF6YZG2S9UnWX8c1\ngxQnSdrU0CHxLOB5Sc4EdgGubb2wqtZV1ZqqWrOSVYMVKEnaaNDjS6vqQuBXAJLsBzxmyPYlSZtn\n0J5Ekj37X7cD/gh4y5DtS5I2zySnwJ4IfA7YP8nFSZ4NHJ3kK8CFwPeAd06qfUnSlpvYcFNVHd14\n6g2TalOStHW54lqS1GRISJKaDAlJUpMhIUlqMiQkSU2GhCSpyZCQJDUZEpKkJkNCktRkSEiSmgwJ\nSVKTISFJakpVjV3DopL8D3DRFn6bWwM/2grlbA3TUsu01AHTU8u01AHWMp9pqQOmp5atVcedquo2\ncy/eLEJia0iyvqrWjF0HTE8t01IHTE8t01IHWMs01wHTU8uk63C4SZLUZEhIkpq2pZBYN3YBs0xL\nLdNSB0xPLdNSB1jLfKalDpieWiZaxzZzT0KStPm2pZ6EJGkzGRKSpCZDQpLUtKxDIsldk6zqP39o\nkhcm2W3susbkn8n8kuyV5HFJfjXJXiPXcqckD+8/3zHJLmPWMy2S7DRy+6uWcm25WdYhAXwAuCHJ\n3ehmAOwD/NPQRSTZL8nbknw8ySdnPoauozctfyZPSPLVJFck+XGSK5P8eOg6+lp+CzgdeALwJODz\nSZ41Ui2/DbwfeGt/aW/gX8aopa9n9yQHJ/nlmY8RanhQki8CF/aP75vkTUPXAXxuidcmLsmaJB9K\nclaSc5Ocl+TcSbS1/SS+6RS5saquT/J44I1V9cYkZ49Qx/uAtwBvA24Yof3ZpuXP5DXAr1bVl0Zo\ne66XAPerqv8FSHIr4L+B40ao5fnAwcBpAFX11SR7jlDHTHgeQxdU5wCH0P1QfNjApbweeCTwYYCq\n+sKQYdX3LO8A7JjkfkD6p3YFxurdnED37/Y84MZJNrTcQ+K6JEcDTwd+tb+2coQ6rq+qN4/Q7nym\n5c/kh1MSEAD/C1w56/GV/bUxXFNV1ybdz6Ek2wNjzVM/BjgI+HxVHZbkF4A/G6OQqvrOzJ9Jb8g3\nW48EnkEXln896/qVwP8bsI7Z/qeqPjxEQ8s9JJ4JPAd4dVV9M8mdgX8YoY6PJHke8CHgmpmLVXXp\nCLVMy5/J+iT/TDeUMvvP5IMj1PI14LQk/0r3A/lI4Nwkv9vX9NcLffFW9ukk/4/uXesjgOcBHxmw\n/dmurqqrk5BkVVVdmGT/Eer4TpIHAZVkJV14DfYGo6reBbwryROr6gNDtbuIY5O8HTiZCf//2WYW\n0yXZHdinqiYybrdI29+c53JV1V2GrmVaJHnnPJerqga/F5Dk2IWer6pXDljLdsCzgV+hG9b4GPD2\nGuE/apIP0b2peBHdENNlwMqqevTAddwaeAPwcLo/k48Dx8wMDw5Yx+/Oc/kK4MyqOmfgWv4R+AXg\nAjYON03k/8+yDokkpwCPo+sxnQlcAvxXVc33l71N6APr5/7St+XAmiZJdqZ7B39D/3gFsKqqfjpy\nXQ8Bbgl8tKquHbOWsST5J2ANG3t2jwXOBfYF3ldVrxmwli9X1SC9uuU+3HTLqvpxfwPu3VV17KRm\nACyk7yI/F5i52XYK8Naqum7oWuj+kc9YDTwZ2GPoIpLsDbwROLS/9Bm6d4cXD1jD31TVi5J8hPmD\n83FD1TLLyXTvmH/SP96R7p3zg0aoZSakbgvM9Ib3Ar49cA1/O8/lK4D1VfWvA5ayN3BgVf2kr+tY\n4N/p/l+fSTcZYyj/neQeVfXFSTe03ENi+yS3A54C/OGIdbyZ7ubwzLS9p/XXfmvoQubpov9NkjOB\nlw9cyjvppt4+uX/81P7aIwasYeZezGsHbHMxq2d+CAFU1U/GWh+Q5AXAscAPmTWkAdxn4FJW0w2t\nvK9//ES60LpvksOq6kUD1bEns8b/geuA21bVz5Jc0/iaSTkEOKcfGbiGbhiuqmqr/90s95B4Fd2Y\n7mer6owkdwG+OkIdB1XVfWc9/mSSL4xQB0kOnPVwO7qexRj/Dm5TVbPvSxyfZKj/7ABU1Zn9r58e\nst1FXJXkwKo6CyDJ/YGfjVTLMcD+Q4/9z+M+wKGzhuDeTNfzfDDdFNChnMDGCQ7QzQ78p36IcOLv\n6Oc4YqiGlnVIVNX72Pjug6r6Bt27kKHdkOSuVfV1gD6sxlov8bpZn19P947sKSPU8b9Jngqc2D8+\nmpGmnSY5FHgFcCe6/xMz78rGuE/zIuB9Sb7X17EX8Osj1AHwHbphnbHtDtyCjbXsDOxRVTcM+Q6+\nqv4kyX+ycYj0OVW1vv/8N4aqY6acoRpa1iGRZDXdTJF70nVZARhhBs1LgE8l+Qbdf/w70c0aGcOz\n+7DcoJ8GO7Rn0d2TeD3dP/j/Zrw/k3cAL6YbVx51sWPf4/0FYOam5JdHuncF8A3glCT/zqbTLIec\nEgzdWP85/USU0N0D+LP+HfwnhioiyZ8Ap9LNNrtqqHYb/p3u/03ofrbdGfgy3c+6rWq5z256H91S\n/v+PbujpN4AvVdUxI9Syik3/4w89hjlTx1lVdeCca2dW1f3HqGcaJDmtqh4wdh0zktwLuAebvrF5\n9wh1zDs1eMgpwbNquT3dvbwv0fUqLq6qUweu4ZnALwEPpFtI9xng1IFvns+rH0Z+XlVt9fucyz0k\nzq6q+yU5t6ru088y+kxVHTJQ+w+rqk8mecJ8zw+5cKx/d3pPundlL5n11K7AS6pqq78DadTxB1X1\nmiRvZP4ZRS8coo6+lpmwfAqwAvggm75jPmuoWmbVdCzwULqQ+A/gUXT31J40dC2LSfLGqnrBAO3M\nuz1IVQ29PchMPXvR/Zv5fWD3qpqKDRiTnFdV997a33dZDzfRzT4AuLx/d/YDuhkKQ3kI8Ek2bn8x\nW9H9UBrK/nTzunebU8+VwG8PWMfMStn1C75qGK+b83j29OBi+D2KoNtg8L7A2VX1zCS3Bf5xhDqW\n4tDFX7JVTMX2IP0K53vQzfb6DN3f1eBvJPpaZq/12g44EPjeJNpa7iGxrl9p/cd0m4PdggGnelbV\nTHf9VVW1yarroe8D9F3if03ywKoaZefKvo6ZhUg/7ScWbJDkyfN8ySRrOWzI9pbo6qq6Mcn1SXal\nWwC6z9hFjWxatge5FV2P83LgUuBHVXX9CHUAzO69XE93j2IiW4Ys6+GmaTFN9wGm5WZ+48/k564N\nVMsxdGs0rqTbqfdA4KVV9fGB6wjwduD3gKP6X38CnFNVY93Ubxrq72tatgeZVc8v0m3692JgRVXt\nPUYdQ1mWPYnGHisbDDU7Y9Z9gFvOuS+xK7N+QA/sH+hu5j+SWTfzh2o8yaOARwN3mLOSdle6d0Rj\neFZVvSHJI+neLT6N7s9p0JCoqkpycFVdDrwlyUeBXcfYb2yJsvhLtlxVPb7/9BVJPkW/PcgQbc+W\n5LF0N65/mW7Y9pN0w06DS7If3T2RfZn1c3wS92mWZUiwaVdsTNNyH2C2u1XVk5McWVXv6vejGfIf\n+vfo7kc8jm7K6Ywr6d6ZjWHmh92j6bZvuaB/Vz+Gs5IcVFVnVNW3Rqphqd4wdIMjL3w8gu7/yhuq\naiLj/5th5oyatzPhadsONw1g7PsAsyU5vaoOTnIq3TbUPwBOH3rhWJKVI87/30S/I+0d6Oaa35du\n3PmUkYYDLwTuBlwEXMUEt1tYQi1r6LazmbvIcPBatKkhh6uXdUgkeRfdpnGX9493B143wvj7VNwH\n6Gv5LbobXPcGjqe7mf/HVfXWhb5uAnXcHfhzfn49wOCrnNNtz30A8I2qujzdyXR3GGOYJ8md5rte\nVReNUMuXmef0szFqGVOSK5l/hfNMaO46cEkkeQXdpIaJn1GzXIebZtxnJiAAquqydMcPDm3U+wBz\nnFxVl9GtHL0LjLbi+p10m8e9HjiM7sbkKGeu97OJfgjcI91JcKOZsh/Ag51+Ns2mZR3EHE/vf529\n5qno/09vTcu9J/EF4KH9D0WS7AF8ehILThapY9RFfXNqmYqZVjNtzl4ANOKMr7+k2x/pi2wc360a\nZ6vwqZHkcLo9tSZ++tnNRTZunT77ZvGgW6cPbbn3JF4HfK7fngO6banHOKN37EV90zjT6pp+mOer\nSX4H+C7d0NcYfo1ut9NRtkqZYs+k26J7JZtuFb5NhkSmZ+v0TSRZV1VrJ/X9l3VIVNW7k6xn48rZ\nJ9QAh3TMY2ZR3x8xwqK+3rTNtDoG2Al4IfAndH9HT1/wKybnG3Q/CA2JTR1UA51+djMxLVunz7Vm\n8ZfcdMt9uOnZVfWOOdf+oqpeOlZNY5ummVYA/ariqqorR6zhA3SzmuYOqwy2j9Q06md9/dVIb6ym\nTr9G4xEjrrKeV5KPVtXEzpdY1j0J4IlJrq6qEwCS/D0jDK0k+TPgNXNmWf1eVf3R0LUAj09yAd1B\nNh+l6yq/uKoG3R+on175Tvo1LUmuoFvUduaCXzgZH+4/tKnBTj+7mZiWrdM3McmAgOXfk9iR7j//\ncXQLYS6vcbYJP7uq7jfn2lhbUJxTVQckeTzd8NPv0m13fN9FvnRr13Eu8Pyq+kz/+MHAm8b6AZRk\nB2C//uGYZzhMjWmajjumJP9QVU9LcjndbLxN1Dhbp+9HN7NpZg3LTC2uuF6KfhbTjN8C/gX4L+CV\nSfaYxFziRazoNya7pq9vR2DVwDXMWNn/+hjgfVV1xUiLi2+YCQiAqvpsklG68UkeCrwL+Bbdu+V9\nkjy9Bj6vYAot33eQm+f+6c6z+DbdQVnTYGbF9duY8IrrZRkSdNs9zP4HHrofio9hQnOJF3ECcHI/\nxgvdrJF3DVzDjI/0q3p/Bjw3yW2Aq0eo49NJ3kp3fGnRTUE9Jf0ZDzXsWQ6vA36lqr4MG96lnQhs\nswcx9QY7/WzKvYXuftWd2XSL+zDOzxOA66vqzUM0tGyHm/rplQ+sqv8auxbYsLHd4f3Dk6rqYyPW\nsgdwRXVnBO8M7FJVPxi4hk8t8HRNotu8QC3nzh3mmu/ati4TPP3s5iDJm6vquWPXAcOuuF62IQHz\n3wvQRpOeX31zkeQ4unnvMzfvf4NuC+jBt02ZdpnQ6WfaPP1kgrlqEtvaLPeQeC3wOeCDNeJvdM7e\nLzvQ3Re4aow9X2Yb6+Z53/a860Sq6lUj1LIKeD7w4P7SZ+huom/T6yYy/+lnt6qqR45UkkawXO9J\nzPg/dLN3bkjyM0bakGv23i/9FtRH0k0vHNslI7Z91azPV9PNtBplP6s+DP66/9BGg51+ps3Tb+3z\nXLqzLQBOAd46iVl5y7onMc0cCttU/27+Y1X10AHbfG9VPSXJecwzk8d7EppW6c7bXsnGCTBPo5sx\nuNXvFy33ngRJHsestK2qfxuhhtl7JW1Ht4x+0BlFST7CAlMap2Azu52AoY+BnFkz89iB271ZyICn\nn2mzHTRnbdMn+w1Nt7plHRJJ/gI4iG4KKsAxSQ6tqpcNXMrsvZKup5uPf+TANbx24PYWNOfd+wrg\nNnTbqA+mqr7ff/q8qvq/s5/rd4b9vz//VduUwU4/02a7Icldq+rrAEnuwoT+jpb1cFO/qveAqrqx\nf7wCONthhPHNWc17PfDDsfbEaWyfvs1PgR1r63Ytrt/G/Z10W4WEbuX1M6tqoanlN8my7kn0dgNm\n5g7fcsiGk7yRhYd4Bt9ALtNzItztgAtmNvZLskuSe1TVaUMVkOS5dEe43qV/QzFjF7oV+tu6jyR5\nHgPMxdfmqaqT+//LM7v0fnlSs/GWe0/iKOAv6O78h+7exEur6p8Han9m6+tD6X4oz7T7ZOCLVfWc\nIeqYU9Nn2Xgi3K/SnwhXVYNuXZ7kbODAmanJ/eLH9UNOyU1yS2B3utCcvTPwlf4gHHYuvpYmycOq\n6pNz7nNuMIkDoZZ7SPwj8BXgMrr7AGcMvbK4r+PzwINnhlMy7sl0U3Ei3MxGg3OujTrEk2RPNu1d\nLesTx3Tzk+SVVXXsrC1+ZqtJLABd7sNN7wB+CXgccFfg7CSnVtUbBq5jd7oT4Gbend6ivzaGaTkR\n7htJXgjM7D/zPLrx1cEl+VW6NRK3p1s7cie6NRvb2h5FTa7Onw5VdWz/6auqapOeXiZ0Vv2y7knA\nhpvVBwGHAc8BflZVvzBwDc+kG+I5hY3DXq+oqsE3+UtyEN0PwN3oToTble6si8HuBfR17An8Ld2J\ndEW3gdqLqmrwBX791MGHAZ+o7izyw4CnVtWzh65lWo25Ol8/rzHZYiIjAsu6J5HkZGBnuq05PkM3\nt3iMVcbH001PexHwCuCPgb1GqANg36o6A/gJ3f0IkjwZGDQk+r+Ho4ZscwHXVdX/JtkuyXZV9akk\nfzN2UVNmzNX56mWEs+q3m8Q3nSLnAtcC96I7ge1e/VkOQ3sT8ABgx6r6MN250n8/Qh0A860RGXrd\nCEnelWS3WY937zfaG8PlSW4BnAqckOQNbLptyDavJnz6mZZs7ln1Mx8HMqGz6pf9cBN00yuBZ9Ct\nHt2rqgY98Gemazh7K44kX6gBT4Prtyp/NPAUNs6ygu4dyD2q6uChaunrme+0vlG2Kum3S7+abijw\nN+imSp9Q03fg/aAy4Oln2jwZ8Kz65T7c9Dt0N67vTze76Ti6YaehXdffG5mZ7nkbuq2ph/Q9ugNT\nHkd3KNOMK4EXD1wLwHZJdq+qy2DDGRej/Husqtm9hrEOg5pGg51+ps12dpLn0w09zZ6R5+ymzbSa\nbtbKmWOt5u39Ld2CpD2TvBp4EvBHQxZQVV8AvpDkhJH/LGa8Dvh8kvfSvYN/EvDqIQuYs4U7bDxp\nbJTdgqfQYKefabP9A3Ah8Ei67Wx+gwntorxNDDdNg/6G0+F0P4BOrqpBt8Wexh1PkzyIbrPDoltI\nN0j3WUsz5Oln2jwzQ7Mza4smufbKkNhGJLldVX1/zp5JG1TVRQPXcwzwW8AH6YLz14C3VdUoB80n\neTBw96p6Z5Jb0x3pOt+K422GK66nV5LTq+rgJKfSrTH6AXC6J9Npi/T3RT5RVYdNQS3n0p1BflX/\neGfgcyP1aI6l69HsX1X7Jbk98L6qOnToWqSlSPJbdAdA3Ztuiv0tgJdX1Vu2dlvL/Z6EZqmqG5Lc\nmOSWVXXFyOWETW+G3tBfG8PjgfsBZwFU1ff6GXHbtCFPP9Pmqaq395+eCky0Z2dIbHt+ApyX5CRm\nrQUYYUfadwKnJflQ//jX6LZRGcO1VVVJZmaf7TxSHdPmzXSnn72pf/y0/tpWP/1MmyfJn9HtlHB5\n/3h34PeqaqtPiHG4aRsza2faTYy0RciBwIP7h5+pqrNHqCF0K+DvADyCbkfYZwH/NNb9kWkx31qe\nodf3aH6NdUYT2TrFnsQ2ZowwaKmqs+iHeEasofptSX4X+DHditaXV9VJY9Y1JQY7/UybbUWSVTNn\nSPQ7SUxkkbAhsY2ZokOHpslZwOVV9ZKxC5kyLwE+lWST08/GLUm9E4CTZ20Z/kwmtBDU4aZtzLQc\nOjRNklwI3A24iE3v02zTx5cCJFnFAKefafP1W+0c3j88qao+NpF2DIlty7QcOjRNpmXtyLQY4/Qz\nTS+Hm7Y903Lo0NTYVsNgAQ8BPknX05yr6BZAakRztpTZgW4W2lWT2ErGnsQ2ZloOHdL0S3Ln+U4/\n29ZXok+bfobekcAhVfXSxV6/2d/fkNi2JFkD/CHdTciV/eVy/F1zDXn6mbbcpLbad7hp23MC3ayV\n8xh+u3LdDIxx+pk2z5y/l+3otpW5ehJtGRLbnv/pT8eTWuaefjbjSiZ0+pk22+y/l+vpzss5chIN\nOdy0jUlyOHA0cDKbbv/szUhtYsjTzzS97Else54J/ALd/YiZ4SZnrGg+g51+pqVJ8kbmOQ9mxiT2\nYDMktj0HVdX+i79MGu70My3Z+v7XQ+l2TZg5r/7JwBcn0aDDTduYfhn/X1XVRP5BafkY8vQzbZ4k\nnwcePHMU8ST/buxJbHsOAc7pTx27ho3nOTsFVnPNnBtxeZJ70Z1+tueI9Wij3elmm80cJXuL/tpW\nZ0hse44YuwDdbKzrzyn4I+DD9KefjVuSen8BnJXkFLo3er8MvGISDTncJEk3M/0q66cBL6ILh3OA\nvarq9K3d1nZb+xtKWh6S/FmS3WY93j3Jn45ZkzZ4E/AAYMd+3dOVwN9PoiFDQlLLo2aOxwSoqsuA\nR49YjzZ6QFU9n36Vdf93s8MkGjIkJLWs6M+TACZ7+pk223VJVtCvmUhyGya0zY43riW1DHb6mTbb\n3wIfAvZM8mrgSXQTDLY6b1xLahrq9DNtvn4jxsPpZjedXFUTWehoSEiSmhxukjSvIU8/0/QyJCTN\nq6p2mfl89uln41WkMTjcJGnJJnX6maaXPQlJ8xry9DNNL0NCUstgp59pejncJElqsichaRNjnH6m\n6eW2HJLmWg+cSXdk6YHAV/uPA5jQ/kCaXg43SZrXkKefaXrZk5DUMnP62YyJnX6m6eU9CUktg51+\npullT0JSy/F0x5XeB/gA8BBgIpvIaXrZk5DU8ia6Mwp2rKoP9+ddfwA4aNyyNCRDQlLLA6rqwCRn\nQ3f6WRJnN21jHG6S1DLY6WeaXoaEpJa5p599FvizcUvS0FwnIalpqNPPNL0MCUlSk8NNkqQmQ0KS\n1GRISECSTyV55JxrL0ry5iV+/auSPHyR15ySZM0815+R5O82r2JpGIaE1DkROGrOtaP66wtKsqKq\nXl5Vn5hIZdKIDAmp837gMTOLxZLsC9weODrJ+iQXJHnlzIuTfCvJXyY5C3hykuOTPKl/7uVJzkhy\nfpJ1STKrnaclOad/7uC5RSS5TZIP9F9/RpJDJ/h7lhZlSEhAVV0KnA48qr90FPBe4A+rag3d/kUP\nSXKfWV/2v1V1YFW9Z863+7uqOqiq7gXsCDx21nM7VdUBwPOA4+Yp5Q3A66vqIOCJwNu39PcmbQlD\nQtpo9pDTzFDTU/rewtnAPYF7zHr9Pze+z2FJTktyHvCw/utmt0FVnQrsmmS3OV/7cODvkpwDfLh/\nzS224PckbRH3bpI2+lfg9UkOBHYCLgV+Hzio37foeLrT2mZcNfcbJFlNtzHemqr6TpJXzPmauQuT\n5j7eDjikqq7ekt+ItLXYk5B6VfUT4FN0w0An0h24cxVwRZLbsnEoaiEzgfCjvgfwpDnP/zpAkgcD\nV1TVFXOe/zjwgpkHSQ7Y3N+HtDXZk5A2dSLdfkVHVdWF/Q6oFwLfAf5rsS+uqsuTvA04H/gBcMac\nl1zdf8+VwLPm+RYvBP4+ybl0/z9PBZ5zU38z0pZyWw5JUpPDTZKkJkNCktRkSEiSmgwJSVKTISFJ\najIkJElNhoQkqen/B8ns4XRdlkyOAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "plot_explain(res_masks, lbls)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 1 }