{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Class Confusion Widget\n",
"This widget was designed to help extrapolate your models decisions through visuals such as graphs or confusion matrices that go more in-depth than the standard `plot_confusion_matrix`. Class Confusion can be used with **both** Tabular and Image classification models. (Note: Due to widgets not exporting well, there will be images instead showing the output. The code will still be there though for you to run!)\n",
"\n",
"This widget was developed for both the regular environment as well as Google Colaboratory (not affiliated with Fast.AI). For those using the latter, a repository is available [here](https://github.com/muellerzr/ClassConfusion)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Images\n",
"\n",
"Before you can use the widget, we need to finish training our model and generate a [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) object"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai.vision import *\n",
"from fastai.widgets import ClassConfusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.PETS)\n",
"path_img = path/'images'\n",
"fnames = get_image_files(path_img)\n",
"pat = r'/([^/]+)_\\d+.jpg$'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), \n",
" size=224, bs=64).normalize(imagenet_stats)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(data, models.resnet34, metrics=error_rate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
epoch
\n",
"
train_loss
\n",
"
valid_loss
\n",
"
error_rate
\n",
"
time
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
1.066235
\n",
"
0.374091
\n",
"
0.109608
\n",
"
00:32
\n",
"
\n",
"
\n",
"
1
\n",
"
0.479194
\n",
"
0.318278
\n",
"
0.098106
\n",
"
00:28
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Class Confusion's constructor differs depending on our use-case. For **Images**, we are interested in the `classlist`, `is_ordered`, and `figsize` variables.\n",
"\n",
"* `interp`: Either a Tabular or Image ClassificationInterpretation object\n",
"\n",
"\n",
"* `classlist`: Here you pass in the list of classes you are interested in looking at. Depending on if you have specific combinations or not you want to try will determine how you pass them in. If we just want to look at all combinations between a few classes, we can pass their class names as a normal array, `['Abyssinian', 'Bengal', 'Birman']`. If we want to pass in a specific combination or three, we pass them in as a list of arrays or tuples, `[('Abyssian', 'Bengal'), ('Bengal', 'Birman')]`. Here we have what our **actual** class was first, and the **prediction** second.\n",
"\n",
"\n",
"* `is_ordered`: This will determine whether to generate all the combinations from the set of names you passed in. If you have a specific listed set of combinations, we want `is_ordered` to be True.\n",
"\n",
"\n",
"* `figsize`: This is a tuple for the size you want your photos to return as. Defaults to (8,8)\n",
"\n",
"Also when you call the function, it will ask for a `k` value. `k` is the same as `k` from `plot_top_losses`, which is the number of images you want to look at."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interp = ClassificationInterpretation.from_learner(learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at an example set for the 'Ragdoll', 'Birman', and 'Maine_Coon' classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter a value for `k`, or the top images you will see: 5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4/4 [00:54<00:00, 12.22s/it]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b8ac317f64dd4137879461758a249b7e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Tab(children=(Output(), Output(), Output(), Output()), _titles={'0': 'Ragdoll | Birman', '1': 'Birman | Ragdol…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"classlist = ['Ragdoll', 'Birman', 'Maine_Coon']\n",
"ClassConfusion(interp, classlist, is_ordered=False, figsize=(8,8))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](https://i.imgur.com/jAE6BVm.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output is now our confused images as well as their filenames, in case we want to go find those particular instances.\n",
"\n",
"Next, let's look at a set of classes in a particular order."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"classlist = [('Ragdoll', 'Birman'), ('British_Shorthair', 'Russian_Blue')]\n",
"ClassConfusion(interp, classlist, is_ordered=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](https://i.imgur.com/EFLUEnQ.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are looking at exact cells from our Confusion Matrix!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tabular\n",
"Tabular has a bit more bells and whistles than the Images does. We'll look at the `ADULT_SAMPLE` dataset for an example. \n",
"\n",
"\n",
"Along with the standard constructor items above, there are two more, `cut_off` and `varlist`:\n",
"\n",
"* `cut_off`: This is the cut-off number, an integer, for plotting categorical variables. It sets a maximum to 100 bars on the graph at a given moment, else it will defaulty show a `Number of values is above 100` messege, and move onto the next set.\n",
"\n",
"\n",
"* `varlist`: This is a list of variables that you specifically want to look at. Defaulty ClassConfusion will use every variable that was used in the model, including `_na`'s."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.tabular import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dep_var = 'salary'\n",
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n",
" .split_by_idx(list(range(800,1000)))\n",
" .label_from_df(cols=dep_var)\n",
" .add_test(test)\n",
" .databunch())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = tabular_learner(data, layers=[200,100], metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"