{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## CIFAR 10" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fastai.conv_learner import *\n", "PATH = \"data/cifar10/\"\n", "os.makedirs(PATH,exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load classes" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", "stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_data(sz,bs):\n", " tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n", " return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "bs=128" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "\n", "class tofp16(nn.Module):\n", " def __init__(self):\n", " super(tofp16, self).__init__()\n", "\n", " def forward(self, input):\n", " return input.half()\n", "\n", "\n", "def copy_in_params(net, params):\n", " net_params = list(net.parameters())\n", " for i in range(len(params)):\n", " net_params[i].data.copy_(params[i].data)\n", "\n", "\n", "def set_grad(params, params_with_grad):\n", "\n", " for param, param_w_grad in zip(params, params_with_grad):\n", " if param.grad is None:\n", " param.grad = torch.nn.Parameter(param.data.new().resize_(*param.data.size()))\n", " param.grad.data.copy_(param_w_grad.grad.data)\n", "\n", "\n", "#BatchNorm layers to have parameters in single precision.\n", "#Find all layers and convert them back to float. This can't\n", "#be done with built in .apply as that function will apply\n", "#fn to all modules, parameters, and buffers. Thus we wouldn't\n", "#be able to guard the float conversion based on the module type.\n", "def BN_convert_float(module):\n", " if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):\n", " module.float()\n", " for child in module.children():\n", " BN_convert_float(child)\n", " return module\n", "\n", "def network_to_half(network):\n", " return nn.Sequential(tofp16(), BN_convert_float(network.cuda().half()))" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Measure fp16 - half" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hidden": true }, "outputs": [], "source": [ "from fastai.models.cifar10.resnext import resnext29_8_64\n", "\n", "m = resnext29_8_64()\n", "# m = resnet50(False)\n", "bm = BasicModel(network_to_half(m).cuda(), name='cifar10_resnet50')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "hidden": true }, "outputs": [], "source": [ "data = get_data(8,bs*4*4)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hidden": true }, "outputs": [], "source": [ "learn = ConvLearner(data, bm)\n", "learn.unfreeze()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hidden": true }, "outputs": [], "source": [ "lr=4e-2; wd=5e-4" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cb8d6d896da44f4096f4d9d4989d364d", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 3.748584 4.073438 0.199611 \n", " 1 2.67269 1.886133 0.30006 \n", " 2 2.250531 1.782031 0.36761 \n", "\n", "CPU times: user 1min 40s, sys: 47.9 s, total: 2min 28s\n", "Wall time: 1min 39s\n" ] }, { "data": { "text/plain": [ "[1.7820313, 0.367610102891922]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(lr, 1, cycle_len=3, use_clr=(20,8))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Measure time on 32x32" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "data = get_data(32,bs*4).half()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5ad88634768f4e1e9b1eca897ce6ac37", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.583998 1.559375 0.462343 \n", " 1 1.508549 1.411523 0.492938 \n", " 2 1.443428 1.383203 0.507284 \n", "\n", "CPU times: user 1min 48s, sys: 45.1 s, total: 2min 33s\n", "Wall time: 1min 46s\n" ] }, { "data": { "text/plain": [ "[1.3832031, 0.50728360414505]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(lr, 1, cycle_len=3, use_clr=(20,8))" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Measure fp32 - full" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "hidden": true }, "outputs": [], "source": [ "from fastai.models.cifar10.resnext import resnext29_8_64\n", "\n", "mf = resnext29_8_64()\n", "# m = resnet50(False)\n", "bmf = BasicModel(mf.cuda(), name='cifar10_resnet50')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "hidden": true }, "outputs": [], "source": [ "dataf = get_data(8,bs*4*4)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "hidden": true }, "outputs": [], "source": [ "learnf = ConvLearner(dataf, bmf)\n", "learnf.unfreeze()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "hidden": true }, "outputs": [], "source": [ "lr=4e-2; wd=5e-4" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a429bb88acfe46a6905ef511a31add87", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 3.712339 2.205959 0.234404 \n", " 1 2.59699 1.735602 0.357074 \n", " 2 2.128743 1.61046 0.409041 \n", "\n", "CPU times: user 1min 22s, sys: 46.6 s, total: 2min 8s\n", "Wall time: 1min 33s\n" ] }, { "data": { "text/plain": [ "[1.6104597, 0.4090405464172363]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learnf.fit(lr, 1, cycle_len=3, use_clr=(20,8))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Measure time on 32x32" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "data = get_data(32,bs*4)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "161c9f185a4f4c978a853e6c143ed2d8", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.606625 1.529435 0.445549 \n", " 1 1.535117 1.434556 0.485208 \n", " 2 1.466418 1.392065 0.501437 \n", "\n", "CPU times: user 1min 20s, sys: 47.4 s, total: 2min 8s\n", "Wall time: 1min 33s\n" ] }, { "data": { "text/plain": [ "[1.3920648, 0.5014371871948242]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learnf.fit(lr, 1, cycle_len=3, use_clr=(20,8))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Results:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "FP16 is actually slower in these tests. Will have to look at why this is.\n", "Possible reasons:\n", "* Image size or batch size is too small\n", " * Training isn't long enough to make a difference?\n", "* Data loader is too slow\n", " * Training time doesn't increase when training on 8x8 vs 32x32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "toc": { "nav_menu": { "height": "266px", "width": "252px" }, "number_sections": true, "sideBar": true, "skip_h1_title": false, "toc_cell": false, "toc_position": {}, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }