{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "\n", "from fastai.vision import *\n", "from fastai.vision.models.wrn import wrn_22\n", "\n", "torch.backends.cudnn.benchmark = True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/ubuntu/.fastai/data/cifar10')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.CIFAR)\n", "path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])\n", "data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 07:08

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
11.4937291.2889110.532400
21.1612371.1032860.604100
30.9583530.9961720.649300
40.8298481.1202790.638500
50.7167440.7248090.752300
60.6340611.1392400.626800
70.5748451.6274890.506100
80.5318480.9125670.712200
90.4892700.7919870.745500
100.4597940.6462390.782000
110.4316010.6402380.789400
120.4027800.6486630.793200
130.3873140.6140630.793000
140.3668000.5946120.813600
150.3383510.6207420.804600
160.3249270.4707620.841500
170.3022580.4682170.844900
180.2861160.4217910.859000
190.2574660.4288250.859200
200.2331210.3431000.887100
210.2057340.3422730.887500
220.1763120.3185320.896700
230.1447740.3283960.896100
240.1199990.2878290.910800
250.0870100.2327550.928600
260.0607230.2363100.931400
270.0425710.2079550.942000
280.0278020.2175850.938900
290.0200100.2098650.943000
300.0162310.2095460.943000
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16()\n", "learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With mixup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 05:42

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
11.8061221.4136670.504500
21.5926651.1892600.590700
31.4615591.0186930.655400
41.3702290.8743070.712100
51.2968080.9138730.704000
61.2508950.8364090.733900
71.2096400.7367760.778600
81.1866050.7537980.767200
91.1665160.7578420.767700
101.1375160.6994500.806500
111.1205710.7360780.780600
121.1037850.9099420.710700
131.0739710.5308250.856600
141.0554550.5838790.831600
151.0358600.5097210.868300
161.0172070.5109950.867800
170.9952230.4466470.889100
180.9625320.3789010.904300
190.9408120.3525700.917800
200.9220710.3321440.928500
210.8992620.3268300.932000
220.8803370.3128920.936600
230.8747890.3064690.940000
240.8658730.3056110.939200
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16().mixup()\n", "learn.fit_one_cycle(24, 3e-3, wd=0.2, div_factor=10, pct_start=0.5)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }