{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from fastai.vision import *\n",
"from fastai.vision.models.wrn import wrn_22\n",
"\n",
"torch.backends.cudnn.benchmark = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('/home/ubuntu/.fastai/data/cifar10')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.CIFAR)\n",
"path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])\n",
"data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 07:08
\n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 1.493729 | \n",
" 1.288911 | \n",
" 0.532400 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.161237 | \n",
" 1.103286 | \n",
" 0.604100 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.958353 | \n",
" 0.996172 | \n",
" 0.649300 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.829848 | \n",
" 1.120279 | \n",
" 0.638500 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.716744 | \n",
" 0.724809 | \n",
" 0.752300 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.634061 | \n",
" 1.139240 | \n",
" 0.626800 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.574845 | \n",
" 1.627489 | \n",
" 0.506100 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.531848 | \n",
" 0.912567 | \n",
" 0.712200 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.489270 | \n",
" 0.791987 | \n",
" 0.745500 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.459794 | \n",
" 0.646239 | \n",
" 0.782000 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.431601 | \n",
" 0.640238 | \n",
" 0.789400 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.402780 | \n",
" 0.648663 | \n",
" 0.793200 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.387314 | \n",
" 0.614063 | \n",
" 0.793000 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.366800 | \n",
" 0.594612 | \n",
" 0.813600 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.338351 | \n",
" 0.620742 | \n",
" 0.804600 | \n",
"
\n",
" \n",
" 16 | \n",
" 0.324927 | \n",
" 0.470762 | \n",
" 0.841500 | \n",
"
\n",
" \n",
" 17 | \n",
" 0.302258 | \n",
" 0.468217 | \n",
" 0.844900 | \n",
"
\n",
" \n",
" 18 | \n",
" 0.286116 | \n",
" 0.421791 | \n",
" 0.859000 | \n",
"
\n",
" \n",
" 19 | \n",
" 0.257466 | \n",
" 0.428825 | \n",
" 0.859200 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.233121 | \n",
" 0.343100 | \n",
" 0.887100 | \n",
"
\n",
" \n",
" 21 | \n",
" 0.205734 | \n",
" 0.342273 | \n",
" 0.887500 | \n",
"
\n",
" \n",
" 22 | \n",
" 0.176312 | \n",
" 0.318532 | \n",
" 0.896700 | \n",
"
\n",
" \n",
" 23 | \n",
" 0.144774 | \n",
" 0.328396 | \n",
" 0.896100 | \n",
"
\n",
" \n",
" 24 | \n",
" 0.119999 | \n",
" 0.287829 | \n",
" 0.910800 | \n",
"
\n",
" \n",
" 25 | \n",
" 0.087010 | \n",
" 0.232755 | \n",
" 0.928600 | \n",
"
\n",
" \n",
" 26 | \n",
" 0.060723 | \n",
" 0.236310 | \n",
" 0.931400 | \n",
"
\n",
" \n",
" 27 | \n",
" 0.042571 | \n",
" 0.207955 | \n",
" 0.942000 | \n",
"
\n",
" \n",
" 28 | \n",
" 0.027802 | \n",
" 0.217585 | \n",
" 0.938900 | \n",
"
\n",
" \n",
" 29 | \n",
" 0.020010 | \n",
" 0.209865 | \n",
" 0.943000 | \n",
"
\n",
" \n",
" 30 | \n",
" 0.016231 | \n",
" 0.209546 | \n",
" 0.943000 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16()\n",
"learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With mixup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 05:42 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 1.806122 | \n",
" 1.413667 | \n",
" 0.504500 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.592665 | \n",
" 1.189260 | \n",
" 0.590700 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.461559 | \n",
" 1.018693 | \n",
" 0.655400 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.370229 | \n",
" 0.874307 | \n",
" 0.712100 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.296808 | \n",
" 0.913873 | \n",
" 0.704000 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.250895 | \n",
" 0.836409 | \n",
" 0.733900 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.209640 | \n",
" 0.736776 | \n",
" 0.778600 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.186605 | \n",
" 0.753798 | \n",
" 0.767200 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.166516 | \n",
" 0.757842 | \n",
" 0.767700 | \n",
"
\n",
" \n",
" 10 | \n",
" 1.137516 | \n",
" 0.699450 | \n",
" 0.806500 | \n",
"
\n",
" \n",
" 11 | \n",
" 1.120571 | \n",
" 0.736078 | \n",
" 0.780600 | \n",
"
\n",
" \n",
" 12 | \n",
" 1.103785 | \n",
" 0.909942 | \n",
" 0.710700 | \n",
"
\n",
" \n",
" 13 | \n",
" 1.073971 | \n",
" 0.530825 | \n",
" 0.856600 | \n",
"
\n",
" \n",
" 14 | \n",
" 1.055455 | \n",
" 0.583879 | \n",
" 0.831600 | \n",
"
\n",
" \n",
" 15 | \n",
" 1.035860 | \n",
" 0.509721 | \n",
" 0.868300 | \n",
"
\n",
" \n",
" 16 | \n",
" 1.017207 | \n",
" 0.510995 | \n",
" 0.867800 | \n",
"
\n",
" \n",
" 17 | \n",
" 0.995223 | \n",
" 0.446647 | \n",
" 0.889100 | \n",
"
\n",
" \n",
" 18 | \n",
" 0.962532 | \n",
" 0.378901 | \n",
" 0.904300 | \n",
"
\n",
" \n",
" 19 | \n",
" 0.940812 | \n",
" 0.352570 | \n",
" 0.917800 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.922071 | \n",
" 0.332144 | \n",
" 0.928500 | \n",
"
\n",
" \n",
" 21 | \n",
" 0.899262 | \n",
" 0.326830 | \n",
" 0.932000 | \n",
"
\n",
" \n",
" 22 | \n",
" 0.880337 | \n",
" 0.312892 | \n",
" 0.936600 | \n",
"
\n",
" \n",
" 23 | \n",
" 0.874789 | \n",
" 0.306469 | \n",
" 0.940000 | \n",
"
\n",
" \n",
" 24 | \n",
" 0.865873 | \n",
" 0.305611 | \n",
" 0.939200 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16().mixup()\n",
"learn.fit_one_cycle(24, 3e-3, wd=0.2, div_factor=10, pct_start=0.5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}