{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hide_input": false
   },
   "outputs": [],
   "source": [
    "import fastai\n",
    "from fastai import *          # Quick access to most common functionality\n",
    "from fastai.vision import *   # Quick access to computer vision functionality"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vision example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Images can be in labeled folders, or a single folder with a CSV."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PosixPath('/data1/jhoward/git/fastai/fastai/../data/mnist_sample')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.MNIST_SAMPLE)\n",
    "path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image folder version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a `DataBunch`, optionally with transforms:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+f+vtL9mv/gkx8Efjh8ILH41ePP8Agr1+zb4EsLnSLe4vtB1nxBenW9Nu5XIFpNZSW0TMUUfPJC0iBs4ym2Rvi2v0n/4JW/t+/wDBDL9ir4b+HvG/7S3/AATz8afEj4xWc9z/AGprN/cWeo6RFiQGCS2tLqdIlbZjJeJnRwxD4K4AJv2jf+CBfwH+FX/BMzWP+ClPwc/4KeaX400TTLiG00/Tdb+Euo+GrfxBdl41lg025vp/MvcbnMckduY5fJkG5PLk2fmlX7Ff8FuviBqP/BW39h3/AIem/softB+Nrr4S/D/xXZaF40+A3iqC0s4fAt3JDDbwX9tHbPsnjke4jj3sJJQbltrCNXSP8daAPq//AIJ1/wDBGj9sj/gqH4N8V+Ov2XT4Pex8GXsNtrw8ReJVsZITLG8iPtKN8hVH+Y4GVPoa+r/hx/wQs/4Jl/s0JB4r/wCCqP8AwWc+GWnNaIZNW+Hnwd1WPV9UAaB5EXzVWWZGA8tuLN1flFblWP5Q0UAfpp/wUw/4KyfsOWn7GDf8Es/+CQfwMufDPwk1e8gvviN4t8U6Zt1XxPd288FxbsrmVpCokgQvJMqsdioiRop3/mXRRQB//9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAmZJREFUSInllr1rImEQxh/PJFgpYSMo4hIbCaYXQqo0iRLQasUihSBY2mvlv2Bjk0ZIo6QJpggpYkJEkIA2oliooNEgilgofrC8zDWJ3J6a3Vw8i7uBaYZn57fz8e67KgCEDdqPTcL+D+CWUuHJyQkeHx9hsVhwcXExj3Mch2AwKNHe3d3h/Px8ZS6S82KxSKIo0nQ6pdlsRowxWQ+FQktzKarQ7/fD6/Vif38fbrcbb29vuL6+BgAkk0mUSiUYjUZUKhUAAGMM+Xz+zyv88J2dHeI4jnQ6nSQuCAL1+31ijNF0OiWXy/VZHuXAX12v15Pf76f7+3sajUbzVqbTablnvwYyGAwUDoep1WpJZpbNZkkQBNJqtd8Hms1mEgSBHh4eqNFoLCxIo9FYaPO3gMfHxzQYDFZu5GQyoUwmQ06nUzaXGkBk5Tq92+vrK3iex+7uLl5eXlCtViWu1Wphs9ng8XigUqnw/Pz8aT7F8+N5fmncarVSNBolxhgNh8OVOsUtVeJ6vZ7a7TYxxqhcLv99IACKxWLzma7SrO3jvb29DY7jFGm/Xdnh4SGlUqn51jabzfW31GAwkMVioVAoRLVabQ7r9XoUCATWA9RoNGQymejo6IiazebCeWy1WmSz2eTyyEP29vYoEonQ7e3t0oMviiJ1u13y+XyyLy25ns7OzuBwOCQDPjg4wOnp6cLgGWMYj8dIJBLI5XKIx+MLmmUmAdrt9oXb+8NEUQQRoV6vIx6Po9Pp4OrqShFkJfDp6QmMMajVagBAoVDAzc0NAODy8hLdbvfLgN9NhfdBbsr+/d/EjQN/AqmGrig38GGyAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (3, 28, 28)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64)\n",
    "data.normalize(imagenet_stats)\n",
    "img,label = data.train_ds[0]\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create and fit a `Learner`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 00:04\n",
      "epoch  train loss  valid loss  accuracy\n",
      "0      0.053222    0.013478    0.996565  (00:04)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn = create_cnn(data, models.resnet18, metrics=accuracy)\n",
    "learn.fit_one_cycle(1, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=16), HTML(value='0.00% [0/16 00:00<00:00]')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.9966)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy(*learn.get_preds())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CSV version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Same as above, using CSV instead of folder name for labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD8Zv2ZP2OrP9ozw9qPizX/ANrf4MfC/T9Pu2t1f4m+MJrW4unWISMYbSytrq6ZACqiQxBGZtqMxRwnUftQf8E0PiP+zf8As66N+1t4c+P3ws+Kvw31fxYfC7+K/hd4iurqLTtZ+zNdLY3MN7aWs8crW6NKMRsu0csCVB+cK/T39pnxz/wTV/4KC/D/AMA/syfsw/8ABQK2/Z/+H3gi2ji8E/Cv4pfCO60+zuNXmSGG51TWNc0u6vY7q+mKsWvJoIYoYgEUJmRnAPzCor1P9rD9jH9oX9izxxaeC/jz4KWzi1ezF/4W8R6XeR32j+JNPYAx32nX0BaG8t3VkYPGxK7wHCtlR5ZQAV71+zRrf7C3jX4XT/s//tT6Vr3gfW7zxT/aOk/HDw3a/wBqnTbc26Rf2ffaSTG1xaB1Mvm28qTI0jHy5wFQeC0UAfo74X+G3wV+CX7Dnxf+Bf7Q/wDwUm+C3xR+FjeGrrXfg94Z8Ea/qN1r1l42Hy6fc2dnc2cMlhFKGeK+STEbRSFsO8cci/nFRRQB/9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAShJREFUSIntlsFtwzAMRX86h+U5zEHMQZw5nEE4COGsQWWP35OCBE1sJbZzaEuAF+FDjyRIUQcAxAft65Owf+D+wJQSzAzuDpI/3MyQUloFPOCmS90dl8sFOWcAwPl8BgA0TYOu69D3PXLOaNt2FZS1HhEkWa1/4nVCVSVJmtn+wHEcSZIRsRa2DCyZRQRTSvsDN8ysDlgapZi708yoqu9mvCxKKVFVOQwDzewugHEctwc+C8LdryUXkX2Bt+BS9spst2mGYRhq57T+UhGZLV0ZIVXdBhgRi2Ur0JnA6oHlxTGz2ZEws6ez+9I+PB6PaNsWOWdEBNwdqvpQO7fG3moSEbmOxe2DsNSxd/vwHRMRiAi6rrueTdOE0+n0UL8a+Kr9sU/UrwR+A2R+RsRJ08RxAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (3, 28, 28)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = ImageDataBunch.from_csv(path, ds_tfms=(rand_pad(2, 28), []), bs=64)\n",
    "data.normalize(imagenet_stats)\n",
    "img,label = data.train_ds[0]\n",
    "img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 00:06\n",
      "epoch  train loss  valid loss  accuracy\n",
      "0      0.047875    0.014735    0.995466  (00:06)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn = create_cnn(data, models.resnet18, metrics=accuracy)\n",
    "learn.fit_one_cycle(1, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}