{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from fastai import *\n",
    "from fastai.vision import *\n",
    "from fastai.vision.models.wrn import wrn_22\n",
    "\n",
    "torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PosixPath('/data1/jhoward/git/fastai/fastai/../data/cifar10')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.CIFAR)\n",
    "path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])\n",
    "data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, tfms=cifar_norm, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=30), HTML(value='0.00% [0/30 00:00<00:00]'))), HTML(val…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 07:22\n",
      "epoch  train loss  valid loss  accuracy\n",
      "0      1.395528    1.194204    0.564700  (00:23)\n",
      "1      1.044369    0.964681    0.652700  (00:14)\n",
      "2      0.859872    1.128236    0.631100  (00:14)\n",
      "3      0.705556    0.816264    0.724100  (00:14)\n",
      "4      0.608872    0.850188    0.717200  (00:14)\n",
      "5      0.544764    0.768168    0.745700  (00:14)\n",
      "6      0.501907    0.717987    0.758500  (00:14)\n",
      "7      0.467273    0.690202    0.770600  (00:14)\n",
      "8      0.438323    0.557466    0.822600  (00:14)\n",
      "9      0.414395    0.535755    0.815700  (00:14)\n",
      "10     0.390469    0.460306    0.846000  (00:14)\n",
      "11     0.376003    0.528630    0.820200  (00:14)\n",
      "12     0.355230    0.699035    0.782800  (00:14)\n",
      "13     0.343304    0.442456    0.849400  (00:14)\n",
      "14     0.337318    0.559917    0.815800  (00:14)\n",
      "15     0.321129    0.733412    0.769400  (00:14)\n",
      "16     0.303034    0.561098    0.810200  (00:14)\n",
      "17     0.278229    0.402861    0.864800  (00:14)\n",
      "18     0.257787    0.504758    0.845600  (00:14)\n",
      "19     0.237806    0.401391    0.873000  (00:14)\n",
      "20     0.209473    0.487110    0.854300  (00:14)\n",
      "21     0.184482    0.434822    0.861900  (00:14)\n",
      "22     0.154364    0.310643    0.899500  (00:14)\n",
      "23     0.117682    0.278385    0.911500  (00:14)\n",
      "24     0.091463    0.249205    0.922400  (00:14)\n",
      "25     0.066957    0.249959    0.928400  (00:14)\n",
      "26     0.043469    0.219788    0.936900  (00:14)\n",
      "27     0.028288    0.213345    0.939300  (00:14)\n",
      "28     0.020876    0.209119    0.942300  (00:14)\n",
      "29     0.017776    0.210493    0.941400  (00:14)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16()\n",
    "learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With mixup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=24), HTML(value='0.00% [0/24 00:00<00:00]'))), HTML(val…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 05:52\n",
      "epoch  train loss  valid loss  accuracy\n",
      "0      1.754659    1.418296    0.510000  (00:14)\n",
      "1      1.536941    1.274806    0.569400  (00:14)\n",
      "2      1.398102    0.972869    0.675200  (00:14)\n",
      "3      1.308902    1.026916    0.680600  (00:14)\n",
      "4      1.252701    1.091272    0.649400  (00:14)\n",
      "5      1.199010    0.735006    0.765200  (00:14)\n",
      "6      1.181147    0.814292    0.750400  (00:14)\n",
      "7      1.152908    0.697707    0.791000  (00:14)\n",
      "8      1.132457    0.746398    0.768600  (00:14)\n",
      "9      1.116421    0.715845    0.788800  (00:14)\n",
      "10     1.100030    0.711368    0.792600  (00:14)\n",
      "11     1.088852    0.572201    0.841300  (00:14)\n",
      "12     1.075793    0.733747    0.790300  (00:14)\n",
      "13     1.055172    0.555941    0.851700  (00:14)\n",
      "14     1.036253    0.516668    0.866700  (00:14)\n",
      "15     1.013288    0.522676    0.859600  (00:14)\n",
      "16     0.996412    0.494712    0.866500  (00:14)\n",
      "17     0.971465    0.406641    0.902100  (00:14)\n",
      "18     0.943768    0.393006    0.906800  (00:14)\n",
      "19     0.919369    0.337167    0.924500  (00:14)\n",
      "20     0.900417    0.327883    0.930300  (00:14)\n",
      "21     0.887146    0.313054    0.936300  (00:14)\n",
      "22     0.872677    0.309915    0.937800  (00:14)\n",
      "23     0.864825    0.305209    0.938900  (00:14)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16().mixup()\n",
    "learn.fit_one_cycle(24, 3e-3, wd=0.2, div_factor=10, pct_start=0.5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}