{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## CIFAR 10" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.conv_learner import *\n", "PATH = Path(\"data/cifar10/\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=64\n", "sz=32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = tfms_from_model(resnet18, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n", "data = ImageClassifierData.from_csv(PATH, 'train', PATH/'train.csv', tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.pretrained(resnet18, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=1e-2; wd=1e-5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "178615044d0445b39b5d41485c8540d6", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "learn.lr_find()\n", "learn.sched.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "60c148e603be4d758a07fc2894a4fdf0", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.249359 1.116181 0.604056 \n", " 1 1.215158 1.07421 0.613115 \n" ] }, { "data": { "text/plain": [ "[1.0742103, 0.6131150265957447]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, cycle_len=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lrs = np.array([lr/9,lr/3,lr])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20f2b140b25e4397a2d798b7ed213512", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " \r" ] }, { "data": { "image/png": "\n", "text/plain": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 0.863994 0.792711 0.726479 \n" ] }, { "data": { "text/plain": [ "[0.7927107, 0.7264793882978723]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lrs, 1, cycle_len=1, wds=wd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simplenet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n", "data = ImageClassifierData.from_csv(PATH, 'train', PATH/'train.csv', tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SimpleConv(nn.Module):\n", " def __init__(self, ic, oc, ks=3, drop=0.2, bn=True):\n", " super().__init__()\n", " self.conv = nn.Conv2d(ic, oc, ks, padding=(ks-1)//2)\n", " self.bn = nn.BatchNorm2d(oc, momentum=0.05) if bn else None\n", " self.drop = nn.Dropout(drop, inplace=True)\n", " self.act = nn.ReLU(True)\n", " \n", " def forward(self, x):\n", " x = self.conv(x)\n", " if self.bn: x = self.bn(x)\n", " return self.drop(self.act(x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net = nn.Sequential(\n", " SimpleConv(3, 64),\n", " SimpleConv(64, 128),\n", " SimpleConv(128, 128),\n", " SimpleConv(128, 128),\n", " nn.MaxPool2d(2),\n", " SimpleConv(128, 128),\n", " SimpleConv(128, 128),\n", " SimpleConv(128, 256),\n", " nn.MaxPool2d(2),\n", " SimpleConv(256, 256),\n", " SimpleConv(256, 256),\n", " nn.MaxPool2d(2),\n", " SimpleConv(256, 512),\n", " SimpleConv(512, 2048, ks=1, bn=False),\n", " SimpleConv(2048, 256, ks=1, bn=False),\n", " nn.MaxPool2d(2),\n", " SimpleConv(256, 256, bn=False, drop=0),\n", " nn.MaxPool2d(2),\n", " Flatten(),\n", " nn.Linear(256, 10)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bm = BasicModel(net.cuda(), name='simplenet')\n", "learn = ConvLearner(data, bm)\n", "learn.crit = nn.CrossEntropyLoss()\n", "learn.opt_fn = optim.Adam\n", "learn.unfreeze()\n", "learn.metrics=[accuracy]\n", "lr = 1e-3\n", "wd = 5e-3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6c2e6102551b4bd5bfedb8195031c800", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " \r" ] }, { "data": { "image/png": "\n", "text/plain": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " \r" ] }, { "data": { "image/png": "\n", "text/plain": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.464019 1.812134 0.324219 \n", " 1 1.299797 1.872144 0.301779 \n", " 2 1.152769 1.641428 0.405336 \n", " 3 1.06013 1.531731 0.46875 \n", " 4 1.001071 1.344982 0.546875 \n", " 5 0.957563 1.159598 0.629405 \n", " 6 0.895986 1.152674 0.619265 \n", " 7 0.852257 1.277312 0.607713 \n", " 8 0.844254 1.373495 0.538813 \n", " 9 0.784301 0.972733 0.717586 \n", " 10 0.751162 0.859369 0.741606 \n", " 11 0.735842 0.921104 0.729555 \n", " 12 0.690585 0.966144 0.706034 \n", " 13 0.662635 0.824769 0.759142 \n", " 14 0.626122 0.784435 0.775598 \n", " 15 0.61732 0.772561 0.772689 \n", " 16 0.570246 0.727107 0.785322 \n", " 17 0.526993 0.718699 0.786652 \n", " 18 0.499946 0.645241 0.812916 \n", " 19 0.499634 0.630276 0.816572 \n" ] }, { "data": { "text/plain": [ "[0.630276, 0.8165724734042553]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, wds=wd, cycle_len=20, use_clr=(32,10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7f8e07243c194ada863b64921bf0ada3", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.603266 2.02473 0.271941 \n", " 1 1.326654 1.682021 0.391955 \n", " 2 1.124686 1.564738 0.427776 \n", " 3 0.963391 1.164936 0.603225 \n", " 4 0.82219 1.19409 0.578291 \n" ] }, { "data": { "text/plain": [ "[1.1940901, 0.5782912234042553]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, wds=wd, cycle_len=5, use_clr=(32,10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c0051276049a4827b6f5e82a8d4170d6", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 0.819311 1.080679 0.636386 \n", " 1 0.90712 1.294629 0.547457 \n", " 2 0.717722 0.938504 0.700881 \n", " 3 0.898441 1.263396 0.586187 \n", " 4 0.803364 1.037912 0.666888 \n", " 5 0.668088 0.855235 0.737616 \n", " 6 0.616654 0.754756 0.770778 \n" ] }, { "data": { "text/plain": [ "[0.75475585, 0.7707779255319149]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d693d920dcb482fbd0759c4589348e5", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 0.833685 1.148864 0.620928 \n", " 1 0.819332 1.212562 0.608627 \n", " 2 0.803363 0.984564 0.697224 \n", " 3 0.790965 1.016013 0.702045 \n", " 4 0.733683 0.902306 0.735622 \n", " 5 0.698549 0.878661 0.732131 \n", " 6 0.648197 0.783731 0.758311 \n", " 7 0.597658 0.738099 0.782912 \n", " 8 0.557584 0.646611 0.80768 \n", " 9 0.507423 0.603345 0.822058 \n" ] }, { "data": { "text/plain": [ "[0.60334516, 0.8220578457446809]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, wds=wd, cycle_len=10, use_clr=(32,10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }