{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fastai.dataloader import *\n", "from fastai.dataset import *\n", "from fastai.transforms import *\n", "from fastai.models import *\n", "from fastai.conv_learner import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "DIR = Path('data/imagenet/')\n", "TRAIN_CSV='train.csv'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load ImageNet sample and format data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget --header=\"Host: files.fast.ai\" --header=\"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.146 Safari/537.36\" --header=\"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\" --header=\"Accept-Language: en-GB,en-US;q=0.9,en;q=0.8\" \"http://files.fast.ai/data/imagenet-sample-train.tar.gz\" -O \"imagenet-sample-train.tar.gz\" -c" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir data/imagenet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!tar -xzf imagenet-sample-train.tar.gz -C data/imagenet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ls data/imagenet/train/n01558993" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TRAIN = Path('data/imagenet/train')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CATEGORIES = list(Path(TRAIN).iterdir())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cats_to_files = {cat : list(cat.iterdir()) for cat in CATEGORIES}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "files = []\n", "for k,v in cats_to_files.items():\n", " for file in v:\n", " files.append(file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "!mkdir data/imagenet/train1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import shutil \n", "for file in files:\n", " shutil.move(str(file), 'data/imagenet/train1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ALL_FILES = Path('data/imagenet/train1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "afiles = list(ALL_FILES.iterdir())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMAGE='image'\n", "CATEGORY='category'\n", "df = pd.DataFrame(afiles, columns=[IMAGE])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "df[IMAGE] = df[IMAGE].apply(lambda x: str(x)[str(x).rfind('/')+1:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df['category'] = df[IMAGE].apply(lambda x: x[:x.find('_')])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.to_csv(DIR/TRAIN_CSV, index=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "models\ttmp train train1 train.csv\r\n" ] } ], "source": [ "!ls {DIR}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "arch = resnet34\n", "tfms = tfms_from_model(arch, 256, aug_tfms=transforms_side_on)\n", "bs = 128" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "data = ImageClassifierData.from_csv(DIR, 'train1', DIR/TRAIN_CSV, tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [], "source": [ "m = ConvnetBuilder(resnet34, data.c, data.is_multi, data.is_reg)\n", "# models = ConvnetBuilder(resnet34, data.c, data.is_multi, data.is_reg)\n", "# m.model = network_to_half(m.model)\n", "learner = ConvLearner(data, m)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dd4985c144344c0ea21626702c9e5fc2", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 2.918728 1.625027 0.620684 \n", "\n", "CPU times: user 5min 27s, sys: 2min 46s, total: 8min 13s\n", "Wall time: 1min 10s\n" ] }, { "data": { "text/plain": [ "[1.6250268, 0.6206835527573863]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learner.fit(0.5,1,cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training model 16" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "\n", "class tofp16(nn.Module):\n", " def __init__(self):\n", " super(tofp16, self).__init__()\n", "\n", " def forward(self, input):\n", " return input.half()\n", "\n", "\n", "def copy_in_params(net, params):\n", " net_params = list(net.parameters())\n", " for i in range(len(params)):\n", " net_params[i].data.copy_(params[i].data)\n", "\n", "\n", "def set_grad(params, params_with_grad):\n", "\n", " for param, param_w_grad in zip(params, params_with_grad):\n", " if param.grad is None:\n", " param.grad = torch.nn.Parameter(param.data.new().resize_(*param.data.size()))\n", " param.grad.data.copy_(param_w_grad.grad.data)\n", "\n", "\n", "def BN_convert_float(module):\n", " '''\n", " BatchNorm layers to have parameters in single precision.\n", " Find all layers and convert them back to float. This can't\n", " be done with built in .apply as that function will apply\n", " fn to all modules, parameters, and buffers. Thus we wouldn't\n", " be able to guard the float conversion based on the module type.\n", " '''\n", " if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):\n", " module.float()\n", " for child in module.children():\n", " BN_convert_float(child)\n", " return module\n", "\n", "\n", "def network_to_half(network):\n", " return nn.Sequential(tofp16(), BN_convert_float(network.half()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arch = resnet34\n", "tfms = tfms_from_model(arch, 256, aug_tfms=transforms_side_on)\n", "bs = 128" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = ImageClassifierData.from_csv(DIR, 'train1', DIR/TRAIN_CSV, tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [], "source": [ "m16 = ConvnetBuilder(resnet34, data.c, data.is_multi, data.is_reg)\n", "m16.model = network_to_half(m16.model)\n", "learner16 = ConvLearner(data, m16)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): tofp16(\n", " )\n", " (1): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " (2): ReLU(inplace)\n", " (3): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)\n", " (4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (5): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (6): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (4): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (5): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (7): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " )\n", " )\n", " (8): AdaptiveConcatPool2d(\n", " (ap): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (mp): AdaptiveMaxPool2d(output_size=(1, 1))\n", " )\n", " (9): Flatten(\n", " )\n", " (10): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)\n", " (11): Dropout(p=0.25)\n", " (12): Linear(in_features=1024, out_features=512, bias=True)\n", " (13): ReLU()\n", " (14): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (15): Dropout(p=0.5)\n", " (16): Linear(in_features=512, out_features=774, bias=True)\n", " (17): LogSoftmax()\n", " )\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner16.model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31cfab14bf844ba68eee81f4702a1eb5", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 3.002965 1.667827 0.600142 \n", "\n", "CPU times: user 5min 20s, sys: 2min 37s, total: 7min 58s\n", "Wall time: 1min 6s\n" ] }, { "data": { "text/plain": [ "[1.667827, 0.6001415579549728]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learner16.fit(0.5,1,cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No speedup here. \n", "You can confirm that we are using v100 tensor cores by running this: \n", "`/usr/local/cuda/bin/nvprof --log-file nvprof_output.txt python profile_fp16_imagenet.py`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }