{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "\n", "from fastai import *\n", "from fastai.vision import *\n", "from fastai.vision.models.wrn import wrn_22\n", "\n", "torch.backends.cudnn.benchmark = True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/ubuntu/.fastai/data/cifar10')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.CIFAR)\n", "path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])\n", "data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=30), HTML(value='0.00% [0/30 00:00<00:00]'))), HTML(val…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 07:22\n", "epoch train loss valid loss accuracy\n", "0 1.395528 1.194204 0.564700 (00:23)\n", "1 1.044369 0.964681 0.652700 (00:14)\n", "2 0.859872 1.128236 0.631100 (00:14)\n", "3 0.705556 0.816264 0.724100 (00:14)\n", "4 0.608872 0.850188 0.717200 (00:14)\n", "5 0.544764 0.768168 0.745700 (00:14)\n", "6 0.501907 0.717987 0.758500 (00:14)\n", "7 0.467273 0.690202 0.770600 (00:14)\n", "8 0.438323 0.557466 0.822600 (00:14)\n", "9 0.414395 0.535755 0.815700 (00:14)\n", "10 0.390469 0.460306 0.846000 (00:14)\n", "11 0.376003 0.528630 0.820200 (00:14)\n", "12 0.355230 0.699035 0.782800 (00:14)\n", "13 0.343304 0.442456 0.849400 (00:14)\n", "14 0.337318 0.559917 0.815800 (00:14)\n", "15 0.321129 0.733412 0.769400 (00:14)\n", "16 0.303034 0.561098 0.810200 (00:14)\n", "17 0.278229 0.402861 0.864800 (00:14)\n", "18 0.257787 0.504758 0.845600 (00:14)\n", "19 0.237806 0.401391 0.873000 (00:14)\n", "20 0.209473 0.487110 0.854300 (00:14)\n", "21 0.184482 0.434822 0.861900 (00:14)\n", "22 0.154364 0.310643 0.899500 (00:14)\n", "23 0.117682 0.278385 0.911500 (00:14)\n", "24 0.091463 0.249205 0.922400 (00:14)\n", "25 0.066957 0.249959 0.928400 (00:14)\n", "26 0.043469 0.219788 0.936900 (00:14)\n", "27 0.028288 0.213345 0.939300 (00:14)\n", "28 0.020876 0.209119 0.942300 (00:14)\n", "29 0.017776 0.210493 0.941400 (00:14)\n", "\n" ] } ], "source": [ "learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16()\n", "learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With mixup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=24), HTML(value='0.00% [0/24 00:00<00:00]'))), HTML(val…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 05:52\n", "epoch train loss valid loss accuracy\n", "0 1.754659 1.418296 0.510000 (00:14)\n", "1 1.536941 1.274806 0.569400 (00:14)\n", "2 1.398102 0.972869 0.675200 (00:14)\n", "3 1.308902 1.026916 0.680600 (00:14)\n", "4 1.252701 1.091272 0.649400 (00:14)\n", "5 1.199010 0.735006 0.765200 (00:14)\n", "6 1.181147 0.814292 0.750400 (00:14)\n", "7 1.152908 0.697707 0.791000 (00:14)\n", "8 1.132457 0.746398 0.768600 (00:14)\n", "9 1.116421 0.715845 0.788800 (00:14)\n", "10 1.100030 0.711368 0.792600 (00:14)\n", "11 1.088852 0.572201 0.841300 (00:14)\n", "12 1.075793 0.733747 0.790300 (00:14)\n", "13 1.055172 0.555941 0.851700 (00:14)\n", "14 1.036253 0.516668 0.866700 (00:14)\n", "15 1.013288 0.522676 0.859600 (00:14)\n", "16 0.996412 0.494712 0.866500 (00:14)\n", "17 0.971465 0.406641 0.902100 (00:14)\n", "18 0.943768 0.393006 0.906800 (00:14)\n", "19 0.919369 0.337167 0.924500 (00:14)\n", "20 0.900417 0.327883 0.930300 (00:14)\n", "21 0.887146 0.313054 0.936300 (00:14)\n", "22 0.872677 0.309915 0.937800 (00:14)\n", "23 0.864825 0.305209 0.938900 (00:14)\n", "\n" ] } ], "source": [ "learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16().mixup()\n", "learn.fit_one_cycle(24, 3e-3, wd=0.2, div_factor=10, pct_start=0.5)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }