{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp torch_core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.core.all import *\n", "from local.torch_imports import *\n", "from fastprogress import progress_bar,master_bar" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = ['progress_bar','master_bar']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "if torch.cuda.is_available():\n", " if torch.cuda.current_device()==0:\n", " torch.cuda.set_device(int(os.environ.get('DEFAULT_GPU') or 0))\n", " torch.backends.cudnn.benchmark = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Torch Core\n", "\n", "> Basic pytorch functions used in the fastai library" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def __array_eq__(self:Tensor,b):\n", " return torch.equal(self,b) if self.dim() else self==b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _array2tensor(x):\n", " if x.dtype==np.uint16: x = x.astype(np.float32) \n", " return torch.from_numpy(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def tensor(x, *rest, **kwargs):\n", " \"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly.\"\n", " if len(rest): x = (x,)+rest\n", " # There was a Pytorch bug in dataloader using num_workers>0. Haven't confirmed if fixed\n", " # if isinstance(x, (tuple,list)) and len(x)==0: return tensor(0)\n", " res = (x if isinstance(x, Tensor)\n", " else torch.tensor(x, **kwargs) if isinstance(x, (tuple,list))\n", " else _array2tensor(x) if isinstance(x, ndarray)\n", " else as_tensor(x.values, **kwargs) if isinstance(x, (pd.Series, pd.DataFrame))\n", " else as_tensor(x, **kwargs) if hasattr(x, '__array__') or is_iter(x)\n", " else _array2tensor(array(x), **kwargs))\n", " if res.dtype is torch.float64: return res.float()\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(tensor(torch.tensor([1,2,3])), torch.tensor([1,2,3]))\n", "test_eq(tensor(array([1,2,3])), torch.tensor([1,2,3]))\n", "test_eq(tensor(1,2,3), torch.tensor([1,2,3]))\n", "test_eq_type(tensor(1.0), torch.tensor(1.0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def set_seed(s):\n", " \"Set random seed for `random`, `torch`, and `numpy` (where available)\"\n", " try: torch.manual_seed(s)\n", " except NameError: pass\n", " try: np.random.seed(s%(2**32-1))\n", " except NameError: pass\n", " random.seed(s)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "set_seed(2*33)\n", "a1 = np.random.random()\n", "a2 = torch.rand(())\n", "a3 = random.random()\n", "set_seed(2*33)\n", "b1 = np.random.random()\n", "b2 = torch.rand(())\n", "b3 = random.random()\n", "test_eq(a1,b1)\n", "test_eq(a2,b2)\n", "test_eq(a3,b3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def unsqueeze(x, dim=-1, n=1):\n", " \"Same as `torch.unsqueeze` but can add `n` dims\"\n", " for _ in range(n): x = x.unsqueeze(dim)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1])\n", "t2 = unsqueeze(t, n=2)\n", "test_eq(t2,t[:,None,None])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def unsqueeze_(x, dim=-1, n=1):\n", " \"Same as `torch.unsqueeze_` but can add `n` dims\"\n", " for _ in range(n): x.unsqueeze_(dim)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1])\n", "unsqueeze_(t, n=2)\n", "test_eq(t, tensor([1]).view(1,1,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2(*args, **kwargs))\n", "def _fa_rebuild_qtensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def apply(func, x, *args, **kwargs):\n", " \"Apply `func` recursively to `x`, passing on args\"\n", " if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])\n", " if isinstance(x,dict): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}\n", " res = func(x, *args, **kwargs)\n", " return res if x is None else retain_type(res, x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def to_detach(b, cpu=True):\n", " \"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`.\"\n", " def _inner(x, cpu=True):\n", " if not isinstance(x,Tensor): return x\n", " x = x.detach()\n", " return x.cpu() if cpu else x\n", " return apply(_inner, b, cpu=cpu)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def to_half(b):\n", " \"Recursively map lists of tensors in `b ` to FP16.\"\n", " return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def to_float(b):\n", " \"Recursively map lists of int tensors in `b ` to float.\"\n", " return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "# None: True if available; True: error if not availabe; False: use CPU\n", "defaults.use_cuda = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def default_device(use_cuda=-1):\n", " \"Return or set default device; `use_cuda`: None - CUDA if available; True - error if not availabe; False - CPU\"\n", " if use_cuda != -1: defaults.use_cuda=use_cuda\n", " use = defaults.use_cuda or (torch.cuda.is_available() and defaults.use_cuda is None)\n", " assert torch.cuda.is_available() or not use\n", " return torch.device(torch.cuda.current_device()) if use else torch.device('cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#cuda\n", "_td = torch.device(torch.cuda.current_device())\n", "test_eq(default_device(None), _td)\n", "test_eq(default_device(True), _td)\n", "test_eq(default_device(False), torch.device('cpu'))\n", "default_device(None);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def to_device(b, device=None):\n", " \"Recursively put `b` on `device`.\"\n", " if device is None: device=default_device()\n", " def _inner(o): return o.to(device, non_blocking=True) if isinstance(o,Tensor) else o\n", " return apply(_inner, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = to_device((3,(tensor(3),tensor(2))))\n", "t1,(t2,t3) = t\n", "test_eq_type(t,(3,(tensor(3).cuda(),tensor(2).cuda())))\n", "test_eq(t2.type(), \"torch.cuda.LongTensor\")\n", "test_eq(t3.type(), \"torch.cuda.LongTensor\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def to_cpu(b):\n", " \"Recursively map lists of tensors in `b ` to the cpu.\"\n", " return to_device(b,'cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t3 = to_cpu(t3)\n", "test_eq(t3.type(), \"torch.LongTensor\")\n", "test_eq(t3, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def to_np(x):\n", " \"Convert a tensor to a numpy array.\"\n", " return apply(lambda o: o.data.cpu().numpy(), x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t3 = to_np(t3)\n", "test_eq(type(t3), np.ndarray)\n", "test_eq(t3, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensor subtypes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TensorBase(Tensor):\n", " def __new__(cls, x, **kwargs): \n", " res = torch.Tensor._make_subclass(cls, tensor(x))\n", " res._meta = kwargs\n", " return res\n", "\n", " def __reduce_ex__(self,proto):\n", " torch.utils.hooks.warn_if_has_hooks(self)\n", " args = (type(self), self.storage(), self.storage_offset(), tuple(self.size()), self.stride())\n", " if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())\n", " f = _fa_rebuild_qtensor if self.is_quantized else _fa_rebuild_tensor\n", " return (f, args + (self.requires_grad, OrderedDict()))\n", "\n", " def gi(self, i):\n", " res = self[i]\n", " return type(self)(res) if isinstance(res,Tensor) else res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _patch_tb():\n", " if getattr(TensorBase,'_patched',False): return\n", " TensorBase._patched = True\n", "\n", " def get_f(fn):\n", " def _f(self, *args, **kwargs):\n", " cls = self.__class__\n", " res = getattr(super(TensorBase, self), fn)(*args, **kwargs)\n", " return cls(res) if isinstance(res,Tensor) else res\n", " return _f\n", "\n", " t = tensor([1])\n", " skips = '__getitem__ __class__ __deepcopy__ __delattr__ __dir__ __doc__ __getattribute__ __hash__ __init__ \\\n", " __init_subclass__ __new__ __reduce__ __reduce_ex__ __module__ __setstate__'.split()\n", "\n", " for fn in dir(t):\n", " if fn in skips: continue\n", " f = getattr(t, fn)\n", " if isinstance(f, (MethodWrapperType, BuiltinFunctionType, BuiltinMethodType, MethodType, FunctionType)):\n", " setattr(TensorBase, fn, get_f(fn))\n", "\n", "_patch_tb()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "class TensorCategory(TensorBase): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "class TensorMultiCategory(TensorCategory): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _T(TensorBase): pass\n", "\n", "t = _T(range(5))\n", "test_eq(t[0], 0)\n", "test_eq_type(t.gi(0), _T(0))\n", "test_eq_type(t.gi(slice(2)), _T([0,1]))\n", "test_eq_type(t+1, _T(range(1,6)))\n", "\n", "test_eq(type(pickle.loads(pickle.dumps(t))), _T)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1,2,3])\n", "m = TensorBase([False,True,True])\n", "test_eq(t[m], tensor([2,3]))\n", "\n", "t = tensor([[1,2,3],[1,2,3]])\n", "m = TensorBase([[False,True,True],\n", " [False,True,True]])\n", "test_eq(t[m], tensor([2,3,2,3]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = TensorBase([[1,2,3],[1,2,3]], a=1)\n", "test_eq(t._meta, {'a': 1})\n", "x = retain_type(tensor([4,5,6]), t)\n", "test_eq(x._meta, {'a': 1})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TensorImageBase(TensorBase):\n", " _show_args = ArrayImageBase._show_args\n", " def show(self, ctx=None, **kwargs):\n", " return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TensorImage(TensorImageBase): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TensorImageBW(TensorImage): _show_args = ArrayImageBW._show_args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TensorMask(TensorImageBase): _show_args = ArrayMask._show_args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im = Image.open(TEST_IMAGE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im_t = TensorImage(array(im))\n", "test_eq(type(im_t), TensorImage)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im_t2 = TensorMask(tensor(1))\n", "test_eq(type(im_t2), TensorMask)\n", "test_eq(im_t2, tensor(1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ax = im_t.show(figsize=(2,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## L -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def tensored(self:L):\n", " \"`mapped(tensor)`\"\n", " return self.map(tensor)\n", "@patch\n", "def stack(self:L, dim=0):\n", " \"Same as `torch.stack`\"\n", " return torch.stack(list(self.tensored()), dim=dim)\n", "@patch\n", "def cat (self:L, dim=0):\n", " \"Same as `torch.cat`\"\n", " return torch.cat (list(self.tensored()), dim=dim)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

L.tensored[source]

\n", "\n", "> L.tensored()\n", "\n", "`mapped(tensor)`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(L.tensored)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are shortcuts for `torch.stack` and `torch.cat` if your `L` contains tensors or something convertible. You can manually convert with `tensored`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = L(([1,2],[3,4]))\n", "test_eq(t.tensored(), [tensor(1,2),tensor(3,4)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

L.stack[source]

\n", "\n", "> L.stack(**`dim`**=*`0`*)\n", "\n", "Same as `torch.stack`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(L.stack)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(t.stack(), tensor([[1,2],[3,4]]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

L.cat[source]

\n", "\n", "> L.cat(**`dim`**=*`0`*)\n", "\n", "Same as `torch.cat`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(L.cat)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(t.cat(), tensor([1,2,3,4]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chunks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def concat(*ls):\n", " \"Concatenate tensors, arrays, lists, or tuples\"\n", " if not len(ls): return []\n", " it = ls[0]\n", " if isinstance(it,torch.Tensor): res = torch.cat(ls)\n", " elif isinstance(it,ndarray): res = np.concatenate(ls)\n", " else:\n", " res = itertools.chain.from_iterable(map(L,ls))\n", " if isinstance(it,(tuple,list)): res = type(it)(res)\n", " else: res = L(res)\n", " return retain_type(res, it)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a,b,c = [1],[1,2],[1,1,2]\n", "test_eq(concat(a,b), c)\n", "test_eq_type(concat(tuple (a),tuple (b)), tuple (c))\n", "test_eq_type(concat(array (a),array (b)), array (c))\n", "test_eq_type(concat(tensor(a),tensor(b)), tensor(c))\n", "test_eq_type(concat(TensorBase(a),TensorBase(b)), TensorBase(c))\n", "test_eq_type(concat([1,1],1), [1,1,1])\n", "test_eq_type(concat(1,1,1), L(1,1,1))\n", "test_eq_type(concat(L(1,2),1), L(1,2,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Chunks:\n", " \"Slice and int indexing into a list of lists\"\n", " def __init__(self, chunks, lens=None):\n", " self.chunks = chunks\n", " self.lens = L(map(len,self.chunks) if lens is None else lens)\n", " self.cumlens = np.cumsum(0+self.lens)\n", " self.totlen = self.cumlens[-1]\n", "\n", " def __getitem__(self,i):\n", " if isinstance(i,slice): return retain_type(self.getslice(i), old=self.chunks[0])\n", " di,idx = self.doc_idx(i)\n", " return retain_type(self.chunks[di][idx], old=self.chunks[0])\n", "\n", " def getslice(self, i):\n", " st_d,st_i = self.doc_idx(ifnone(i.start,0))\n", " en_d,en_i = self.doc_idx(ifnone(i.stop,self.totlen+1))\n", " res = [self.chunks[st_d][st_i:(en_i if st_d==en_d else sys.maxsize)]]\n", " for b in range(st_d+1,en_d): res.append(self.chunks[b])\n", " if st_d!=en_d and en_dclass Module[source]\n", "\n", "> Module() :: [`Module`](/torchcore.html#Module)\n", "\n", "Same as `nn.Module`, but no need for subclasses to call `super().__init__`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Module, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-1.0893], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class _T(Module):\n", " def __init__(self): self.f = nn.Linear(1,1)\n", " def forward(self,x): return self.f(x)\n", "\n", "t = _T()\n", "t(tensor([1.]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from torch.nn.parallel import DistributedDataParallel\n", "\n", "def get_model(model):\n", " \"Return the model maybe wrapped inside `model`.\"\n", " return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def one_hot(x, c):\n", " \"One-hot encode `x` with `c` classes.\"\n", " res = torch.zeros(c, dtype=torch.uint8)\n", " if isinstance(x, Tensor) and x.numel()>0: res[x] = 1.\n", " else: res[list(L(x, use_list=None))] = 1.\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(one_hot([1,4], 5), tensor(0,1,0,0,1).byte())\n", "test_eq(one_hot(torch.tensor([]), 5), tensor(0,0,0,0,0).byte())\n", "test_eq(one_hot(2, 5), tensor(0,0,1,0,0).byte())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def one_hot_decode(x, vocab=None):\n", " return L(vocab[i] if vocab else i for i,x_ in enumerate(x) if x_==1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(one_hot_decode(tensor(0,1,0,0,1)), [1,4])\n", "test_eq(one_hot_decode(tensor(0,0,0,0,0)), [ ])\n", "test_eq(one_hot_decode(tensor(0,0,1,0,0)), [2 ])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def params(m):\n", " \"Return all parameters of `m`\"\n", " return [p for p in m.parameters()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def trainable_params(m):\n", " \"Return all trainable parameters of `m`\"\n", " return [p for p in m.parameters() if p.requires_grad]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Linear(4,5)\n", "test_eq(trainable_params(m), [m.weight, m.bias])\n", "m.weight.requires_grad_(False)\n", "test_eq(trainable_params(m), [m.bias])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def bn_bias_params(m, with_bias=True):\n", " \"Return all bias and BatchNorm parameters\"\n", " if isinstance(m, bn_types): return L(m.parameters()) if with_bias else L(m.weight)\n", " res = L(m.children()).map(bn_bias_params, with_bias=with_bias).concat()\n", " #if with_bias and hasattr(m, 'bias'): res.append(m.bias)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(nn.Linear(10,20), nn.BatchNorm1d(20), nn.Conv1d(3,4, 3))\n", "test_eq(bn_bias_params(model), [model[1].weight, model[1].bias])\n", "model = nn.ModuleList([nn.Linear(10,20), nn.Sequential(nn.BatchNorm1d(20), nn.Conv1d(3,4, 3))])\n", "test_eq(bn_bias_params(model), [model[1][0].weight, model[1][0].bias])\n", "model = nn.ModuleList([nn.Linear(10,20), nn.Sequential(nn.BatchNorm1d(20), nn.Conv1d(3,4, 3))])\n", "test_eq(bn_bias_params(model, with_bias=False), [model[1][0].weight])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def batch_to_samples(b, max_n=10):\n", " \"'Transposes' a batch to (at most `max_n`) samples\"\n", " if isinstance(b, Tensor): return retain_types(list(b[:max_n]), [b])\n", " else:\n", " res = L(b).map(partial(batch_to_samples,max_n=max_n))\n", " return retain_types(res.zip(), [b])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1,2,3])\n", "test_eq(batch_to_samples([t,t+1], max_n=2), ([1,2],[2,3]))\n", "test_eq(batch_to_samples(tensor([1,2,3]), 10), [1, 2, 3])\n", "test_eq(batch_to_samples([tensor([1,2,3]), tensor([4,5,6])], 10), [(1, 4), (2, 5), (3, 6)])\n", "test_eq(batch_to_samples([tensor([1,2,3]), tensor([4,5,6])], 2), [(1, 4), (2, 5)])\n", "test_eq(batch_to_samples([tensor([1,2,3]), [tensor([4,5,6]),tensor([7,8,9])]], 10), \n", " [(1, (4, 7)), (2, (5, 8)), (3, (6, 9))])\n", "test_eq(batch_to_samples([tensor([1,2,3]), [tensor([4,5,6]),tensor([7,8,9])]], 2), [(1, (4, 7)), (2, (5, 8))])\n", "\n", "t = Tuple(tensor([1,2,3]),TensorBase([2,3,4]))\n", "test_eq_type(batch_to_samples(t)[0][1], TensorBase(2))\n", "test_eq(batch_to_samples(t).map(type), [Tuple]*3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def interp_1d(x:Tensor, xp, fp):\n", " \"Same as `np.interp`\"\n", " slopes = (fp[1:]-fp[:-1])/(xp[1:]-xp[:-1])\n", " incx = fp[:-1] - (slopes*xp[:-1])\n", " locs = (x[:,None]>=xp[None,:]).long().sum(1)-1\n", " locs = locs.clamp(0,len(slopes)-1)\n", " return slopes[locs]*x + incx[locs]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "brks = tensor(0,1,2,4,8,64).float()\n", "ys = tensor(range_of(brks)).float()\n", "ys /= ys[-1].item()\n", "pts = tensor(0.2,0.5,0.8,3,5,63)\n", "\n", "preds = pts.interp_1d(brks, ys)\n", "test_close(preds.numpy(), np.interp(pts.numpy(), brks.numpy(), ys.numpy()))\n", "\n", "plt.scatter(brks,ys)\n", "plt.scatter(pts,preds)\n", "plt.legend(['breaks','preds']);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def logit(x):\n", " \"Logit of `x`, clamped to avoid inf.\"\n", " x = x.clamp(1e-7, 1-1e-7)\n", " return -(1/x-1).log()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def num_distrib():\n", " \"Return the number of processes in distributed training (if applicable).\"\n", " return int(os.environ.get('WORLD_SIZE', 0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def rank_distrib():\n", " \"Return the distributed rank of this process (if applicable).\"\n", " return int(os.environ.get('RANK', 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Image helpers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def make_cross_image(bw=True):\n", " \"Create a tensor containing a cross image, either `bw` (True) or color\"\n", " if bw:\n", " im = torch.zeros(5,5)\n", " im[2,:] = 1.\n", " im[:,2] = 1.\n", " else:\n", " im = torch.zeros(3,5,5)\n", " im[0,2,:] = 1.\n", " im[1,:,2] = 1.\n", " return im" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAI40lEQVR4nO3dT4ichR3G8efpJqJgwUPmELKh60GkQaiSIQj2FDzEKtqjgj0JuVSIUBDtzUOvxYuXYEVBUQQ9iFhEUGsLVp34r6ZRCJJiUMgEkeqloj49zBxiu7vzzuR959355fuBhZ3dycyD7nff2dnlHScRgDp+0vcAAO0iaqAYogaKIWqgGKIGitnVxY3u2bMnGxsbXdz0Je/EiRN9T5jLwYMH+55Q0pkzZ3T+/Hlv9rlOot7Y2NBoNOripi959qb/H3csvg66MRwOt/wcD7+BYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiGkVt+4jtT2yftv1A16MALG5m1LbXJD0i6RZJByTdZftA18MALKbJkfqQpNNJPk3yraRnJN3R7SwAi2oS9T5Jn11w+ez0Yz9i+6jtke3ReDxuax+AOTWJerPTV/7fq+olOZ5kmGQ4GAwufhmAhTSJ+qyk/RdcXpf0eTdzAFysJlG/I+ka21fbvkzSnZJe6HYWgEXNPJl/ku9s3yvpZUlrkh5LcrLzZQAW0ugVOpK8JOmljrcAaAF/UQYUQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDEzo7b9mO1ztj9axiAAF6fJkfpxSUc63gGgJTOjTvKGpC+XsAVAC/iZGiimtahtH7U9sj0aj8dt3SyAObUWdZLjSYZJhoPBoK2bBTAnHn4DxTT5ldbTkt6UdK3ts7bv6X4WgEXtmnWFJHctYwiAdvDwGyiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYpyk/Ru1279RAD+SxJt9nCM1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxcyM2vZ+26/ZPmX7pO1jyxgGYDEzz1Fme6+kvUnetf1TSSck/TrJP7f5N5yjDOjYwucoS/JFknen738t6ZSkfe3OA9CWXfNc2faGpBskvbXJ545KOtrKKgALa3yKYNtXSvqLpD8keX7GdXn4DXTsok4RbHu3pOckPTUraAD9avJEmSU9IenLJPc1ulGO1EDntjpSN4n6l5L+Kukfkn6Yfvj3SV7a5t8QNdCxhaNeBFED3eNld4BLBFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8XMdTbRpg4ePKjRaNTFTV/yJmeXWh1dnIQD0nA43PJzHKmBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiZkZt+3Lbb9v+wPZJ2w8tYxiAxTQ5ndF/JB1O8o3t3ZL+ZvvPSf7e8TYAC5gZdSYnmfpmenH39I0TTwE7VKOfqW2v2X5f0jlJryR5q9tZABbVKOok3ye5XtK6pEO2r/vf69g+antkezQej9veCaChuZ79TvKVpNclHdnkc8eTDJMMB4NBS/MAzKvJs98D21dN379C0s2SPu56GIDFNHn2e6+kJ2yvafJN4NkkL3Y7C8Cimjz7/aGkG5awBUAL+IsyoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgmMZR216z/Z7tF7scBODizHOkPibpVFdDALSjUdS21yXdKunRbucAuFhNj9QPS7pf0g9bXcH2Udsj26PxeNzKOADzmxm17dsknUtyYrvrJTmeZJhkOBgMWhsIYD5NjtQ3Sbrd9hlJz0g6bPvJTlcBWNjMqJM8mGQ9yYakOyW9muTuzpcBWAi/pwaK2TXPlZO8Lun1TpYAaAVHaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGinGS9m/UHkv6V8s3u0fS+ZZvs0urtHeVtkqrtberrT9LsukZPjuJugu2R0mGfe9oapX2rtJWabX29rGVh99AMUQNFLNKUR/ve8CcVmnvKm2VVmvv0reuzM/UAJpZpSM1gAaIGihmJaK2fcT2J7ZP236g7z3bsf2Y7XO2P+p7yyy299t+zfYp2ydtH+t701ZsX277bdsfTLc+1PemJmyv2X7P9ovLus8dH7XtNUmPSLpF0gFJd9k+0O+qbT0u6UjfIxr6TtLvkvxc0o2SfruD/9v+R9LhJL+QdL2kI7Zv7HlTE8cknVrmHe74qCUdknQ6yadJvtXklTfv6HnTlpK8IenLvnc0keSLJO9O3/9aky++ff2u2lwmvple3D1929HP8tpel3SrpEeXeb+rEPU+SZ9dcPmsdugX3iqzvSHpBklv9btka9OHsu9LOifplSQ7duvUw5Lul/TDMu90FaL2Jh/b0d+hV43tKyU9J+m+JP/ue89Wknyf5HpJ65IO2b6u701bsX2bpHNJTiz7vlch6rOS9l9weV3S5z1tKcf2bk2CfirJ833vaSLJV5q8+upOfu7iJkm32z6jyY+Mh20/uYw7XoWo35F0je2rbV+myQvfv9DzphJsW9KfJJ1K8se+92zH9sD2VdP3r5B0s6SP+121tSQPJllPsqHJ1+yrSe5exn3v+KiTfCfpXkkva/JEzrNJTva7amu2n5b0pqRrbZ+1fU/fm7Zxk6TfaHIUeX/69qu+R21hr6TXbH+oyTf6V5Is7ddEq4Q/EwWK2fFHagDzIWqgGKIGiiFqoBiiBoohaqAYogaK+S/20vv5In3GxwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(make_cross_image(), cmap=\"Greys\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAI3UlEQVR4nO3dQYic9R3G8efpJqLUgofmINnQeBCpBBoxBCGXEiykGvSqUE9CLhUiWMSeivciXnoJGiwoiqAHyUUCTRHBxmxiLMbVEsTiorAtUjQ9VKK/HmYOqd3ZeWf2fefd98n3AwM7u7Pv/HiZ777vzCz/cVUJQI4f9D0AgHYRNRCGqIEwRA2EIWogzI4uNmqbl9S7cnffA8zofN8D5Koqb/R9d/GWFlF3aGh7dsOHHdowKWpOv4EwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwjaK2fcT2x7Yv236q66EAzG/qcka2lyT9TdIvJK1JOifp4ar6cJPfGdqiO8MxtD3Lckad2cpyRgclXa6qT6rqG0mvSHqwzeEAtKdJ1LslfXbN9bXx9/6H7WO2V2yvtDUcgNk1WSJ4o0P8/50EVtUJSSckTr+BPjU5Uq9J2nPN9WVJn3czDoCtahL1OUm3277N9g2SHpL0RrdjAZjX1NPvqrpq+zFJb0paknSyqi51PhmAufAJHUMztD3LW1qd4RM6gOsEUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhpkZt+6TtddsfLGIgAFvT5Ej9gqQjHc8BoCVTo66qtyR9uYBZALSA59RAmB1tbcj2MUnH2toegPm4qqbfyN4r6VRV7Wu0UXv6RjGfoe1Z9z1ArqracO9y+g2EafKW1suS3pF0h+012492PxaAeTU6/Z55o5x+d2doe5bT785w+g1cJ4gaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogTGsLD17rbkkrXWwYgCTpwCY/40gNhCFqIAxRA2GIGghD1EAYogbCEDUQhqiBMEQNhCFqIAxRA2GIGghD1EAYogbCEDUQhqiBMEQNhCFqIMzUqG3vsX3G9qrtS7aPL2IwAPNpskbZVUlPVNUF2z+SdN726ar6sOPZAMxh6pG6qr6oqgvjr7+WtCppd9eDAZjPTM+pbe+VdJeksxv87JjtFdsr/2hnNgBzaBy17ZslvSbp8ar66vs/r6oTVXWgqg7sanNCADNpFLXtnRoF/VJVvd7tSAC2osmr35b0vKTVqnqm+5EAbEWTI/UhSY9IOmz74vhyX8dzAZjT1Le0quptSV7ALABawH+UAWGIGghD1EAYogbCEDUQhqiBMEQNhCFqIAxRA2GIGghD1EAYogbCEDUQhqiBMEQNhCFqIIyrqv2N2u1vFCND27Msr9GZqtpw73KkBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsJMjdr2jbbftf2+7Uu2n17EYADmM3U5I9uW9MOqumJ7p6S3JR2vqr9s8jtDW3RnOIa2Z1nOqDOTljPa0eAXS9KV8dWd48vQHlrAdaPRc2rbS7YvSlqXdLqqznY7FoB5NYq6qr6tqv2SliUdtL3v+7exfcz2iu2VtocE0NzMSwTb/p2kf1fV7ze5DafnXRnanuU5dWfmXiLY9i7bt4y/vknSvZI+anc8AG2Z+kKZpFsl/dH2kkZ/BF6tqlPdjgVgXnxCx9AMbc9y+t0ZPqEDuE4QNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwjaO2vWT7PdunuhwIwNbMcqQ+Lmm1q0EAtKNR1LaXJd0v6bluxwGwVU2P1M9KelLSd5NuYPuY7RXbK61MBmAuU6O2fVTSelWd3+x2VXWiqg5U1YHWpgMwsyZH6kOSHrD9qaRXJB22/WKnUwGYm6uq+Y3tn0v6TVUdnXK75hvFbIa2Z933ALmqasO9y/vUQJiZjtSNN8qRujtD27McqTvDkRq4ThA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAmB0dbfefkv7e8jZ/PN7uUHQzbzeLDrBvu9PVrD+Z9INOVj7pgu2VIa1UOqR5hzSrNKx5+5iV028gDFEDYYYU9Ym+B5jRkOYd0qzSsOZd+KyDeU4NoJkhHakBNEDUQJhBRG37iO2PbV+2/VTf82zG9knb67Y/6HuWaWzvsX3G9qrtS7aP9z3TJLZvtP2u7ffHsz7d90xN2F6y/Z7tU4u6z20fte0lSX+Q9EtJd0p62Pad/U61qRckHel7iIauSnqiqn4q6R5Jv97G+/Y/kg5X1c8k7Zd0xPY9Pc/UxHFJq4u8w20ftaSDki5X1SdV9Y1Gn7z5YM8zTVRVb0n6su85mqiqL6rqwvjrrzV68O3ud6qN1ciV8dWd48u2fpXX9rKk+yU9t8j7HULUuyV9ds31NW3TB96Q2d4r6S5JZ/udZLLxqexFSeuSTlfVtp117FlJT0r6bpF3OoSoN/pv5239F3pobN8s6TVJj1fVV33PM0lVfVtV+yUtSzpoe1/fM01i+6ik9ao6v+j7HkLUa5L2XHN9WdLnPc0Sx/ZOjYJ+qape73ueJqrqX5L+rO392sUhSQ/Y/lSjp4yHbb+4iDseQtTnJN1u+zbbN0h6SNIbPc8UwbYlPS9ptaqe6XuezdjeZfuW8dc3SbpX0kf9TjVZVf22qparaq9Gj9k/VdWvFnHf2z7qqroq6TFJb2r0Qs6rVXWp36kms/2ypHck3WF7zfajfc+0iUOSHtHoKHJxfLmv76EmuFXSGdt/1egP/emqWtjbREPCv4kCYbb9kRrAbIgaCEPUQBiiBsIQNRCGqIEwRA2E+S+hrujkVjaWiAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(make_cross_image(False).permute(1,2,0));" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs):\n", " \"Display batch `b` in a grid of size `items` with `cols` width\"\n", " if items" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_image_batch(([Image.open(TEST_IMAGE_BW),Image.open(TEST_IMAGE)],['bw','color']), items=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model init" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def requires_grad(m):\n", " \"Check if the first parameter of `m` requires grad or not\"\n", " ps = list(m.parameters())\n", " return ps[0].requires_grad if len(ps)>0 else False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Linear(4,5)\n", "assert requires_grad(tst)\n", "for p in tst.parameters(): p.requires_grad_(False)\n", "assert not requires_grad(tst)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def init_default(m, func=nn.init.kaiming_normal_):\n", " \"Initialize `m` weights with `func` and set `bias` to 0.\"\n", " if func:\n", " if hasattr(m, 'weight'): func(m.weight)\n", " if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)\n", " return m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Linear(4,5)\n", "tst.weight.data.uniform_(-1,1)\n", "tst.bias.data.uniform_(-1,1)\n", "tst = init_default(tst, func = lambda x: x.data.fill_(1.))\n", "test_eq(tst.weight, torch.ones(5,4))\n", "test_eq(tst.bias, torch.zeros(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def cond_init(m, func):\n", " \"Apply `init_default` to `m` unless it's a batchnorm module\"\n", " if (not isinstance(m, bn_types)) and requires_grad(m): init_default(m, func)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Linear(4,5)\n", "tst.weight.data.uniform_(-1,1)\n", "tst.bias.data.uniform_(-1,1)\n", "cond_init(tst, func = lambda x: x.data.fill_(1.))\n", "test_eq(tst.weight, torch.ones(5,4))\n", "test_eq(tst.bias, torch.zeros(5))\n", "\n", "tst = nn.BatchNorm2d(5)\n", "init = [tst.weight.clone(), tst.bias.clone()]\n", "cond_init(tst, func = lambda x: x.data.fill_(1.))\n", "test_eq(tst.weight, init[0])\n", "test_eq(tst.bias, init[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def apply_leaf(m, f):\n", " \"Apply `f` to children of `m`.\"\n", " c = m.children()\n", " if isinstance(m, nn.Module): f(m)\n", " for l in c: apply_leaf(l,f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.Linear(4,5)))\n", "apply_leaf(tst, partial(init_default, func=lambda x: x.data.fill_(1.)))\n", "for l in [tst[0], *tst[1]]: test_eq(l.weight, torch.ones(5,4))\n", "for l in [tst[0], *tst[1]]: test_eq(l.bias, torch.zeros(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def apply_init(m, func=nn.init.kaiming_normal_):\n", " \"Initialize all non-batchnorm layers of `m` with `func`.\"\n", " apply_leaf(m, partial(cond_init, func=func))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(5)))\n", "init = [tst[1][1].weight.clone(), tst[1][1].bias.clone()]\n", "apply_init(tst, func=lambda x: x.data.fill_(1.))\n", "for l in [tst[0], tst[1][0]]: test_eq(l.weight, torch.ones(5,4))\n", "for l in [tst[0], tst[1][0]]: test_eq(l.bias, torch.zeros(5))\n", "test_eq(tst[1][1].weight, init[0])\n", "test_eq(tst[1][1].bias, init[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiprocessing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from multiprocessing import Process, Queue" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def set_num_threads(nt):\n", " \"Get numpy (and others) to use `nt` threads\"\n", " try: import mkl; mkl.set_num_threads(nt)\n", " except: pass\n", " torch.set_num_threads(1)\n", " os.environ['IPC_ENABLE']='1'\n", " for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:\n", " os.environ[o] = str(nt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "@delegates(concurrent.futures.ProcessPoolExecutor)\n", "class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):\n", " def __init__(self, max_workers=None, on_exc=print, **kwargs):\n", " self.not_parallel = max_workers==0\n", " self.on_exc = on_exc\n", " if self.not_parallel: max_workers=1\n", " super().__init__(max_workers, **kwargs)\n", "\n", " def map(self, f, items, *args, **kwargs):\n", " g = partial(f, *args, **kwargs)\n", " if self.not_parallel: return map(g, items)\n", " try: return super().map(g, items)\n", " except Exception as e: self.on_exc(e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=True, **kwargs):\n", " \"Applies `func` in parallel to `items`, using `n_workers`\"\n", " with ProcessPoolExecutor(n_workers) as ex:\n", " r = ex.map(f,items, *args, **kwargs)\n", " if progress:\n", " if total is None: total = len(items)\n", " r = progress_bar(r, total=total, leave=False)\n", " return L(r)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def add_one(x, a=1): \n", " time.sleep(random.random()/100)\n", " return x+a\n", "\n", "inp,exp = range(50),range(1,51)\n", "test_eq(parallel(add_one, inp, n_workers=2), exp)\n", "test_eq(parallel(add_one, inp, n_workers=0), exp)\n", "test_eq(parallel(add_one, inp, n_workers=1, a=2), range(2,52))\n", "test_eq(parallel(add_one, inp, n_workers=0, a=2), range(2,52))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def run_procs(f, f_done, args):\n", " \"Call `f` for each item in `args` in parallel, yielding `f_done`\"\n", " processes = L(args).map(Process, args=arg0, target=f)\n", " for o in processes: o.start()\n", " try: yield from f_done()\n", " except Exception as e: print(e)\n", " finally: processes.map(Self.join())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "def parallel_gen(cls, items, n_workers=defaults.cpus, as_gen=False, **kwargs):\n", " \"Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel.\"\n", " batches = np.array_split(items, n_workers)\n", " idx = np.cumsum(0 + L(batches).map(len))\n", " queue = Queue()\n", " def f(batch, start_idx):\n", " for i,b in enumerate(cls(**kwargs)(batch)): queue.put((start_idx+i,b))\n", " def done(): return (queue.get() for _ in progress_bar(items, leave=False))\n", " yield from run_procs(f, done, L(batches,idx).zip())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cls` is any class with `__call__`. It will be passed `args` and `kwargs` when initialized. Note that `n_workers` instances of `cls` are created, one in each process. `items` are then split in `n_workers` batches and one is sent to each `cls`. The function then returns a list of all the results, matching the order of `items` (if not `as_gen`) or a generator of tuples of item indices and results (if `as_gen`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SleepyBatchFunc:\n", " def __init__(self): self.a=1\n", " def __call__(self, batch):\n", " for k in batch:\n", " time.sleep(random.random()/4)\n", " yield k+self.a\n", "\n", "x = np.linspace(0,0.99,20)\n", "res = L(parallel_gen(SleepyBatchFunc, x, n_workers=2))\n", "test_eq(res.sorted().itemgot(1), x+1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## autograd jit functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def script_use_ctx(f):\n", " \"Decorator: create jit script and pass everything in `ctx.saved_variables to `f`, after `*args`\"\n", " sf = torch.jit.script(f)\n", " def _f(ctx, *args, **kwargs): return sf(*args, *ctx.saved_variables, **kwargs)\n", " return update_wrapper(_f,f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def script_save_ctx(static, *argidx):\n", " \"Decorator: create jit script and save args with indices `argidx` using `ctx.save_for_backward`\"\n", " def _dec(f):\n", " sf = torch.jit.script(f)\n", " def _f(ctx, *args, **kwargs):\n", " if argidx:\n", " save = [args[o] for o in argidx]\n", " ctx.save_for_backward(*save)\n", " if not argidx: args = [ctx]+args\n", " return sf(*args, **kwargs)\n", " if static: _f = staticmethod(_f)\n", " return update_wrapper(_f,f)\n", " return _dec" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def script_fwd(*argidx):\n", " \"Decorator: create static jit script and save args with indices `argidx` using `ctx.save_for_backward`\"\n", " return script_save_ctx(True, *argidx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def script_bwd(f):\n", " \"Decorator: create static jit script and pass everything in `ctx.saved_variables to `f`, after `*args`\"\n", " return staticmethod(script_use_ctx(f))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def grad_module(cls):\n", " \"Decorator: convert `cls` into an autograd function\"\n", " class _c(nn.Module):\n", " def forward(self, *args, **kwargs): return cls.apply(*args, **kwargs)\n", " return _c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core_foundation.ipynb.\n", "Converted 01a_core_utils.ipynb.\n", "Converted 01b_core_dispatch.ipynb.\n", "Converted 01c_core_transform.ipynb.\n", "Converted 02_core_script.ipynb.\n", "Converted 03_torchcore.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_data_load.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 71_callback_tensorboard.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }