{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# all_cuda" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.basics import *\n", "from fastai.callback.progress import *\n", "\n", "from torch.cuda.amp import GradScaler,autocast\n", "from torch.cuda.amp.grad_scaler import OptState" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.fp16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai.test_utils import *\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Mixed precision training\n", "\n", "> Callback and utility functions to allow mixed precision training " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A little bit of theory" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A very nice and clear introduction to mixed precision training is [this video from NVIDIA](https://on-demand.gputechconf.com/gtc/2019/video/_/S9143/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What's half precision?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In neural nets, all the computations are usually done in single precision, which means all the floats in all the arrays that represent inputs, activations, weights... are 32-bit floats (FP32 in the rest of this post). An idea to reduce memory usage (and avoid those annoying cuda errors) has been to try and do the same thing in half-precision, which means using 16-bits floats (or FP16 in the rest of this post). By definition, they take half the space in RAM, and in theory could allow you to double the size of your model and double your batch size.\n", "\n", "Another very nice feature is that NVIDIA developed its latest GPUs (the Volta generation) to take fully advantage of half-precision tensors. Basically, if you give half-precision tensors to those, they'll stack them so that each core can do more operations at the same time, and theoretically gives an 8x speed-up (sadly, just in theory).\n", "\n", "So training at half precision is better for your memory usage, way faster if you have a Volta GPU (still a tiny bit faster if you don't since the computations are easiest). How do we do it? Super easily in pytorch, we just have to put .half() everywhere: on the inputs of our model and all the parameters. Problem is that you usually won't see the same accuracy in the end (so it happens sometimes) because half-precision is... well... not as precise ;)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Problems with half-precision:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand the problems with half precision, let's look briefly at what an FP16 looks like (more information [here](https://en.wikipedia.org/wiki/Half-precision_floating-point_format)).\n", "\n", "![half float](images/half.png)\n", "\n", "The sign bit gives us +1 or -1, then we have 5 bits to code an exponent between -14 and 15, while the fraction part has the remaining 10 bits. Compared to FP32, we have a smaller range of possible values (2e-14 to 2e15 roughly, compared to 2e-126 to 2e127 for FP32) but also a smaller *offset*.\n", "\n", "For instance, between 1 and 2, the FP16 format only represents the number 1, 1+2e-10, 1+2*2e-10... which means that 1 + 0.0001 = 1 in half precision. That's what will cause a certain numbers of problems, specifically three that can occur and mess up your training.\n", "1. The weight update is imprecise: inside your optimizer, you basically do w = w - lr * w.grad for each weight of your network. The problem in performing this operation in half precision is that very often, w.grad is several orders of magnitude below w, and the learning rate is also small. The situation where w=1 and lr*w.grad is 0.0001 (or lower) is therefore very common, but the update doesn't do anything in those cases.\n", "2. Your gradients can underflow. In FP16, your gradients can easily be replaced by 0 because they are too low.\n", "3. Your activations or loss can overflow. The opposite problem from the gradients: it's easier to hit nan (or infinity) in FP16 precision, and your training might more easily diverge." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The solution: mixed precision training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To address those three problems, we don't fully train in FP16 precision. As the name mixed training implies, some of the operations will be done in FP16, others in FP32. This is mainly to take care of the first problem listed above. For the next two there are additional tricks.\n", "\n", "The main idea is that we want to do the forward pass and the gradient computation in half precision (to go fast) but the update in single precision (to be more precise). It's okay if w and grad are both half floats, but when we do the operation w = w - lr * grad, we need to compute it in FP32. That way our 1 + 0.0001 is going to be 1.0001. \n", "\n", "This is why we keep a copy of the weights in FP32 (called master model). Then, our training loop will look like:\n", "1. compute the output with the FP16 model, then the loss\n", "2. back-propagate the gradients in half-precision.\n", "3. copy the gradients in FP32 precision\n", "4. do the update on the master model (in FP32 precision)\n", "5. copy the master model in the FP16 model.\n", "\n", "Note that we lose precision during step 5, and that the 1.0001 in one of the weights will go back to 1. But if the next update corresponds to add 0.0001 again, since the optimizer step is done on the master model, the 1.0001 will become 1.0002 and if we eventually go like this up to 1.0005, the FP16 model will be able to tell the difference.\n", "\n", "That takes care of problem 1. For the second problem, we use something called gradient scaling: to avoid the gradients getting zeroed by the FP16 precision, we multiply the loss by a scale factor (scale=512 for instance). That way we can push the gradients to the right in the next figure, and have them not become zero.\n", "\n", "![half float representation](images/half_representation.png)\n", "\n", "Of course we don't want those 512-scaled gradients to be in the weight update, so after converting them into FP32, we can divide them by this scale factor (once they have no risks of becoming 0). This changes the loop to:\n", "1. compute the output with the FP16 model, then the loss.\n", "2. multiply the loss by scale then back-propagate the gradients in half-precision.\n", "3. copy the gradients in FP32 precision then divide them by scale.\n", "4. do the update on the master model (in FP32 precision).\n", "5. copy the master model in the FP16 model.\n", "\n", "For the last problem, the tricks offered by NVIDIA are to leave the batchnorm layers in single precision (they don't have many weights so it's not a big memory challenge) and compute the loss in single precision (which means converting the last output of the model in single precision before passing it to the loss).\n", "\n", "![Mixed precision training](images/Mixed_precision.jpeg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dynamic loss scaling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The only annoying thing with the previous implementation of mixed precision training is that it introduces one new hyper-parameter to tune, the value of the loss scaling. Fortunately for us, there is a way around this. We want the loss scaling to be as high as possible so that our gradients can use the whole range of representation, so let's first try a really high value. In all likelihood, this will cause our gradients or our loss to overflow, and we will try again with half that big value, and again, until we get to the largest loss scale possible that doesn't make our gradients overflow.\n", "\n", "This value will be perfectly fitted to our model and can continue to be dynamically adjusted as the training goes, if it's still too high, by just halving it each time we overflow. After a while though, training will converge and gradients will start to get smaller, so we al\n", "so need a mechanism to get this dynamic loss scale larger if it's safe to do so. The strategy used in the Apex library is to multiply the loss scale by 2 each time we had a given number of iterations without overflowing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MixedPrecision -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates(GradScaler)\n", "class MixedPrecision(Callback):\n", " \"Mixed precision training using Pytorch's `autocast` and `GradScaler`\"\n", " order = 10\n", " def __init__(self, **kwargs): self.kwargs = kwargs\n", " def before_fit(self): \n", " self.autocast,self.learn.scaler,self.scales = autocast(),GradScaler(**self.kwargs),L()\n", " def before_batch(self): self.autocast.__enter__()\n", " def after_pred(self):\n", " if next(flatten(self.pred)).dtype==torch.float16: self.learn.pred = to_float(self.pred)\n", " def after_loss(self): self.autocast.__exit__(None, None, None)\n", " def before_backward(self): self.learn.loss_grad = self.scaler.scale(self.loss_grad)\n", " def before_step(self):\n", " self.skipped=True\n", " self.scaler.step(self)\n", " if self.skipped: raise CancelStepException()\n", " self.scales.append(self.scaler.get_scale())\n", " def after_step(self): self.learn.scaler.update()\n", "\n", " @property # pretend to be an optimizer for `GradScaler`\n", " def param_groups(self): return self.opt.param_groups\n", " def step(self, *args, **kwargs): self.skipped=False\n", " def after_fit(self): self.autocast,self.learn.scaler,self.scales = None,None,None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class FP16TestCallback(Callback):\n", " \"Asserts that predictions are `float16` values\"\n", " order = 9\n", " def after_pred(self): assert listify(flatten(self.pred))[0].dtype==torch.float16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
017.55486514.35781900:00
117.00677913.43655000:00
216.41444212.54255200:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#cuda\n", "set_seed(99, True)\n", "learn = synth_learner(cbs=[MixedPrecision,FP16TestCallback], cuda=True)\n", "learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()\n", "learn.opt_func = partial(SGD, mom=0.)\n", "learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]\n", "learn.fit(3)\n", "assert learn.recorder.values[-1][-1]\n", " \n", " \n", " epoch\n", " train_loss\n", " valid_loss\n", " time\n", " \n", " \n", " \n", " \n", " 0\n", " 87.652245\n", " 72.425194\n", " 00:00\n", " \n", " \n", " 1\n", " 86.457306\n", " 70.571136\n", " 00:00\n", " \n", " \n", " 2\n", " 85.303947\n", " 68.533089\n", " 00:00\n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#cuda\n", "#Multioutput version\n", "set_seed(99, True)\n", "learn = synth_learner(cbs=[MixedPrecision,FP16TestCallback], cuda=True)\n", "class MultiOutputModel(Module):\n", " def __init__(self): self.linear1, self.linear2 = nn.Linear(1,1) , nn.Linear(1,1)\n", " def forward(self,x): return self.linear1(x), self.linear2(x)\n", "def multioutputloss(pred, val): return ((val-pred[0]).abs() + 0.5 * (val-pred[1]).abs()).sum()\n", "learn.model = MultiOutputModel()\n", "learn.opt_func = partial(SGD, mom=0.)\n", "learn.splitter = lambda m: [list(m.linear1.parameters()), list(m.linear2.parameters())]\n", "learn.loss_func=multioutputloss\n", "learn.fit(3)\n", "assert learn.recorder.values[-1][-1]None:\n", " for (model_params,master_params) in zip(model_pgs,master_pgs):\n", " master_params_to_model_params(model_params, master_params, flat_master=flat_master)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#cuda\n", "learn.opt.params = master_p\n", "learn.opt.step()\n", "to_model_params(model_p, master_p)\n", "test_close([p.float() for pg in model_p for p in pg], [p for pg in master_p for p in pg], eps=1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#cuda\n", "learn.opt.params = master_pf\n", "learn.opt.step()\n", "to_model_params(model_pf, master_pf, flat_master=True)\n", "test_close([p.float().squeeze() for pg in model_pf for p in pg], [p for pg in master_pf for p in pg[0]], eps=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Checking for overflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For dynamic loss scaling, we need to know when the gradients have gone up to infinity. It's faster to check it on the sum than to do `torch.isinf(x).any()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export \n", "def test_overflow(x):\n", " s = float(x.float().sum())\n", " return (s == float('inf') or s == float('-inf') or s != s)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(3,4)\n", "assert not test_overflow(x)\n", "x[1,2] = float('inf')\n", "assert test_overflow(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we can use it in the following function that checks for gradient overflow:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export \n", "def grad_overflow(pgs):\n", " for pg in pgs:\n", " for p in pg:\n", " if p.grad is not None and test_overflow(p.grad.data): return True\n", " return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#cuda\n", "assert not grad_overflow(model_p)\n", "assert not grad_overflow(model_pf)\n", "model_p[1][0].grad.data[0,0] = float('inf')\n", "model_pf[0][1].grad.data[0] = float('inf')\n", "assert grad_overflow(model_p)\n", "assert grad_overflow(model_pf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NonNativeMixedPrecision -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def copy_clone(d):\n", " return {k:(v.detach().clone().float() if isinstance(v,Tensor) else v) for k,v in d.items()}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _copy_state(opt, pgs1, pgs2):\n", " opt.param_lists = pgs2\n", " for pg1,pg2 in zip(pgs1, pgs2):\n", " for p1,p2 in zip(pg1, pg2): opt.state[p2] = copy_clone(opt.state.pop(p1, {}))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ModelToHalf(Callback):\n", " \"Use with NonNativeMixedPrecision callback (but it needs to run at the very beginning)\"\n", " order=-50\n", " def before_fit(self): self.learn.model = convert_network(self.model, dtype=torch.float16)\n", " def after_fit (self): self.learn.model = convert_network(self.model, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@docs\n", "class NonNativeMixedPrecision(Callback):\n", " \"Run training in mixed precision\"\n", " order=10\n", " def __init__(self, loss_scale=512, flat_master=False, dynamic=True, max_loss_scale=2.**24,\n", " div_factor=2., scale_wait=500, clip=None):\n", " assert torch.backends.cudnn.enabled, \"Mixed precision training requires cudnn.\"\n", " self.flat_master,self.dynamic,self.max_loss_scale = flat_master,dynamic,max_loss_scale\n", " self.div_factor,self.scale_wait,self.clip = div_factor,scale_wait,clip\n", " self.loss_scale = max_loss_scale if dynamic else loss_scale\n", "\n", " def before_fit(self):\n", " assert self.dls.device.type == 'cuda', \"Mixed-precision training requires a GPU, remove the call `to_fp16`\"\n", " if self.learn.opt is None: self.learn.create_opt()\n", " self.model_pgs,self.master_pgs = get_master(self.opt, self.flat_master)\n", " self.old_pgs = self.opt.param_lists\n", " #Changes the optimizer so that the optimization step is done in FP32.\n", " _copy_state(self.learn.opt, self.model_pgs, self.master_pgs)\n", " if self.dynamic: self.count = 0\n", "\n", " def before_batch(self): self.learn.xb = to_half(self.xb)\n", " def after_pred(self): self.learn.pred = to_float(self.pred)\n", " def before_backward(self): self.learn.loss_grad *= self.loss_scale\n", "\n", " def before_step(self):\n", " #First, check for an overflow\n", " if self.dynamic and grad_overflow(self.model_pgs):\n", " self.loss_scale /= self.div_factor\n", " self.learn.loss_grad /= self.div_factor #to record correct loss\n", " self.model.zero_grad()\n", " raise CancelBatchException() #skip step and zero_grad\n", " to_master_grads(self.model_pgs, self.master_pgs, self.flat_master)\n", " for master_params in self.master_pgs:\n", " for param in master_params:\n", " if param.grad is not None: param.grad.div_(self.loss_scale)\n", " if self.clip is not None:\n", " for group in self.master_pgs: nn.utils.clip_grad_norm_(group, self.clip)\n", " # Check if it's been long enough without overflow\n", " if self.dynamic:\n", " self.count += 1\n", " if self.count == self.scale_wait:\n", " self.count = 0\n", " self.loss_scale *= self.div_factor\n", "\n", " def after_step(self):\n", " self.model.zero_grad() #Zero the gradients of the model manually (optimizer disconnected)\n", " to_model_params(self.model_pgs, self.master_pgs, self.flat_master)\n", "\n", " def after_batch(self):\n", " if self.training: self.learn.loss_grad /= self.loss_scale #Log correct loss\n", " def after_fit(self):\n", " if not hasattr(self,'master_pgs'): return\n", " _copy_state(self.learn.opt, self.master_pgs, self.model_pgs)\n", " self.learn.opt.param_lists = self.old_pgs\n", " delattr(self, \"master_pgs\")\n", " delattr(self, \"model_pgs\")\n", " delattr(self, \"old_pgs\")\n", "\n", " _docs = dict(before_fit=\"Put the model in FP16 and prepare the two copies of the parameters\",\n", " before_batch=\"Put the input in FP16\",\n", " after_pred=\"Put the output back to FP32 so that the loss is computed in FP32\",\n", " before_backward=\"Apply loss scaling to avoid gradient underflow\",\n", " before_step=\"Copy the gradients to the master param and undo the loss scaling\",\n", " after_step=\"Copy the master params to the model params\",\n", " after_batch=\"Ensure loss is logged correctly\",\n", " after_fit=\"Put the model back in FP32\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class TestBeforeMixedPrecision(Callback):\n", " order=-55\n", " def before_fit(self): test_eq(first(self.model.parameters()).dtype, torch.float32)\n", " def before_batch(self): test_eq(self.x.dtype, torch.float32)\n", " def after_pred(self): test_eq(self.pred.dtype, torch.float16)\n", " def after_loss(self): self.tst_loss = self.learn.loss_grad.detach().clone()\n", " def before_step(self):\n", " self.learn.has_overflown = grad_overflow(self.non_native_mixed_precision.model_pgs)\n", " self.grads = [p.grad.data.clone() for p in self.model.parameters()]\n", " self.old_params = [p.data.clone() for p in self.model.parameters()]\n", " def after_cancel_step(self): assert self.has_overflown\n", "\n", "class TestAfterMixedPrecision(Callback):\n", " order=65\n", " def before_fit(self): test_eq(first(self.model.parameters()).dtype, torch.float16)\n", " def after_fit(self): test_eq(first(self.model.parameters()).dtype, torch.float32)\n", " def before_batch(self): test_eq(self.x.dtype, torch.float16)\n", " def after_pred(self): test_eq(self.pred.dtype, torch.float32)\n", " def before_backward(self):\n", " loss_scale = self.non_native_mixed_precision.loss_scale if self.training else 1.\n", " test_eq(self.loss_grad, self.test_before_mixed_precision.tst_loss * loss_scale) \n", " def before_step(self):\n", " tbmp = self.test_before_mixed_precision\n", " test_eq(self.loss_grad, tbmp.loss_grad)\n", " #Test gradients have been copied and scaled back\n", " test_close(sum([[p.grad.data for p in pg] for pg in self.non_native_mixed_precision.master_pgs], []),\n", " [g.float()/self.non_native_mixed_precision.loss_scale for g in tbmp.grads])\n", " def after_batch(self):\n", " if self.has_overflown: return\n", " tbmp,mp =self.test_before_mixed_precision,self.non_native_mixed_precision\n", " #Test master params have been copied to model\n", " test_close(sum([[p.data for p in pg] for pg in mp.master_pgs], []),\n", " [p.data.float() for p in self.model.parameters()], eps=1e-3)\n", " #Test update has been done properly\n", " for p,g,op in zip(self.model.parameters(), tbmp.grads, tbmp.old_params):\n", " test_close(p.data.float(), op.float() - self.lr*g.float()/self.non_native_mixed_precision.loss_scale, eps=1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
07.1879325.85584500:00
17.1487435.69771700:00
27.0489155.52417200:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#cuda\n", "learn = synth_learner(cbs=[ModelToHalf(), NonNativeMixedPrecision()], cuda=True)\n", "learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()\n", "learn.opt_func = partial(SGD, mom=0.)\n", "learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]\n", "learn.fit(3, cbs=[TestAfterMixedPrecision(), TestBeforeMixedPrecision()])\n", "#Check loss scale did change\n", "assert 1 < learn.non_native_mixed_precision.loss_scale < 2**24\n", "#Check the model did train\n", "for v1,v2 in zip(learn.recorder.values[0], learn.recorder.values[-1]): assert v2\n", " \n", " \n", " epoch\n", " train_loss\n", " valid_loss\n", " time\n", " \n", " \n", " \n", " \n", " 0\n", " 11.927933\n", " 12.063744\n", " 00:00\n", " \n", " \n", " 1\n", " 11.539829\n", " 11.545557\n", " 00:00\n", " \n", " \n", " 2\n", " 11.266481\n", " 11.075830\n", " 00:00\n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#cuda\n", "learn = synth_learner(cbs=[ModelToHalf(), NonNativeMixedPrecision(dynamic=False)], cuda=True)\n", "learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()\n", "learn.opt_func = partial(SGD, mom=0.)\n", "learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]\n", "learn.fit(3, cbs=[TestAfterMixedPrecision(), TestBeforeMixedPrecision()])\n", "#Check loss scale did mot change\n", "test_eq(learn.non_native_mixed_precision.loss_scale,512)\n", "#Check the model did train\n", "for v1,v2 in zip(learn.recorder.values[0], learn.recorder.values[-1]): assert v2\n", " \n", " \n", " epoch\n", " train_loss\n", " valid_loss\n", " time\n", " \n", " \n", " \n", " \n", " 0\n", " 8.358611\n", " 10.943352\n", " 00:00\n", " \n", " \n", " 1\n", " 8.330508\n", " 10.722443\n", " 00:00\n", " \n", " \n", " 2\n", " 8.221409\n", " 10.485508\n", " 00:00\n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#cuda\n", "learn = synth_learner(cuda=True)\n", "learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()\n", "learn.opt_func = partial(SGD, mom=0.)\n", "learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]\n", "learn.to_non_native_fp16()\n", "learn.fit(3, cbs=[TestAfterMixedPrecision(), TestBeforeMixedPrecision()])\n", "#Check the model did train\n", "for v1,v2 in zip(learn.recorder.values[0], learn.recorder.values[-1]): assert v2\n", " \n", " \n", " epoch\n", " train_loss\n", " valid_loss\n", " time\n", " \n", " \n", " \n", " \n", " 0\n", " 11.646567\n", " 10.883919\n", " 00:00\n", " \n", " \n", " 1\n", " 11.489956\n", " 9.904404\n", " 00:00\n", " \n", " \n", " 2\n", " 10.746455\n", " 7.914827\n", " 00:00\n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#cuda\n", "learn = synth_learner(cuda=True)\n", "learn.model = nn.Sequential(nn.Linear(1,1), nn.Linear(1,1)).cuda()\n", "learn.opt_func = partial(SGD, mom=0.9)\n", "learn.splitter = lambda m: [list(m[0].parameters()), list(m[1].parameters())]\n", "learn.to_non_native_fp16()\n", "learn.freeze()\n", "learn.create_opt()\n", "init_ps = [p for pg in learn.opt.param_groups for p in pg]\n", "learn.fit(3)\n", "final_ps = [p for pg in learn.opt.param_groups for p in pg]\n", "for p1,p2 in zip(init_ps, final_ps): test_is(p1, p2)\n", "#First param groups has no state because not trained\n", "test_eq([learn.opt.state[p] for p in learn.opt.param_lists[0]], [{}, {'do_wd': False}])\n", "#Second param groups has state \n", "for p in learn.opt.param_lists[1]: assert 'grad_avg' in learn.opt.state[p]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def to_non_native_fp32(self: Learner): return self.remove_cbs([ModelToHalf, NonNativeMixedPrecision])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#cuda\n", "learn = learn.to_non_native_fp32()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 01a_losses.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 10b_tutorial.albumentations.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 18b_callback.preds.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.image_sequence.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.azureml.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted dev-setup.ipynb.\n", "Converted app_examples.ipynb.\n", "Converted camvid.ipynb.\n", "Converted migrating_catalyst.ipynb.\n", "Converted migrating_ignite.ipynb.\n", "Converted migrating_lightning.ipynb.\n", "Converted migrating_pytorch.ipynb.\n", "Converted migrating_pytorch_verbose.ipynb.\n", "Converted ulmfit.ipynb.\n", "Converted index.ipynb.\n", "Converted quick_start.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import *\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }