{ "cells": [ { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "# Swish Implementation Comparison" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "# Minimal fork of https://github.com/rwightman/gen-efficientnet-pytorch\n", "# Adds setup and lets you set the activation function\n", "# Note changes on setup branch\n", "# !pip install git+https://github.com/thomasbrandon/gen-efficientnet-pytorch@setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "from fastai.vision import *\n", "from gen_efficientnet.gen_efficientnet import efficientnet_b0, model_urls\n", "import swish_torch" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "SIZE = 256 # Resize crop to 256x256\n", "BS = 48 # Could probably be a little higher for CUDA/Function but will use same for all\n", "LR=1e-3" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "PATH = untar_data(URLs.IMAGEWOOF_320)\n", "data = (ImageList\n", " .from_folder(PATH)\n", " .split_by_folder(valid='val')\n", " .label_from_folder()\n", " .transform(([flip_lr(p=0.5)], []), size=SIZE)\n", " .databunch(bs=BS, num_workers=6)\n", " .presize(SIZE, scale=(0.35,1))\n", " .normalize(imagenet_stats))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "class PeakMemMetric(LearnerCallback):\n", " \"Callback that measures used and peak GPU memory.\"\n", " _order=-20 # Needs to run before the recorder\n", "\n", " def __init__(self, learn:Learner, device=None):\n", " super().__init__(learn)\n", " assert torch.cuda.is_available(), \"pytorch CUDA is required\"\n", " self._dev = ifnone(device, torch.cuda.current_device())\n", "\n", " def on_train_begin(self, **kwargs):\n", " self.learn.recorder.add_metric_names(['cache MB', 'alloc MB'])\n", "\n", " def on_epoch_begin(self, **kwargs):\n", " torch.cuda.reset_max_memory_cached(self._dev)\n", " torch.cuda.reset_max_memory_allocated(self._dev)\n", " \n", " def on_epoch_end(self, last_metrics, **kwargs):\n", " b2mb = lambda num: int(num/2**20)\n", " cache = torch.cuda.max_memory_cached(self._dev)\n", " alloc = torch.cuda.max_memory_allocated(self._dev)\n", " return add_metrics(last_metrics, [b2mb(cache), b2mb(alloc)])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "def load_pretrained(mdl):\n", " # Load pretrained data, except for differently size linear layers\n", " state_dict = torch.utils.model_zoo.load_url(model_urls['efficientnet_b0'])\n", " for attr in ['weight','bias']: state_dict[f'classifier.{attr}'] = getattr(mdl.classifier, attr)\n", " mdl.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ "ImageDataBunch;\n", "\n", "Train: LabelList (12454 items)\n", "x: ImageList\n", "Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)\n", "y: CategoryList\n", "n02111889,n02111889,n02111889,n02111889,n02111889\n", "Path: /home/user/.fastai/data/imagewoof-320;\n", "\n", "Valid: LabelList (500 items)\n", "x: ImageList\n", "Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)\n", "y: CategoryList\n", "n02111889,n02111889,n02111889,n02111889,n02111889\n", "Path: /home/user/.fastai/data/imagewoof-320;\n", "\n", "Test: None" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# https://github.com/fastai/imagenette\n", "# Subset of 10 dog breeds from Imagenet, 320px shortest side\n", "data" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Original Implementation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "mdl = efficientnet_b0(num_classes=data.c)\n", "load_pretrained(mdl)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ "\u001b[0;31mSignature:\u001b[0m \u001b[0mmdl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mDocstring:\u001b[0m \n", "\u001b[0;31mSource:\u001b[0m \n", "\u001b[0;32mdef\u001b[0m \u001b[0mswish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/fastai/lib/python3.7/site-packages/gen_efficientnet/efficientnet_builder.py\n", "\u001b[0;31mType:\u001b[0m function\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mdl.act_fn??" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracycache MBalloc MBtime
00.4009870.3706520.8900007204689001:12
10.4396660.3857240.8900007106687901:11
20.2985810.2746520.9100007106687901:12
30.1365970.2313830.9180007106687901:11
40.0759610.2117510.9320007106687901:11
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lrn = Learner(data, mdl, callback_fns=[PeakMemMetric], metrics=[accuracy])\n", "lrn.fit_one_cycle(5, LR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "lrn.destroy()\n", "del lrn, mdl" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Autograd Function Implementation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracycache MBalloc MBtime
00.4500810.5934700.8820006432542101:14
10.4369540.3684580.8800006432542101:13
20.2621580.3686610.8900006432542101:14
30.1427930.2466730.9280006432542101:14
40.0753770.2405330.9240006432542101:14
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class SwishFunction(torch.autograd.Function):\n", " @staticmethod\n", " def forward(ctx, i):\n", " result = i * torch.sigmoid(i)\n", " ctx.save_for_backward(i)\n", " return result\n", "\n", " @staticmethod\n", " def backward(ctx, grad_output):\n", " i, = ctx.saved_tensors\n", " if not ctx.needs_input_grad[0]: return (None,)\n", " sigmoid_i = torch.sigmoid(i)\n", " return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n", " \n", "# Activation function for gen_efficientnet has an inplace keyword\n", "# Can't be inplace so just ignore\n", "def swish_function(x, inplace=False): return SwishFunction.apply(x)\n", "\n", "mdl = efficientnet_b0(num_classes=data.c, act_fn=swish_function)\n", "load_pretrained(mdl)\n", "lrn = Learner(data, mdl, callback_fns=[PeakMemMetric], metrics=[accuracy])\n", "lrn.fit_one_cycle(5, LR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "lrn.destroy()\n", "del lrn, mdl" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## CUDA Implementation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracycache MBalloc MBtime
00.4447610.3947720.8740005934540001:02
10.4415380.4345010.8660005934540001:01
20.2933200.2760600.9060005934540001:02
30.1494190.2453420.9180005934540001:02
40.0616240.2584650.9180005934540001:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Activation function for gen_efficientnet has an inplace keyword\n", "# Can't be inplace so just ignore\n", "def swish_cuda_fn(x, inplace=False): return swish_torch.swish(x)\n", "\n", "mdl = efficientnet_b0(num_classes=data.c, act_fn=swish_cuda_fn)\n", "load_pretrained(mdl)\n", "lrn = Learner(data, mdl, callback_fns=[PeakMemMetric], metrics=[accuracy])\n", "lrn.fit_one_cycle(5, LR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "lrn.destroy()\n", "del lrn, mdl" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "# Results\n", "```\n", "\t\t train_loss valid_loss accuracy cache MB alloc MB time\n", "Original 0.075961 0.211751 0.932000 7106 6879 01:11\n", "Autograd 0.075377 0.240533 0.924000 6432 5421 01:14\n", "CUDA 0.061624 0.258465 0.918000 5934 5400 01:02\n", "```\n", "\n", "So the CUDA version is (slightly) faster than the original with the memory usage of the Autoigrad version." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:.conda-fastai]", "language": "python", "name": "conda-env-.conda-fastai-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }