{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Class Confusion Widget\n", "This widget was designed to help extrapolate your models decisions through visuals such as graphs or confusion matrices that go more in-depth than the standard `plot_confusion_matrix`. Class Confusion can be used with **both** Tabular and Image classification models. (Note: Due to widgets not exporting well, there will be images instead showing the output. The code will still be there though for you to run!)\n", "\n", "This widget was developed for both the regular environment as well as Google Colaboratory (not affiliated with Fast.AI). For those using the latter, a repository is available [here](https://github.com/muellerzr/ClassConfusion)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Images\n", "\n", "Before you can use the widget, we need to finish training our model and generate a [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) object" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision import *\n", "from fastai.widgets import ClassConfusion" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.PETS)\n", "path_img = path/'images'\n", "fnames = get_image_files(path_img)\n", "pat = r'/([^/]+)_\\d+.jpg$'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), \n", " size=224, bs=64).normalize(imagenet_stats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(data, models.resnet34, metrics=error_rate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losserror_ratetime
01.0662350.3740910.10960800:32
10.4791940.3182780.09810600:28
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Class Confusion's constructor differs depending on our use-case. For **Images**, we are interested in the `classlist`, `is_ordered`, and `figsize` variables.\n", "\n", "* `interp`: Either a Tabular or Image ClassificationInterpretation object\n", "\n", "\n", "* `classlist`: Here you pass in the list of classes you are interested in looking at. Depending on if you have specific combinations or not you want to try will determine how you pass them in. If we just want to look at all combinations between a few classes, we can pass their class names as a normal array, `['Abyssinian', 'Bengal', 'Birman']`. If we want to pass in a specific combination or three, we pass them in as a list of arrays or tuples, `[('Abyssian', 'Bengal'), ('Bengal', 'Birman')]`. Here we have what our **actual** class was first, and the **prediction** second.\n", "\n", "\n", "* `is_ordered`: This will determine whether to generate all the combinations from the set of names you passed in. If you have a specific listed set of combinations, we want `is_ordered` to be True.\n", "\n", "\n", "* `figsize`: This is a tuple for the size you want your photos to return as. Defaults to (8,8)\n", "\n", "Also when you call the function, it will ask for a `k` value. `k` is the same as `k` from `plot_top_losses`, which is the number of images you want to look at." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp = ClassificationInterpretation.from_learner(learn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at an example set for the 'Ragdoll', 'Birman', and 'Maine_Coon' classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Please enter a value for `k`, or the top images you will see: 5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:54<00:00, 12.22s/it]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b8ac317f64dd4137879461758a249b7e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Tab(children=(Output(), Output(), Output(), Output()), _titles={'0': 'Ragdoll | Birman', '1': 'Birman | Ragdol…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "classlist = ['Ragdoll', 'Birman', 'Maine_Coon']\n", "ClassConfusion(interp, classlist, is_ordered=False, figsize=(8,8))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.imgur.com/jAE6BVm.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output is now our confused images as well as their filenames, in case we want to go find those particular instances.\n", "\n", "Next, let's look at a set of classes in a particular order." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classlist = [('Ragdoll', 'Birman'), ('British_Shorthair', 'Russian_Blue')]\n", "ClassConfusion(interp, classlist, is_ordered=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.imgur.com/EFLUEnQ.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are looking at exact cells from our Confusion Matrix!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tabular\n", "Tabular has a bit more bells and whistles than the Images does. We'll look at the `ADULT_SAMPLE` dataset for an example. \n", "\n", "\n", "Along with the standard constructor items above, there are two more, `cut_off` and `varlist`:\n", "\n", "* `cut_off`: This is the cut-off number, an integer, for plotting categorical variables. It sets a maximum to 100 bars on the graph at a given moment, else it will defaulty show a `Number of values is above 100` messege, and move onto the next set.\n", "\n", "\n", "* `varlist`: This is a list of variables that you specifically want to look at. Defaulty ClassConfusion will use every variable that was used in the model, including `_na`'s." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.tabular import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = 'salary'\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [FillMissing, Categorify, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_idx(list(range(800,1000)))\n", " .label_from_df(cols=dep_var)\n", " .add_test(test)\n", " .databunch())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.3588330.3815520.81000000:16
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(1, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp = ClassificationInterpretation.from_learner(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGB5JREFUeJzt3XucVWXd9/HPFwaIk4KiIhqaiiCQgpwMU9HSgPKQZ+TWRMz0fszURzPNyvIQj+WdFimmaR4yvb1TQ8G0fMQDHhKQSA1PISIgcvAACMrhd/+x1uiWZpgB59qL2fN9v177xdrXuvZav83MfPd1rbX23ooIzMxSalZ0AWZW+Rw0Zpacg8bMknPQmFlyDhozS85BY2bJOWhsg0hqLeleSe9KuvNTbGekpAcbsraiSNpH0otF17Epk6+jqUySjgPOBnoAS4HpwKUR8fin3O7xwLeBwRGx+lMXuomTFEC3iHil6FoaM49oKpCks4ErgcuAbYCuwNXAoQ2w+R2Al5pCyNSHpKqia2gUIsK3CroBmwPLgKPW06cVWRDNy29XAq3ydUOAN4D/C7wFzAdG5et+DHwIrMr3MRq4CLi1ZNs7AgFU5fdPBP5FNqqaBYwsaX+85HGDgWeAd/N/B5esmwRcDEzOt/Mg0KmW51Zd/3dL6j8MGA68BCwBLijpPxB4Engn7zsWaJmvezR/Lsvz53tMyfbPA94Ebqluyx+zc76PPfP7XYBFwJCifzcK/b0sugDfGvgHCkOB1dV/6LX0+QnwFLA1sBXwBHBxvm5I/vifAC3yP9D3gY75+nWDpdagAdoC7wHd83XbAr3y5Y+CBtgCeBs4Pn/ciPz+lvn6ScCrwK5A6/z+mFqeW3X9P8zr/yawELgNaA/0AlYCO+X9+wF75fvdEfgncGbJ9gLYpYbt/z+ywG5dGjR5n2/m22kDPAD8vOjfi6JvnjpVni2BRbH+qc1I4CcR8VZELCQbqRxfsn5Vvn5VREwkezXvvpH1rAV6S2odEfMj4vka+nwVeDkibomI1RHxB2AmcHBJnxsj4qWIWAH8N9BnPftcRXY8ahVwO9AJuCoilub7fx7YHSAipkbEU/l+XwOuBfarx3P6UUR8kNfzCRFxHfAy8DRZuH6/ju1VPAdN5VkMdKrj2EEXYHbJ/dl520fbWCeo3gfabWghEbGcbLpxKjBf0gRJPepRT3VN25Xcf3MD6lkcEWvy5eogWFCyfkX14yXtKuk+SW9Keo/suFan9WwbYGFErKyjz3VAb+BXEfFBHX0rnoOm8jxJNjU4bD195pEd1K3WNW/bGMvJpgjVOpeujIgHIuJAslf2mWR/gHXVU13T3I2saUNcQ1ZXt4jYDLgAUB2PWe+pWkntyI57/Ra4SNIWDVFoY+agqTAR8S7Z8YlfSzpMUhtJLSQNk3R53u0PwIWStpLUKe9/60bucjqwr6SukjYHzq9eIWkbSYdIagt8QDYFW1PDNiYCu0o6TlKVpGOAnsB9G1nThmhPdhxpWT7aOm2d9QuAnTZwm1cBUyPiZGACMO5TV9nIOWgqUET8F9k1NBeSHQidA5wO3JN3uQSYAswA/gFMy9s2Zl9/Ae7ItzWVT4ZDM7KzV/PIzsTsB/xnDdtYDHwt77uY7IzR1yJi0cbUtIHOAY4jO5t1HdlzKXURcJOkdyQdXdfGJB1KdkD+1LzpbGBPSSMbrOJGyBfsmVlyHtGYWXIOGjNLzkFjZsk5aMwsuSb9hjBVtQ61bF90GVaLvrt1LboEq8O0aVMXRcRWdfVr2kHTsj2tutd5xtIKMvnpsUWXYHVo3ULrXtFdI0+dzCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZclVFF2D1N+5HIxm2b28WLllK/6MuA+D73xrOSYcPZuHbywD40djxPPD4C/TvtQNjfzACAAkuHTeR8Q/PKKz2pmzOnDmcPOoEFix4k2bNmnHS6FM4/YzvFF1WWW3SQSPpd8B+wLt504kRMV2SgKuA4cD7efs0SUOAcyLia0XUm9ot9z7FuDse4fqLT/hE+69ufZgrb3noE23PvzqPvUdezpo1a+ncaTOevuN8Jjz6HGvWrC1nyQZUVVUx5vIr6LvnnixdupTBg/rxpS8fyG49exZdWtlsclMnSS0ltS1pOjci+uS36XnbMKBbfjsFuKbcdRZh8rRXWfLu+/Xqu2Llqo9CpVXLFkREytJsPbbddlv67rknAO3bt6dHj92YN29uwVWV1yYTNJJ2k3QF8CKwax3dDwVujsxTQAdJ266zvQGSnpW0U6KSNxmnHrsvf7vjfMb9aCQd2rf+qH1A7x2Y+j/fZ8qdF3DGpbd7NLMJmP3aa0yf/iwDBg4qupSyKjRoJLWVNErS48D1wD+B3SPi2ZJul0qaIekXklrlbdsBc0r6vJG3VW93MDAOODQi/pX2WRTrujsfo+fBFzHo2DG8ueg9xpx9+EfrnnluNv2OvJQv/sflnHvSQbRquUnPlCvesmXLGHH0EfzsiivZbLPNii6nrIoe0cwHRgMnR8TeEXF9RCwtWX8+0AMYAGwBnJe3q4ZtVc8NdgN+AxwcEa+v20nSKZKmSJoSq1c01PMozFtLlrJ2bRAR3HDXZPr33uHf+rw4awHLV3xIr126FFChAaxatYoRRx/BMSNGctjXD6/7ARWm6KA5EpgL3C3ph5I+8VcSEfPz6dEHwI3AwHzVG8BnS7puD8zLl+cDK4G+Ne0wIn4TEf0jor+qWtfUpVHp3OnjV8ZDD9iDF16dD8AOXbakefPsx9t1247suuM2zJ63uJAam7qI4NRvjqZ7j934zllnF11OIQodS0fEg8CDkrYE/gP4k6RFZCOc1yRtGxHz87NMhwHP5Q8dD5wu6XZgEPBu3q878A7ZKOlBScsjYlK5n1cqN/30RPbp141OHdrxyp8v5uJxE9m3Xzd27749EcHs+Uv49iV/AGBw3504Z9RBrFq9hrVrg+9cdgeL31le8DNomp6YPJnbfn8LvXt/nkH9+gDw40suY+iw4QVXVj7a1M5GSBoIzI+IOZL+P7AV2VRpOnBqRCzLg2csMJTs9PaoiJhSenpbUlfgfuCkiHi6pn01a7N1tOp+dBmelW2Mt58ZW3QJVofWLTQ1IvrX1W+TOzoYEX8rWT6glj4B/J8a2icBk/Ll14FeSYo0sw1S9DEaM2sCHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0ZpZcVW0rJN0LRG3rI+KQJBWZWcWpNWiAn5etCjOraLUGTUQ8Us5CzKxyrW9EA4CkbsBPgZ7AZ6rbI2KnhHWZWQWpz8HgG4FrgNXA/sDNwC0pizKzylKfoGkdEQ8BiojZEXERcEDassysktQ5dQJWSmoGvCzpdGAusHXassysktRnRHMm0AY4A+gHHA98I2VRZlZZ6hzRRMQz+eIyYFTacsysEtXnrNPD1HDhXkT4OI2Z1Ut9jtGcU7L8GeAIsjNQZmb1Up+p09R1miZL8sV8ZlZv9Zk6bVFytxnZAeHOySoqox47b8fNd11adBlWixUfrim6BGsg9Zk6TSU7RiOyKdMsYHTKosysstQnaHaLiJWlDZJaJarHzCpQfa6jeaKGticbuhAzq1zr+zyazsB2QGtJfcmmTgCbkV3AZ2ZWL+ubOn0FOBHYHriCj4PmPeCCtGWZWSVZ3+fR3ATcJOmIiPhjGWsyswpTn2M0/SR1qL4jqaOkSxLWZGYVpj5BMywi3qm+ExFvA8PTlWRmlaY+QdO89HS2pNaAT2+bWb3V5zqaW4GHJN2Y3x8F3JSuJDOrNPV5r9PlkmYAXyY78/RnYIfUhZlZ5ajvF8i9Cawle+f2l4B/JqvIzCrO+i7Y2xU4FhgBLAbuIPvc4P3LVJuZVYj1TZ1mAo8BB0fEKwCSzipLVWZWUdY3dTqCbMr0sKTrJH2Jj68ONjOrt1qDJiLujohjgB7AJOAsYBtJ10g6qEz1mVkFqPNgcEQsj4jfR8TXyN73NB34XvLKzKxi1PesEwARsSQirvUHk5vZhtigoDEz2xgOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0jdQHH6zkG4cdwHHD9+bor+zFtb+4DICI4OqfX8wRB/TjqAMHcvvvxhVcqQGM+/UvGdx/D77Qf3euGXtV0eWUXVXRBdRG0hDgT8CsvOmuiPhJvm4ocBXQHLg+Isbk7a8B/SNiUdkLLrOWLVtxze/H06ZtO1avWsXJRw9l8JADmfXKiyyY/wZ3/vUZmjVrxpJFC4sutcl74fnnuPnG3/LXR5+kZcuWHHXocA4aOpydd+lWdGllk2xEI6ljA2zmsYjok9+qQ6Y58GtgGNATGCGpZwPsq1GRRJu27QBYvXoVq1evQhJ//P0NnPzt82jWLPvRbtFpqyLLNOClF2fSf+Ag2rRpQ1VVFYP32ZcJ4+8puqyySjl1miLpNkkHSFIDbncg8EpE/CsiPgRuBw4t7SCptaQ/S/pmA+53k7NmzRqO++oXOWhANwbtvT+9+/Rn7uuz+MuEuzjhkCGcMepIXp/1atFlNnm79ezFk5MfY8nixbz//vv85YH7mTv3jaLLKquUQbMrcBtwOvCCpAskdaleKekXkqbXcPteyTa+IOnvku6X1Ctv2w6YU9LnjbytWjvgXuC2iLhu3aIknSJpiqQpby9Z3EBPtRjNmzfntgmPM+GJ53l+xlReefEFPvzwQ1q2asXN4ydx2DEncPF5pxddZpPXvcdunHH2uRx+8FCOOmw4vT+/B82bNy+6rLJKFjQRsSYi7ouIw4F9gZ2A1yUNzNefVTItKr2NyTcxDdghIvYAfgVUjzVrGh1FyfKfgBsj4uZa6vpNRPSPiP4dt9iyAZ5p8dpv1oF+g77Ik48+xNadu3DA0EMA2P8rB/PyzOcLrs4Ajv/GSUx64hkmPDiJjh07NqnjM5D4rJOkzSWdAownG+GMBmbk69Y7oomI9yJiWb48EWghqRPZCOazJbvZHphXcn8yMKyBp2ubnLcXL2Lpe+8AsHLlCv42+RF23Kkb+x34VaY88SgA055+nK6f27nIMi238K23AHhjzuvcN/4ejjjq2IIrKq9kZ50k3Qp8AbgTOCEiXi5dHxFn1fH4zsCCiIh8FNQMWAy8A3ST9DlgLnAscFzJQ38I/AC4GjitgZ7OJmfRW29y0bmnsXbNGtZG8OXhh7HPl4bSZ8Be/ODMU7jthmto07YtF475ZdGlGvCNkUexZMkSWlS14PL/+iUdOjbEuZLGQxFRd6+N2bB0CDAxIlZv5ONPJwuK1cAK4OyIeCJfNxy4kuz09g0RcWne/hrQnyyQbgAWRsR3a9tHz8/3jZvHT9qY8qwMdt6mXdElWB22aFs1NSL619Uv2YgmIsZ/ysePBcbWsm4iMLGG9h1L7o76NPs3s4bjK4PNLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyioiiayiMpIXA7KLraECdgEVFF2HrVWk/ox0iYqu6OjXpoKk0kqZERP+i67DaNdWfkadOZpacg8bMknPQVJbfFF2A1alJ/ox8jMbMkvOIxsySc9CYWXIOGjNLzkHTyEnasYa2AeWvxKx2DprG7y5J21XfkbQfcEOB9dg6JI2uoW1MEbUUxUHT+H0LuEdSZ0nDgauA4QXXZJ90pKSR1XckXQ3Uedl+JfHp7Qog6QvAtcBK4KsRsbDgkqyEpNbAeLKR5jBgSUScWWxV5eWgaaQk3QuU/vB6AvOBtwEi4pAi6rKPSdqi5G574B5gMvBDgIhYUkRdRXDQNFL5sZhaRcQj5arFaiZpFtmLgWpYHRGxU5lLKoyDppGTtA2wHdkv9LyIWFBwSWb/xkHTSEnqA4wDNgfm5s3bA+8Ap0XEs0XVZh+T1AM4lJIXA+BPETGz0MLKzEHTSEmaDnwrIp5ep30v4NqI2KOYyqyapPOAEcDtwBt58/bAscDtEdFkTnE7aBopSS9HRLda1r0SEbuUuyb7JEkvAb0iYtU67S2B52v7+VWiqqILsI12v6QJwM3AnLzts8AJwJ8Lq8pKrQW68O8fF7ttvq7J8IimEZM0jI/n/yIbno+PiImFFmYASBoKjAVe5uMXg67ALsDpEdFkXhAcNGYJSWoGDOSTLwbPRMSaQgsrMwdNIyVp94iYkS+3AM4j+4V+DrgkIt4vsj6rmaQtmtKFetX8XqfG63cly2PIhuNXAK3JTntbwSRdWLLcMz84PFXSa5IGFVha2XlE00hJejYi+ubL04EBEbFKkoC/R8TuxVZokqZFxJ758gRgbETcL2kgcGVEDC62wvLxWafGa3NJXycblbaqPoUaESHJrx6bni4RcT9ARPwtf6Nlk+GgabweAarfOPmUpG0iYoGkzlTWNyE2ZjtJGk92EHh7SW1Kjp21KLCusvPUySyRGt74OjUiluXvTzsyIn5dRF1FcNA0YpLaAN0i4u8lbV2BNRExt/ZHmpWXzzo1bqvIPsqzbUnb9WRXntomRNJ3S/9tahw0jVh+APhu4Bj4aDSzVURMKbQwq8mx6/zbpDhoGr/rgVH58gnAjQXWYnWr6UOwKp7POjVyETFTEpJ2JftIgi8WXZPZujyiqQy/JRvZzIiIt4suxmxdDprK8N/AHmSBY7bJ8dSpAuQXgW1edB22XpPyfx8usoii+DoaM0vOUyezMpHUP/8YzybHQWNWBpK2BZ4Aji66liJ46mRWBpK+B+xM9paRIQWXU3Ye0ZiVx/HA+UBLSTsXXUy5OWjMEpO0PzAzIhaRXbk9uuCSys5BY5beaD6+xukO4Kj8Q8ubjCb1ZM3KTVIHYC+g+tP13gOeAoYXWVe5+WCwmSXnEY2ZJeegMbPkHDT2qUlaI2m6pOck3Zl/xOjGbmuIpPvy5UPy609q69tB0n9uxD4uknTOxtZoG85BYw1hRUT0iYjewIfAqaUrldng37WIGB8RY9bTpQOwwUFj5eegsYb2GLCLpB0l/VPS1cA04LOSDpL0pKRp+cinHYCkoZJmSnocOLx6Q5JOlDQ2X95G0t2S/p7fBpN9Q+fO+WjqZ3m/cyU9I2mGpB+XbOv7kl6U9Fege9n+Nwxw0FgDklQFDAP+kTd1B27Ov1FzOXAh8OX82xunAGdL+gxwHXAwsA/QuZbN/xJ4JCL2APYEnge+B7yaj6bOlXQQ0I3sO8j7AP0k7SupH9ln9fYlC7IBDfzUrQ7+PBprCK3zr+WFbETzW6ALMDsinsrb9wJ6ApOzb+2lJfAk0AOYFREvA0i6FTilhn0cQPaZyETEGuBdSR3X6XNQfns2v9+OLHjaA3dXf3lb/qVuVkYOGmsIKyKiT2lDHibLS5uAv0TEiHX69QEa6mIuAT+NiGvX2ceZDbgP2wieOlm5PAXsLWkXyL78Lv9A9ZnA50reaDiilsc/BJyWP7a5pM2ApWSjlWoPACeVHPvZTtLWwKPA1yW1ltSebJpmZeSgsbKIiIXAicAfJM0gC54eEbGSbKo0IT8YPLuWTXwH2F/SP4CpQK+IWEw2FXtO0s8i4kHgNuDJvN//AO0jYhrZe4ymA38km95ZGfktCGaWnEc0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl979fjsdVs6CPmwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With tabular problems, looking at each *individual* row will probably not help us much. Instead what **Class Confusion** will do is plot every variable at whatever combination we dictate, and we can see how the distrobution of those variables in our misses changed in relative to our overall dataset distribution. For example, let's explore `>=50k` and `<50k`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [00:06<00:00, 1.26it/s]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8891caf76d444f98bb618e698a3a554b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ClassConfusion(interp, ['>=50k', '<50k'], figsize=(12,12))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.imgur.com/iUUSp2A.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can see the distrobutions for each of those two missed boxes in our confusion matrix, and look at what is really going on there. If we look at education, we can see that for many times where we thought people were making above or below 50k, they were often graduates of high school and persuing some college degree. \n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also look at the distrobution for continuous categories as well. Shown below is `age`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [00:06<00:00, 1.29it/s]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4eabd93c293d4a5ab8796ea770867bbc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ClassConfusion(interp, ['>=50k', '<50k'], figsize=(12,12))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.imgur.com/jMiTb3y.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we want to look at specific variables, we pass them into `varlist`. Below is `age`, `education`, and `relationship`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3/3 [00:01<00:00, 1.40it/s]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ab9628b93ba24c67a305c4146361428b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Tab(children=(Output(), Output(), Output()), _titles={'0': 'age', '1': 'education', '2': 'relationship'})" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ClassConfusion(interp, ['>=50k', '<50k'], varlist=['age', 'education', 'relationship'],\n", " figsize=(12,12))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.imgur.com/ZIqwljr.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot the distrobution for our true positives as well, if we want to compare those by using histograms:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3/3 [00:01<00:00, 1.49it/s]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "36ea71ed9c7c48fc9860b81189a375bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Tab(children=(Output(), Output(), Output()), _titles={'0': 'age', '1': 'education', '2': 'relationship'})" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ClassConfusion(interp, [['>=50k', '>=50k'], ['>=50k', '<50k']], varlist=['age', 'education', 'relationship'],\n", " is_ordered=True, figsize=(12,12))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.imgur.com/xNUUPz0.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }