{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-label classification\n", "Previously, we classified things into 2 classes: dogs and cats. But what if our image belongs to multiple classes?\n", "\n", "The process is largely the same in terms of fast.ai, but there are a few differences.\n", "\n", "## 1. Softmax --> Sigmoid\n", "Instead of using a `softmax` activation function to evaluate our classes, we'll use a __sigmoid__ function. Softmax wants to pick one thing -- remember, the $e^{activation}$ means differences between numbers will be greatly accentuated. Sigmoid works better than softmax because it will normalize between 0-1, but it will still allow us to normalize towards multiple classes." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# The sigmoid function aka logistic function\n", "# Sigmoid will be close to 0 when -ve, and close to 1 when +ve\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "def sigmoid(z):\n", " # Apply sigmoid activation function\n", " return 1/(1+np.exp(-z))\n", "\n", "test_input = np.arange(-6, 6, 0.01)\n", "plt.plot(test_input, sigmoid(test_input), linewidth=2)\n", "plt.grid(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Read from CSV\n", "In our previous function, we were able to use a keras-style set up -- put each class in a folder with that class name. Since each image can belong to multiple classes, we'd have to copy each image to multiple folders, which isn't really practical or optimal. So instead we must read our data from a .csv file.\n", "\n", "## Note on one-hot encoding\n", "In single classification, our data might be one-hot encoded. Imagine we have a bunch of data like this:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([['Dog', 'Cat', 'Frog', 'Bird'],\n", " ['0', '0', '1', '0']], dtype='" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Let's take a look at the first image\n", "# Note, because the image is just a matrix of numbers, we can *1.4 to make it brighter\n", "plt.imshow(data.val_ds.denorm(to_np(images))[0]*1.4);" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('agriculture', 1.0),\n", " ('artisinal_mine', 0.0),\n", " ('bare_ground', 0.0),\n", " ('blooming', 0.0),\n", " ('blow_down', 0.0),\n", " ('clear', 1.0),\n", " ('cloudy', 0.0),\n", " ('conventional_mine', 0.0),\n", " ('cultivation', 0.0),\n", " ('habitation', 0.0),\n", " ('haze', 0.0),\n", " ('partly_cloudy', 0.0),\n", " ('primary', 1.0),\n", " ('road', 0.0),\n", " ('selective_logging', 0.0),\n", " ('slash_burn', 0.0),\n", " ('water', 1.0)]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# And to see which labels this has\n", "list(zip(data.classes, labels[0]))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Start off training on small images\n", "image_size=64" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "data = get_data(image_size)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31148b41e48344949d8ce54289e7534e", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, max=6), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# This goes through and resizes all of our images for efficiencies sake?\n", "data = data.resize(int(image_size*1.3), \"tmp\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Now we want to find the learning rate\n", "# First set up our model\n", "learn = ConvLearner.pretrained(f_model, data, metrics=metrics)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d10efd9527314edaa84cd57252fe5ae9", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.220537 0.422463 0.789931 \n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learning_rate_finder = learn.lr_find()\n", "learn.sched.plot()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Note that the learning rate is decreasing fastest at around 0.2\n", "learning_rate = 0.2" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2beac4b5ac8c4f17b50484ebc403462d", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.14736 0.13469 0.879414 \n", " 1 0.142267 0.129391 0.885923 \n", " 2 0.138263 0.127576 0.887258 \n", " 3 0.140281 0.126938 0.888762 \n", " 4 0.137914 0.124001 0.890997 \n", " 5 0.131529 0.123408 0.893647 \n", " 6 0.130181 0.12329 0.891949 \n", "\n" ] }, { "data": { "text/plain": [ "[0.12329047, 0.8919489753927757]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Do some training for the final layers\n", "# We now know that when we train, what we are doing is setting the kernel (image feature filter) values\n", "# and the numbers of the weight matrix in the fully connected layer\n", "learn.fit(learning_rate, 3, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# The satellite images are quite different from the image-net images that this network was trained on\n", "# Therefore, we use slightly higher learning rates than our cats/dogs examples to train the network\n", "differential_learning_rates = np.array([learning_rate/9, learning_rate/3, learning_rate])" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2da22cfecd9144f2880ebf591ef70933", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.120932 0.108382 0.909156 \n", " 1 0.115432 0.105932 0.909191 \n", " 2 0.105167 0.10103 0.915758 \n", " 3 0.111772 0.102488 0.91311 \n", " 4 0.105998 0.100651 0.915267 \n", " 5 0.101719 0.09802 0.916691 \n", " 6 0.098998 0.097133 0.917914 \n", "\n" ] }, { "data": { "text/plain": [ "[0.09713308, 0.91791370766977]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Unfreeze the previous layers so we can train them with our differential learning rates\n", "learn.unfreeze()\n", "learn.fit(differential_learning_rates, 3, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Save the weights\n", "learn.save(f'{image_size}')" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Take a look -- we can see the loss decreasing as we train\n", "learn.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Now we train again with slightly larger images\n", "# Still not 100% sure what the reason for this is -- just simulating different data?\n", "image_size = 128" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "edcf55e4065249f0b9f96d90c1f671a1", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=13), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.098714 0.096537 0.917149 \n", " 1 0.100112 0.095297 0.919537 \n", " 2 0.095763 0.094022 0.920467 \n", " 3 0.095169 0.093924 0.919777 \n", " 4 0.098707 0.09404 0.921207 \n", " 5 0.097631 0.093178 0.921026 \n", " 6 0.095902 0.093244 0.920144 \n", " 7 0.093888 0.092771 0.921678 \n", " 8 0.092554 0.092677 0.921674 \n", " 9 0.09418 0.091928 0.922913 \n", " 10 0.092501 0.091704 0.922293 \n", " 11 0.091666 0.091564 0.922321 \n", " 12 0.090078 0.091956 0.921488 \n", "\n" ] }, { "data": { "text/plain": [ "[0.09195557, 0.9214877960864891]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = get_data(image_size)\n", "learn.set_data(data)\n", "learn.freeze()\n", "learn.fit(learning_rate, 3, cycle_len=1, cycle_mult=3)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "54cb7bc1cf5044f2b38a4b0a6f3fb2e6", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=13), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.092793 0.087805 0.927363 \n", " 1 0.09623 0.088546 0.925771 \n", " 2 0.091078 0.086233 0.92801 \n", " 3 0.087241 0.085489 0.9284 \n", " 4 0.095144 0.089547 0.926124 \n", " 5 0.090785 0.087691 0.926577 \n", " 6 0.087424 0.086907 0.926908 \n", " 7 0.086251 0.086164 0.926724 \n", " 8 0.082321 0.085769 0.92639 \n", " 9 0.080621 0.084496 0.929468 \n", " 10 0.080943 0.084145 0.929588 \n", " 11 0.075996 0.084212 0.929369 \n", " 12 0.077381 0.084321 0.9293 \n", "\n" ] }, { "data": { "text/plain": [ "[0.08432135, 0.9292995924747179]" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Now tweak the initial layers again\n", "learn.unfreeze()\n", "learn.fit(differential_learning_rates, 3, cycle_len=1, cycle_mult=3)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": true }, "outputs": [], "source": [ "learn.save(f'{image_size}')" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Now do it AGAIN, with even bigger image sizes\n", "image_size = 256" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2de642c9cf5848e383a4ec40af258159", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.091138 0.090089 0.924436 \n", " 1 0.089965 0.089359 0.925116 \n", " 2 0.08823 0.089175 0.926036 \n", " 3 0.088119 0.088418 0.926243 \n", " 4 0.085408 0.088381 0.927063 \n", " 5 0.085668 0.088337 0.926008 \n", " 6 0.086202 0.087519 0.927309 \n", "\n" ] }, { "data": { "text/plain": [ "[0.0875194, 0.9273087637207408]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = get_data(image_size)\n", "learn.set_data(data)\n", "learn.freeze()\n", "learn.fit(learning_rate, 3, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ab85ff4d87247f28d6ebc2922ef7309", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=21), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss f2 \n", " 0 0.085767 0.083354 0.932566 \n", " 1 0.08356 0.081909 0.93292 \n", " 2 0.078167 0.080999 0.932661 \n", " 3 0.086357 0.083911 0.930388 \n", " 4 0.083617 0.083095 0.93072 \n", " 5 0.081474 0.083832 0.927965 \n", " 6 0.076187 0.08217 0.93121 \n", " 7 0.075701 0.081892 0.932569 \n", " 34%|███▍ | 174/506 [03:31<06:43, 1.22s/it, loss=0.0751]" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munfreeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdifferential_learning_rates\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_len\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_mult\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/fastai/courses/dl1/fastai/learner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, lrs, n_cycle, wds, **kwargs)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0mlayer_opt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 209\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_gen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwarm_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/courses/dl1/fastai/learner.py\u001b[0m in \u001b[0;36mfit_gen\u001b[0;34m(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, use_clr, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, **kwargs)\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mn_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum_geom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcycle_len\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcycle_len\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_mult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m return fit(model, data, n_epoch, layer_opt.opt, self.crit,\n\u001b[0;32m--> 156\u001b[0;31m metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/courses/dl1/fastai/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, data, epochs, opt, crit, metrics, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0mbatch_num\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstepper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mavg_mom\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mavg_mom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0mdebias_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mavg_mom\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mbatch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/courses/dl1/fastai/model.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, xs, y)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainable_params_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mraw_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# (Interrupted the kernel here as it started running for 21 cycles for some reason. fasi.ai bug?)\n", "learn.unfreeze()\n", "learn.fit(differential_learning_rates, 3, cycle_len=3, cycle_mult=2)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true }, "outputs": [], "source": [ "learn.save(f'{image_size}')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "# Finallly, get our predictions using test time augmentation\n", "# Reminder: TTA = Test Time Augmentation, applying transforms to the test data and taking an average\n", "multi_label_classification_predictions,target_values = learn.TTA()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "collapsed": true }, "outputs": [], "source": [ "predictions = np.mean(multi_label_classification_predictions, 0)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.76759, 0.0016 , 0.02022, 0.00089, 0.00083, 0.99982, 0.00006, 0.00021, 0.43698, 0.02853, 0.00011,\n", " 0.00003, 0.99958, 0.22834, 0.00342, 0.01352, 0.73973],\n", " [0.18095, 0.00006, 0.00077, 0.00014, 0.00077, 0.00035, 0.00022, 0.00009, 0.02922, 0.00091, 0.00005,\n", " 0.99785, 0.99984, 0.01146, 0.0006 , 0.001 , 0.05649],\n", " [0.99405, 0.00013, 0.01334, 0.00012, 0.00006, 0.99976, 0.00001, 0.00005, 0.09454, 0.1095 , 0.00048,\n", " 0.00003, 0.99309, 0.9827 , 0.00023, 0.00075, 0.07534],\n", " [0.01294, 0.9995 , 0.04651, 0.00125, 0.0003 , 0.99966, 0.00012, 0.03368, 0.00482, 0.02961, 0.00009,\n", " 0.00026, 0.96083, 0.30619, 0.00502, 0.00017, 0.99083],\n", " [0.99758, 0.00004, 0.00231, 0.00024, 0.00009, 0.99609, 0.00003, 0.00016, 0.11167, 0.79664, 0.00222,\n", " 0.00135, 0.97242, 0.99619, 0.00044, 0.00086, 0.0222 ]], dtype=float32)" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Predictions contains the likelihood for each of our 17 classes\n", "predictions[:5]" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],\n", " [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.],\n", " [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Target values contains the actual correct classes for each image\n", "target_values[:5]" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9308049768554749" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f2(predictions, target_values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Note on structured/unstructured data\n", "In the next lesson, we will look at processing structured and unstructed data.\n", "\n", "These are not 100% standardized terms, but generally we think of it as:\n", "\n", "1. __Unstructured data__: Every \"piece\" of the data objects we're processing is the same. For example, in an image, everything is a pixel.\n", "\n", "2. __Structured data__: The parts of the data objects are different. For example, if we are processing sales reports, one column could contain profit, another weather data, another quarterly dates, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }