{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "#| eval: false\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|default_exp torch_core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from __future__ import annotations\n", "from fastai.imports import *\n", "from fastai.torch_imports import *\n", "from packaging.version import parse" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "_all_ = ['progress_bar','master_bar']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "defaults.benchmark = True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def setup_cuda(benchmark=defaults.benchmark):\n", " \"Sets the main cuda device and sets `cudnn.benchmark` to `benchmark`\"\n", " if torch.cuda.is_available():\n", " if torch.cuda.current_device()==0:\n", " def_gpu = int(os.environ.get('DEFAULT_GPU') or 0)\n", " if torch.cuda.device_count()>=def_gpu: torch.cuda.set_device(def_gpu)\n", " torch.backends.cudnn.benchmark = benchmark" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Torch Core\n", "\n", "> Basic pytorch functions used in the fastai library" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Arrays and show" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@delegates(plt.subplots, keep=True)\n", "def subplots(\n", " nrows:int=1, # Number of rows in returned axes grid\n", " ncols:int=1, # Number of columns in returned axes grid\n", " figsize:tuple=None, # Width, height in inches of the returned figure\n", " imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure\n", " suptitle:str=None, # Title to be set to returned figure\n", " **kwargs\n", ") -> (plt.Figure, plt.Axes): # Returns both fig and ax as a tuple\n", " \"Returns a figure and set of subplots to display images of `imsize` inches\"\n", " if figsize is None:\n", " h=nrows*imsize if suptitle is None or imsize>2 else nrows*imsize+0.6 #https://github.com/matplotlib/matplotlib/issues/5355\n", " figsize=(ncols*imsize, h)\n", " fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)\n", " if suptitle is not None: fig.suptitle(suptitle)\n", " if nrows*ncols==1: ax = array([ax])\n", " return fig,ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is used in `get_grid`. `suptitle`, `sharex`, `sharey`, `squeeze`, `subplot_kw` and `gridspec_kw` are all passed down to [plt.subplots](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html#matplotlib-pyplot-subplots)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "_,axs = subplots()\n", "test_eq(axs.shape,[1])\n", "plt.close()\n", "_,axs = subplots(2,3)\n", "test_eq(axs.shape,[2,3])\n", "plt.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _fig_bounds(x):\n", " r = x//32\n", " return min(5, max(1,r))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@delegates(plt.Axes.imshow, keep=True, but=['shape', 'imlim'])\n", "def show_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):\n", " \"Show a PIL or PyTorch image on `ax`.\"\n", " # Handle pytorch axis order\n", " if hasattrs(im, ('data','cpu','permute')):\n", " im = im.data.cpu()\n", " if im.shape[0]<5: im=im.permute(1,2,0)\n", " elif not isinstance(im,np.ndarray): im=array(im)\n", " # Handle 1-channel images\n", " if im.shape[-1]==1: im=im[...,0]\n", "\n", " ax = ifnone(ax,ctx)\n", " if figsize is None: figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1]))\n", " if ax is None: _,ax = plt.subplots(figsize=figsize)\n", " ax.imshow(im, **kwargs)\n", " if title is not None: ax.set_title(title)\n", " ax.axis('off')\n", " return ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`show_image` can show PIL images..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGEAAABhCAYAAADGBs+jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAALz0lEQVR4nO2dW08bRxuAn92117s+LdjY5hQIOTVAg9SokRK1UtTLVlUvc93+sPYv5KJSbtqbSpVID4lEKYVIGAgEjM82Pq+9u99FtFsIJCVpG0/67XOTSPHCep59Z+Z9Z2YjOY7j4DNU5GHfgI8vQQh8CQLgSxAAX4IA+BIEwJcgAL4EAfAlCIAvQQB8CQLgSxAAX4IA+BIEwJcgAL4EAfAlCIAvQQB8CQIQGPYNuFiWRb/fx3EcTNOk1+th2za9Xo/BYHDis+FwmGg0iizLSJIEgCzLBAIBAgFhvtK5EeaOu90upVKJbrfL1tYWa2trNJtNdnZ2yOfzJz578+ZNPv30U2KxGLIsoygKqqoyPj5ONBod0jd4c4SQ4DgOlmXRbDZpNptks1nW1tY4Ojri119/ZW9vD3dTiCRJSJLE4uIimUyGQCCALMuEw2HGxsZwHMeLjneFoUlwHIfBYECn02EwGLC3t8fy8jK1Wo1sNsvW1haNRoNWqwU8b3xXxLNnz3j48CGxWAxFUZBlmXg8jm3bXLhwgVAoRDgcJhAIIEkSsiz20DcUCW5jmqbJ/v4+1WqV5eVl7t+/z+HhIa1Wi2aziW3bmKbpXec+4Zubm+RyORRF8Ro5lUpRr9f58MMPSafTvPfee+i67n1G5OgYaiTYtk2z2aRSqVAqlTg4OKBUKtHv989sfJder0ev1zvx74PBgGKxyOHhIYFAgHa7TSAQ8LonNxpElPHWJbj9v23btFot1tfXyWaz7Ozs0Gq16Pf7p2ZD5/mZ7Xabx48fU6lUiMVirKysEI/HWVxcZH5+Hk3TCIVCqKoqnIihRIJt295A/Mcff7C+vs7Tp09pNpsnIuB1cCWsrKwQDAaJRCKoqsoXX3zBV199RTKZJJFIoKrqP/xt/j5D647cLiIUCnlP6cue0OO5gDtAu+PK8b9bloVlWQwGAxzHIRgMeoN7OBx+7Qh7W7x1CZIkEQgEsG2beDzOjRs30HWdUCjE1tYWpmli27bXsIqioCgKgUDAm/G4De0mc91u98TvcBzHS/Ly+Txra2tMTk6i6zqGYfjdETx/omVZRtM0pqenkSSJYrHo9dfu0y5JkidBVVUikQihUAjTNL2x46yn2x13LMuiUqmwv79Pv9/nypUriLgJfajJmqIojI6OAjA7O8vCwgL5fJ7BYIBpmt4MyrZtEokEi4uLxONxWq0WtVqNTqfD5uYmnU4H27bP/B29Xo9Go0EoFKLb7WJZlidalIgYqgRVVZmZmWFqagrDMAgGg9TrdTqdDt1ul36/T7lcptlscunSJe7du8fExATlcplsNku1WuW7774jl8u9dECv1+s8e/aMZrNJsVj0urtgMIiiKG/5G5/N0CNB13UAxsbGuHbtGtVqlVarRbvdptvteoW5yclJLl68SCqVIhaLAc8b+NGjR6/MiE3TpFqtEggEaLVaWJaFLMtCdUtC1I4AdF1nZmaGTCZDv9/3+vxCoUClUuHSpUtomvbaP7fT6ZDL5ej1elQqFUzT9CYHoiDMnYTDYWZnZ09NPd3Zj6ZpbySh0WjQ7Xa9bLrdbgMIlS8II0GW5VMNY9u299S6hbrXxXEcL6osyzqRV4iCMBLOwhXgJmlvIiEajXrZcjqd9nISkSqr74SEv7NGEI1GuXr1KolEgsnJSTRNIxgMCiVBnDt5BX9nPh8IBIhEIici4PiyqAgIHQn/BLFYjPHxccbGxrwCntvFicJ/XkI4HGZ0dJRkMkk0GvUGeZF4JyWYpkmlUiGfz9NoNM412xGp+3mRd06C4zgcHR3x8OFDcrkcW1tbwpaoz4s4HeM5cJ/4Xq9HoVAgn89Tq9Ve+nn36X/xT9EQOhLcjNmyLNrtNvv7+xwdHfHo0SM2NzcpFArUarVTFdRQKEQikUDTNN5//31u3bpFOp1mdHRUSBFCS7Asi1arRb1e5+nTp3z99desrKzQaDSo1WpeJvzimBCJRFhYWCCVSnH79m3u3LnjDcoizYpchJYAzwfhdrtNqVRiZ2eHbDbrray9jEAggGEYGIZBIpEgHA6/cvl02AgtYTAYkMvlWF1dJZvNUi6Xz3zyXyQWi3H9+nXm5ua4ePGi8HuPhJZg2zY7Ozv88ssvFItFarUalmX95XWGYbCwsMD8/DzpdNrbiScq4nWQx5AkCU3T0HUdVVXP3Z/3+31qtRpHR0e0222vgvqyJdBhI7QEWZaZm5vj7t27LC0tYRjGua7L5XI8ePCAb775hu+//55isUij0fDWrUVD6O7I3Qhw5coVut0ukUjkXNdVq1V+/PFHVFVF13Vu377t/bxQKPRv3vIbIbQESZK8hkyn0ywtLaEoCqVSiXw+7y3av9jNuAs5gLfXVZZldF33RIo0RggvQdd1gsEgwWCQL7/8kmKxyM8//8z9+/cpl8t0u106nc6J69xNYf1+n93dXX766SfS6TQfffQRhmF4BTxRRAgvwZ3ZRKNRZmZmSKVSFAoF4vE4zWaTwWBw4uyCixshR0dHFAoFbNum3W57y5uiCADBJQDejEhVVeLxOJqmsbS0xGeffeZVUnd3d72yxou1pEqlwubmJqVSiZmZGZLJJJqmYRjGG20c+DcQWoK7ruyeSwsGgziO440RnU6HjY0NfvvtN6rVKj/88MMpCYeHh1SrVS9rVhSF8fFxlpaWfAln4XYpZw22x1EUBcMw0HWdTCZDJpNBlmUMwyAUCnn7UN0jWYPBwMsd8vk8six7hUERtkQKI8FtKMuyODg4YGtry5vhHEeSJDKZDFNTU942ymQyydHREZIkMTExQaFQ4Pfffz8RFZZlsb6+7h2tSiaT2LaNruvE4/Gh7kMSRoJlWd5MZ3l5mW+//fbUlnd4LuHu3bvcu3ePcDhMOp0mGAximiaDwYCRkRF2d3fZ3d09JWF7e5uDgwPS6TSZTAZFUUgmk4TDYV8C/HmGzS1fl8tlb7fccSRJolwuU6lUcByHWCyGJEleFL1qc5dbfT1eBBRhliSMBPizkarVKtvb26fm/y66rtPv94nH48zNzTE3N4dpmqysrLC3t0ehUDhxsNAlGAyiaRqxWIxEIkEikfAOpA8TYSQcP4vgrh+8TMLq6irlcplIJMLS0hKlUgnLstjf36dSqVCv10+NJ27OoaqqJyIajaJp2tCjQRgJr8NgMPAStVKpxO7uLrZtUyqVqNfrlEqlMxd9ZFkmGAyi6zrRaJRwOIymaX4kvAntdpt+v48sy1QqFR4/fuy9mMRd8nTfBHAcVVW9rmhqaopkMumdfxgmwkp4VRfh5gHAmYP3y3AjwY0GUSqqwqwnHD9Om0wmmZ6eZnx83DvJ819GGAlurT8SiXDhwgVu3LjB9evXicfjw761fx2huiO3RhSLxRgZGUGSJAzD8LY6ui+lcmdS58WtQamqSjQaFe58giTK/y7lJmq2bXN4eMj29jaNRsPb5FWv13ny5AnlcplqtUqhUDjXKxjc1+5omsYnn3zCrVu3mJiY4OOPP2ZiYuItfLO/RphIcKuljuMwNTVFJpPBsixu3rxJo9Hg8PCQBw8esL+/z87ODtVq1YuMV6GqKoZhMDo6ygcffMDnn39OOBw+93r120AYCS5u1+Eu5ui6jm3bjIyMMDk5CeAlZqqqemeej6MoCpFIBEVRSCQSTE9Pk0gkSKVS6LouRG5wHOEkwJ8iJEkiEomgaRrxeJzR0VG63S4bGxuk02nK5TKrq6tsbGyc2I+USCS4c+cOY2NjTExMcPXqVQzDYH5+npGRESFyg+OIcyfHOF7fP36gwzAMHMchHA5TLBbJ5XJUq1WePHly4vpIJMKVK1eYnZ3l8uXLXLt2zStZi1CmeBEhJbwKd0PY5cuXMQwD0zTRdf1EmSKTyTA/P8/U1BSpVOrEqpqICDM7eh3c96a6axDdbvfEAO0eFnSPRrm7sd0psGi8kxL+a4gzRfg/xpcgAL4EAfAlCIAvQQB8CQLgSxAAX4IA+BIEwJcgAL4EAfAlCIAvQQB8CQLgSxAAX4IA+BIEwJcgAL4EAfgfwdHXP6ARORkAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "im = Image.open(TEST_IMAGE_BW)\n", "ax = show_image(im, cmap=\"Greys\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and color images with standard `CHW` dim order..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "im2 = np.array(Image.open(TEST_IMAGE))\n", "ax = show_image(im2, figsize=(2,2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and color images with `HWC` dim order..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "im3 = torch.as_tensor(im2).permute(2,0,1)\n", "ax = show_image(im3, figsize=(2,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@delegates(show_image, keep=True)\n", "def show_titled_image(o, **kwargs):\n", " \"Call `show_image` destructuring `o` to `(img,title)`\"\n", " show_image(o[0], title=str(o[1]), **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_titled_image((im3,'A puppy'), figsize=(2,2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show all images `ims` as subplots with `rows` using `titles`. `suptitle` provides a way to create a figure title for all images. If you use `suptitle`, `constrained_layout` is used unless you set `constrained_layout` to `False`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@delegates(subplots)\n", "def show_images(ims, nrows=1, ncols=None, titles=None, **kwargs):\n", " \"Show all images `ims` as subplots with `rows` using `titles`.\"\n", " if ncols is None: ncols = int(math.ceil(len(ims)/nrows))\n", " if titles is None: titles = [None]*len(ims)\n", " axs = subplots(nrows, ncols, **kwargs)[1].flat\n", " for im,t,ax in zip(ims, titles, axs): show_image(im, ax=ax, title=t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_images((im,im3),titles=('number','puppy'),suptitle='Number Puppy', imsize=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`ArrayImage`, `ArrayImageBW` and `ArrayMask` are subclasses of `ndarray` that know how to show themselves." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ArrayBase(ndarray):\n", " \"An `ndarray` that can modify casting behavior\"\n", " @classmethod\n", " def _before_cast(cls, x): return x if isinstance(x,ndarray) else array(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ArrayImageBase(ArrayBase):\n", " \"Base class for arrays representing images\"\n", " _show_args = {'cmap':'viridis'}\n", " def show(self, ctx=None, **kwargs):\n", " return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ArrayImage(ArrayImageBase):\n", " \"An array representing an image\"\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ArrayImageBW(ArrayImage):\n", " \"An array representing an image\"\n", " _show_args = {'cmap':'Greys'}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ArrayMask(ArrayImageBase):\n", " \"An array representing an image mask\"\n", " _show_args = {'alpha':0.5, 'cmap':'tab20', 'interpolation':'nearest'}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im = Image.open(TEST_IMAGE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im_t = cast(im, ArrayImage)\n", "test_eq(type(im_t), ArrayImage)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax = im_t.show(figsize=(2,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def __array_eq__(self:Tensor,b):\n", " return torch.equal(self,b) if self.dim() else self==b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _array2tensor(x, requires_grad=False, pin_memory=False, **kwargs):\n", " if x.dtype==np.uint16: x = x.astype(np.float32)\n", " # windows default numpy int dtype is int32, while torch tensor default int dtype is int64\n", " # https://github.com/numpy/numpy/issues/9464\n", " if sys.platform == \"win32\" and x.dtype==int: x = x.astype(np.int64)\n", " t = torch.as_tensor(x, **kwargs)\n", " t.requires_grad_(requires_grad)\n", " if pin_memory: t.pin_memory()\n", " return t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@use_kwargs_dict(dtype=None, device=None, requires_grad=False, pin_memory=False)\n", "def tensor(x, *rest, **kwargs):\n", " \"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly.\"\n", " if len(rest): x = (x,)+rest\n", " # There was a Pytorch bug in dataloader using num_workers>0. Haven't confirmed if fixed\n", " # if isinstance(x, (tuple,list)) and len(x)==0: return tensor(0)\n", " res = (x if isinstance(x, Tensor)\n", " else torch.tensor(x, **kwargs) if isinstance(x, (tuple,list,numbers.Number))\n", " else _array2tensor(x, **kwargs) if isinstance(x, ndarray)\n", " else as_tensor(x.values, **kwargs) if isinstance(x, (pd.Series, pd.DataFrame))\n", "# else as_tensor(array(x, **kwargs)) if hasattr(x, '__array__') or is_iter(x)\n", " else _array2tensor(array(x), **kwargs))\n", " if res.dtype is torch.float64: return res.float()\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(tensor(torch.tensor([1,2,3])), torch.tensor([1,2,3]))\n", "test_eq(tensor(array([1,2,3])), torch.tensor([1,2,3]))\n", "test_eq(tensor(1,2,3), torch.tensor([1,2,3]))\n", "test_eq_type(tensor(1.0), torch.tensor(1.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`set_seed` is useful for reproducibility between runs. It is important to remember that certain classes such as `Dataloaders` have internal random number generators that is not effected by this function, so this must be run before such objects are created in order to guarantee reproducibility. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def set_seed(s, reproducible=False):\n", " \"Set random seed for `random`, `torch`, and `numpy` (where available)\"\n", " try: torch.manual_seed(s)\n", " except NameError: pass\n", " try: torch.cuda.manual_seed_all(s)\n", " except NameError: pass\n", " try: np.random.seed(s%(2**32-1))\n", " except NameError: pass\n", " random.seed(s)\n", " if reproducible:\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an example of how `set_seed` can be used to reset the state of random number generators." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a's: 0.154 0.498 0.071\n", "b's: 0.154 0.498 0.071\n" ] } ], "source": [ "set_seed(2*33)\n", "a1 = np.random.random()\n", "a2 = torch.rand(())\n", "a3 = random.random()\n", "set_seed(2*33)\n", "b1 = np.random.random()\n", "b2 = torch.rand(())\n", "b3 = random.random()\n", "print('a\\'s: {0:3.3f} {1:3.3f} {2:3.3f}'.format(a1,a2,a3))\n", "print('b\\'s: {0:3.3f} {1:3.3f} {2:3.3f}'.format(b1,b2,a3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(a1,b1)\n", "test_eq(a2,b2)\n", "test_eq(a3,b3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`get_random_states` and `set_random_states` are useful for storing a state so you can go back to it later. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def get_random_states():\n", " \"Gets states for `random`, `torch`, and `numpy` random number generators\"\n", " return {'random_state':random.getstate(),\n", " 'numpy_state':np.random.get_state(),\n", " 'torch_state':torch.get_rng_state(),\n", " 'torch_cuda_state':torch.cuda.get_rng_state_all(),\n", " 'torch_deterministic':torch.backends.cudnn.deterministic,\n", " 'torch_benchmark':torch.backends.cudnn.benchmark}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def set_random_states(random_state,numpy_state,torch_state,torch_cuda_state,torch_deterministic,torch_benchmark):\n", " \"Set states for `random`, `torch`, and `numpy` random number generators\"\n", " random.setstate(random_state)\n", " np.random.set_state(numpy_state)\n", " torch.set_rng_state(torch_state)\n", " torch.cuda.set_rng_state_all(torch_cuda_state)\n", " torch.backends.cudnn.deterministic=torch_deterministic\n", " torch.backends.cudnn.benchmark=torch_benchmark" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below notice that the old values and rewinded values are the same because we were able to return to the previous state. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "olds: 0.435 0.134 0.023\n", "news: 0.246 0.363 0.227\n", "rewinds: 0.435 0.134 0.023\n" ] } ], "source": [ "old_states = get_random_states()\n", "olds = (random.random(),np.random.random(),torch.rand(()))\n", "news = (random.random(),np.random.random(),torch.rand(()))\n", "set_random_states(**old_states)\n", "rewinds = (random.random(),np.random.random(),torch.rand(()))\n", "\n", "print('olds: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*olds))\n", "print('news: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*news))\n", "print('rewinds: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*rewinds))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_ne(olds,news)\n", "test_eq(olds,rewinds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In `no_random` we combine the ideas of rewinding state with `get_random_states` and `set_random_states` with the ability to `set_seed` and create a context manager that can allow us to control randomness in a portion of our code. \n", "\n", "Note: Similar to `torch.random.fork_rng`, but also with `numpy` and `random`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@contextmanager\n", "def no_random(seed=42,reproducible=True):\n", " \"Stores and retrieves state of random number generators. Sets random seed for `random`, `torch`, and `numpy`.\"\n", " states = get_random_states()\n", " set_seed(seed,reproducible=reproducible)\n", " try:\n", " yield #we are managing global variables\n", " finally:\n", " set_random_states(**states)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are some examples on how we can use `no_random` to control the randomness within a block of code. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "olds: 0.246 0.363 0.227\n", "new1: 0.639 0.375 0.882\n", "new2: 0.639 0.375 0.882\n", "seeded1: 0.146 0.543 0.112\n", "seeded2: 0.146 0.543 0.112\n", "rewinds: 0.246 0.363 0.227\n" ] } ], "source": [ "states=get_random_states()\n", "olds = (random.random(),np.random.random(),torch.rand(()))\n", "set_random_states(**states) #rewinding above random calls\n", "\n", "with no_random():\n", " new1 = (random.random(),np.random.random(),torch.rand(()))\n", "with no_random():\n", " new2 = (random.random(),np.random.random(),torch.rand(()))\n", "with no_random(seed=100):\n", " seeded1 = (random.random(),np.random.random(),torch.rand(()))\n", "with no_random(seed=100):\n", " seeded2 = (random.random(),np.random.random(),torch.rand(()))\n", " \n", "rewinds = (random.random(),np.random.random(),torch.rand(()))\n", "\n", "print('olds: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*olds))\n", "print('new1: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*new1))\n", "print('new2: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*new2))\n", "print('seeded1: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*seeded1))\n", "print('seeded2: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*seeded2))\n", "print('rewinds: {0:3.3f} {1:3.3f} {2:3.3f}'.format(*rewinds))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that olds, and rewinds are alos both equal to each other. From this we can see that everything in the `with` blocks did not update the state outside of the block. Inside of the block, the state is reset for any particular seed, so for the same seed you should get the same random number generator results. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: It is important to remember that classes like `Dataloader` have internal random number generators, and `no_random` will have no effect on those random number generators." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_ne(olds,new1)\n", "test_eq(new1,new2)\n", "test_ne(new1,seeded1)\n", "test_eq(seeded1,seeded2)\n", "test_eq(olds,rewinds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def unsqueeze(x, dim=-1, n=1):\n", " \"Same as `torch.unsqueeze` but can add `n` dims\"\n", " for _ in range(n): x = x.unsqueeze(dim)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1])\n", "t2 = unsqueeze(t, n=2)\n", "test_eq(t2,t[:,None,None])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def unsqueeze_(x, dim=-1, n=1):\n", " \"Same as `torch.unsqueeze_` but can add `n` dims\"\n", " for _ in range(n): x.unsqueeze_(dim)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1])\n", "unsqueeze_(t, n=2)\n", "test_eq(t, tensor([1]).view(1,1,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2(*args, **kwargs))\n", "def _fa_rebuild_qtensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def apply(func, x, *args, **kwargs):\n", " \"Apply `func` recursively to `x`, passing on args\"\n", " if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])\n", " if isinstance(x,(dict,MutableMapping)): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}\n", " res = func(x, *args, **kwargs)\n", " return res if x is None else retain_type(res, x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def maybe_gather(x, axis=0):\n", " \"Gather copies of `x` on `axis` (if training is distributed)\"\n", " if num_distrib()<=1: return x\n", " ndim = x.ndim\n", " res = [x.new_zeros(*x.shape if ndim > 0 else (1,)) for _ in range(num_distrib())]\n", " torch.distributed.all_gather(res, x.contiguous() if ndim > 0 else x[None])\n", " return torch.cat(res, dim=axis) if ndim > 0 else torch.cat(res, dim=axis).mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_detach(b, cpu=True, gather=True):\n", " \"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`.\"\n", " def _inner(x, cpu=True, gather=True):\n", " if not isinstance(x,Tensor): return x\n", " x = x.detach()\n", " if gather: x = maybe_gather(x)\n", " return x.cpu() if cpu else x\n", " return apply(_inner, b, cpu=cpu, gather=gather)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`gather` only applies during distributed training and the result tensor will be the one gathered across processes if `gather=True` (as a result, the batch size will be multiplied by the number of processes)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_half(b):\n", " \"Recursively map floating point tensors in `b ` to FP16.\"\n", " return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_float(b):\n", " \"Recursively map floating point tensors in `b ` to float.\"\n", " return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "# None: True if available; True: error if not available; False: use CPU\n", "defaults.use_cuda = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _has_mps():\n", " if nested_attr(torch, 'backends.mps.is_available', noop)(): return True\n", " return nested_attr(torch, 'backends.mps.is_built', False)()\n", "\n", "def default_device(use=-1):\n", " \"Return or set default device; `use_cuda`: -1 - CUDA/mps if available; True - error if not available; False - CPU\"\n", " if use == -1: use = defaults.use_cuda\n", " else: defaults.use_cuda=use\n", " if use is None:\n", " if torch.cuda.is_available() or _has_mps(): use = True\n", " if use:\n", " if torch.cuda.is_available(): return torch.device(torch.cuda.current_device())\n", " if _has_mps(): return torch.device('mps')\n", " return torch.device('cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|cuda\n", "if torch.cuda.is_available():\n", " _td = torch.device(torch.cuda.current_device())\n", " test_eq(default_device(-1), _td)\n", " test_eq(default_device(True), _td)\n", "else:\n", " test_eq(default_device(False), torch.device('cpu'))\n", "default_device(-1);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_device(b, device=None, non_blocking=False):\n", " \"Recursively put `b` on `device`.\"\n", " if defaults.use_cuda==False: device='cpu'\n", " elif device is None: device=default_device()\n", " def _inner(o):\n", " # ToDo: add TensorDict when released\n", " if isinstance(o,Tensor): return o.to(device, non_blocking=non_blocking)\n", " return o\n", " return apply(_inner, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = to_device((3,(tensor(3),tensor(2))))\n", "t1,(t2,t3) = t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|cuda\n", "if torch.cuda.is_available():\n", " test_eq_type(t,(3,(tensor(3).cuda(),tensor(2).cuda())))\n", " test_eq(t2.type(), \"torch.cuda.LongTensor\")\n", " test_eq(t3.type(), \"torch.cuda.LongTensor\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_cpu(b):\n", " \"Recursively map tensors in `b ` to the cpu.\"\n", " return to_device(b,'cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t3 = to_cpu(t3)\n", "test_eq(t3.type(), \"torch.LongTensor\")\n", "test_eq(t3, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_np(x):\n", " \"Convert a tensor to a numpy array.\"\n", " return apply(lambda o: o.data.cpu().numpy(), x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t3 = to_np(t3)\n", "test_eq(type(t3), np.ndarray)\n", "test_eq(t3, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def to_concat(xs, dim=0):\n", " \"Concat the element in `xs` (recursively if they are tuples/lists of tensors)\"\n", " if not xs: return xs\n", " if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])\n", " if isinstance(xs[0],dict): return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}\n", " #We may receive xs that are not concatenable (inputs of a text classifier for instance),\n", " # in this case we return a big list\n", " try: return retain_type(torch.cat(xs, dim=dim), xs[0])\n", " except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])\n", " for i in range_of(o_)) for o_ in xs], L())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(to_concat([tensor([1,2]), tensor([3,4])]), tensor([1,2,3,4]))\n", "test_eq(to_concat([tensor([[1,2]]), tensor([[3,4]])], dim=1), tensor([[1,2,3,4]]))\n", "test_eq_type(to_concat([(tensor([1,2]), tensor([3,4])), (tensor([3,4]), tensor([5,6]))]), (tensor([1,2,3,4]), tensor([3,4,5,6])))\n", "test_eq_type(to_concat([[tensor([1,2]), tensor([3,4])], [tensor([3,4]), tensor([5,6])]]), [tensor([1,2,3,4]), tensor([3,4,5,6])])\n", "test_eq_type(to_concat([(tensor([1,2]),), (tensor([3,4]),)]), (tensor([1,2,3,4]),))\n", "\n", "test_eq(to_concat([tensor([[1,2]]), tensor([[3,4], [5,6]])], dim=1), [tensor([1]),tensor([3, 5]),tensor([4, 6])])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(type(to_concat([dict(foo=tensor([1,2]), bar=tensor(3,4))])), dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensor subtypes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|exporti\n", "# Parsed PyTorch versions for faster version checking\n", "_torch_version = parse(torch.__version__)\n", "_torch_20 = parse('2.0')\n", "_torch_113 = parse('1.13')\n", "_torch_112 = parse('1.12')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def set_meta(self:Tensor, x, as_copy=False):\n", " \"Set all metadata in `__dict__`\"\n", " if not hasattr(x,'__dict__'): return\n", " # XXX: change to `deepcopy` once PyTorch 1.7.1 is out, and check nb 23 segmentation fit works\n", " self.__dict__ = copy(x.__dict__) if as_copy else x.__dict__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "if not hasattr(torch,'as_subclass'): torch.as_subclass = torch.Tensor.as_subclass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def as_subclass(self:Tensor, typ):\n", " \"Cast to `typ` and include `__dict__` and meta\"\n", " return retain_meta(self, torch.as_subclass(self, typ))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Tensor.set_meta` and `Tensor.as_subclass` work together to maintain `__dict__` after casting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _T(Tensor): pass\n", "t = tensor(1.).requires_grad_()\n", "t.img_size = 1\n", "t2 = t.as_subclass(_T)\n", "test_eq(t.img_size, t2.img_size)\n", "test_eq(t2.img_size, 1)\n", "assert(t2.requires_grad_)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _torch_handled(args, opt, func):\n", " if func not in opt: return False\n", " for oks in opt[func]:\n", " if all(isinstance(arg,ok) for arg,ok in zip(args,oks) if ok): return True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "# from https://github.com/pytorch/pytorch/blob/13c975684a220ec096216ec6468ccd0dc90ff50a/torch/_tensor.py#L34\n", "def _rebuild_from_type(func, type, args, dict):\n", " ret = func(*args).as_subclass(type)\n", " ret.__dict__ = dict\n", " return ret" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _find_args(x):\n", " x0 = x[0] if is_listy(x[0]) and x[0] else x\n", " return [a for a in x0 if hasattr(a,'__dict__')]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TensorBase(Tensor):\n", " \"A `Tensor` which support subclass pickling, and maintains metadata when casting or after methods\"\n", " debug,_opt = False,defaultdict(list)\n", " def __new__(cls, x, **kwargs):\n", " res = cast(tensor(x), cls)\n", " for k,v in kwargs.items(): setattr(res, k, v)\n", " return res\n", "\n", " @classmethod\n", " def _before_cast(cls, x): return tensor(x)\n", " def __repr__(self): return re.sub('tensor', self.__class__.__name__, super().__repr__())\n", "\n", " def __reduce_ex__(self, proto):\n", " if _torch_version >= _torch_20:\n", " return super().__reduce_ex__(proto)\n", " else:\n", " torch.utils.hooks.warn_if_has_hooks(self)\n", " args = (self.storage(), self.storage_offset(), tuple(self.size()), self.stride())\n", " if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())\n", " args = args + (self.requires_grad, OrderedDict())\n", " f = torch._utils._rebuild_qtensor if self.is_quantized else torch._utils._rebuild_tensor_v2\n", " return (_rebuild_from_type, (f, type(self), args, self.__dict__))\n", "\n", " @classmethod\n", " def register_func(cls, func, *oks): cls._opt[func].append(oks)\n", "\n", " @classmethod\n", " def __torch_function__(cls, func, types, args=(), kwargs=None):\n", " if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)\n", " if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)\n", " res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))\n", " dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))\n", " if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)\n", " elif dict_objs and is_listy(res): [r.set_meta(dict_objs[0],as_copy=True) for r in res if issubclass(type(r),TensorBase)]\n", " return res\n", "\n", " def new_tensor(self, size, dtype=None, device=None, requires_grad=False):\n", " cls = type(self)\n", " return self.as_subclass(Tensor).new_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)\n", "\n", " def new_ones(self, data, dtype=None, device=None, requires_grad=False):\n", " cls = type(self)\n", " return self.as_subclass(Tensor).new_ones(data, dtype=dtype, device=device, requires_grad=requires_grad).as_subclass(cls)\n", "\n", " def new(self, x=None):\n", " cls = type(self)\n", " res = self.as_subclass(Tensor).new() if x is None else self.as_subclass(Tensor).new(x)\n", " return res.as_subclass(cls)\n", "\n", " def requires_grad_(self, requires_grad=True):\n", " # Workaround https://github.com/pytorch/pytorch/issues/50219\n", " self.requires_grad = requires_grad\n", " return self\n", "\n", " def clone(self, *, memory_format=None):\n", " cls = type(self)\n", " return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls)\n", "\n", " def new_empty(self, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):\n", " cls = type(self)\n", " if _torch_version < _torch_113 and layout is None:\n", " layout = torch.strided\n", " if _torch_version < _torch_112:\n", " return super().new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)\n", " return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)\n", "\n", " def new_empty(self, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):\n", " cls = type(self)\n", " if _torch_version < _torch_113 and layout is None:\n", " layout = torch.strided\n", " if _torch_version < _torch_112:\n", " return super().new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)\n", " return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TensorBase` hooks into `__torch_function__` to ensure metadata is not lost. To see all functions being called, set `debug`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorBase(0.5000)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = TensorBase(1)\n", "TensorBase.debug=True\n", "1/(a+1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TensorBase` and its subclasses also allow for passing through metadata size as img_size..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data._utils.collate import default_collate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = TensorBase(1,img_size=(128,128))\n", "test_eq(a.img_size,(128,128))\n", "b = cast(a,TensorBase)\n", "test_eq(b.img_size,(128,128))\n", "test_eq(torch.stack([a,b],0).img_size,(128,128))\n", "\n", "test_eq(default_collate([a,b]).img_size,(128,128))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "_TImage2([2.])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class _TImage(TensorBase): pass\n", "class _TImage2(_TImage): pass\n", "t1 = _TImage([1.])\n", "t2 = _TImage2([1.])\n", "t2+t1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _T(TensorBase): pass\n", "\n", "t = _T(range(5))\n", "test_eq(t[0], 0)\n", "test_eq_type(t+1, _T(range(1,6)))\n", "test_eq(repr(t), '_T([0, 1, 2, 3, 4])')\n", "test_eq_type(t[_T([False,False,True,True,True])], _T([2,3,4]))\n", "test_eq_type(t[_T([2,3,4])], _T([2,3,4]))\n", "test_eq(type(pickle.loads(pickle.dumps(t))), _T)\n", "test_eq_type(t.new_ones(1), _T([1]))\n", "test_eq_type(t.new_tensor([1,2]), _T([1,2]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1,2,3])\n", "m = TensorBase([False,True,True])\n", "test_eq(t[m], tensor([2,3]))\n", "t = tensor([[1,2,3],[1,2,3]])\n", "m = cast(tensor([[False,True,True],\n", " [False,True,True]]), TensorBase)\n", "test_eq(t[m], tensor([2,3,2,3]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([[1,2,3],[1,2,3]])\n", "t.img_size = 1\n", "t2 = cast(t, TensorBase)\n", "test_eq(t2.img_size, t.img_size)\n", "x = retain_type(tensor([4,5,6]), t2)\n", "test_eq(x.img_size, t.img_size)\n", "t3 = TensorBase([[1,2,3],[1,2,3]], img_size=1)\n", "test_eq(t3.img_size, t.img_size)\n", "t4 = t2+1\n", "t4.img_size = 2\n", "test_eq(t2.img_size, 1)\n", "test_eq(t4.img_size, 2)\n", "# this will fail with `Tensor` but works with `TensorBase`\n", "test_eq(pickle.loads(pickle.dumps(t2)).img_size, t2.img_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "# test of https://github.com/pytorch/pytorch/issues/47186\n", "class _T(TensorBase): ...\n", "t = _T([1.])\n", "test_eq_type(t.new([1,2]), _T([1.,2.]))\n", "test_eq_type(t.new(), _T([]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "# test of https://github.com/pytorch/pytorch/issues/50219\n", "x = TensorBase(torch.rand(4,3,16,16))\n", "with torch.no_grad():\n", " y = x.requires_grad_()\n", " assert y.requires_grad and x.requires_grad" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "x = TensorBase(torch.rand(4,3,16,16))\n", "x.test = 'test metadata'\n", "y = deepcopy(x)\n", "assert hasattr(y, 'test') and y.test == x.test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TensorImageBase(TensorBase):\n", " _show_args = ArrayImageBase._show_args\n", " def show(self, ctx=None, **kwargs):\n", " return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TensorImage(TensorImageBase): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TensorImageBW(TensorImage): _show_args = ArrayImageBW._show_args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TensorMask(TensorImageBase):\n", " _show_args = ArrayMask._show_args\n", "\n", " def show(self, ctx=None, **kwargs):\n", " codes = getattr(self, 'codes', None)\n", " if codes is not None: kwargs = merge({'vmin': 0, 'vmax': len(codes)}, kwargs)\n", " return super().show(ctx=ctx, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "for o in Tensor.__getitem__, Tensor.__ne__,Tensor.__eq__,Tensor.add,Tensor.sub,Tensor.mul,Tensor.div,Tensor.__rsub__,Tensor.__radd__,Tensor.matmul,Tensor.bmm:\n", " TensorBase.register_func(o, TensorMask, TensorImageBase)\n", " TensorBase.register_func(o, TensorImageBase, TensorMask)\n", "\n", "TensorMask.register_func(torch.einsum, str, TensorImageBase, TensorMask)\n", "TensorMask.register_func(torch.einsum, str, TensorMask, TensorImageBase)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im = Image.open(TEST_IMAGE)\n", "im_t = cast(array(im), TensorImage)\n", "test_eq(type(im_t), TensorImage)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "im_t2 = cast(tensor(1), TensorMask)\n", "test_eq(type(im_t2), TensorMask)\n", "test_eq(im_t2, tensor(1))\n", "ax = im_t.show(figsize=(2,2))\n", "_ =(im_t == im_t2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Operations between `TensorMask` and `TensorImageBase` objects return the type of the `TensorImageBase` object:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = TensorMask([1,2])\n", "test_eq_type(TensorImage(1)+a, TensorImage([2,3]))\n", "test_eq_type(1-a, TensorMask([0,-1]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide (last test of to_concat)\n", "test_eq_type(to_concat([TensorImage([1,2]), TensorImage([3,4])]), TensorImage([1,2,3,4]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TensorFlowField(TensorBase): pass\n", "TensorImage.register_func(F.grid_sample, TensorImageBase, TensorFlowField)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t1 = TensorImage([1.]).view(1,1,1,1)\n", "t2 = TensorFlowField([1.,1.]).view(1,1,1,2)\n", "test_eq_type(F.grid_sample(t1, t2), TensorImage([[[[0.25]]]]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export \n", "class TensorCategory(TensorBase): pass\n", "\n", "TensorBase.register_func(Tensor.__getitem__, TensorImageBase, TensorCategory)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tc = TensorCategory([1,2,3])\n", "mask_t = TensorMask([0,2,4,5])\n", "im_t = TensorImage([0,2,4,5])\n", "test_eq(mask_t[tc], tensor([2,4,5]))\n", "test_eq(im_t[tc], tensor([2,4,5]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export \n", "class TensorMultiCategory(TensorCategory): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class TitledTensorScalar(TensorBase):\n", " \"A tensor containing a scalar that has a `show` method\"\n", " def show(self, **kwargs): show_title(self.item(), **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## L -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def tensored(self:L):\n", " \"`mapped(tensor)`\"\n", " return self.map(tensor)\n", "@patch\n", "def stack(self:L, dim=0):\n", " \"Same as `torch.stack`\"\n", " return torch.stack(list(self.tensored()), dim=dim)\n", "@patch\n", "def cat (self:L, dim=0):\n", " \"Same as `torch.cat`\"\n", " return torch.cat (list(self.tensored()), dim=dim)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "#### L.tensored\n", "\n", "> L.tensored ()\n", "\n", "`mapped(tensor)`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(L.tensored)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are shortcuts for `torch.stack` and `torch.cat` if your `L` contains tensors or something convertible. You can manually convert with `tensored`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = L(([1,2],[3,4]))\n", "test_eq(t.tensored(), [tensor(1,2),tensor(3,4)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "#### L.stack\n", "\n", "> L.stack (dim=0)\n", "\n", "Same as `torch.stack`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(L.stack)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(t.stack(), tensor([[1,2],[3,4]]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "#### L.cat\n", "\n", "> L.cat (dim=0)\n", "\n", "Same as `torch.cat`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(L.cat)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(t.cat(), tensor([1,2,3,4]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chunks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def concat(*ls):\n", " \"Concatenate tensors, arrays, lists, or tuples\"\n", " if not len(ls): return []\n", " it = ls[0]\n", " if isinstance(it,torch.Tensor): res = torch.cat(ls)\n", " elif isinstance(it,ndarray): res = np.concatenate(ls)\n", " else:\n", " res = itertools.chain.from_iterable(map(L,ls))\n", " if isinstance(it,(tuple,list)): res = type(it)(res)\n", " else: res = L(res)\n", " return retain_type(res, it)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a,b,c = [1],[1,2],[1,1,2]\n", "test_eq(concat(a,b), c)\n", "test_eq_type(concat(tuple (a),tuple (b)), tuple (c))\n", "test_eq_type(concat(array (a),array (b)), array (c))\n", "test_eq_type(concat(tensor(a),tensor(b)), tensor(c))\n", "test_eq_type(concat(TensorBase(a),TensorBase(b)), TensorBase(c))\n", "test_eq_type(concat([1,1],1), [1,1,1])\n", "test_eq_type(concat(1,1,1), L(1,1,1))\n", "test_eq_type(concat(L(1,2),1), L(1,2,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class Chunks:\n", " \"Slice and int indexing into a list of lists\"\n", " def __init__(self, chunks, lens=None):\n", " self.chunks = chunks\n", " self.lens = L(map(len,self.chunks) if lens is None else lens)\n", " self.cumlens = np.cumsum(0+self.lens)\n", " self.totlen = self.cumlens[-1]\n", "\n", " def __getitem__(self,i):\n", " if isinstance(i,slice): return retain_type(self.getslice(i), old=self.chunks[0])\n", " di,idx = self.doc_idx(i)\n", " return retain_type(self.chunks[di][idx], old=self.chunks[0])\n", "\n", " def getslice(self, i):\n", " st_d,st_i = self.doc_idx(ifnone(i.start,0))\n", " en_d,en_i = self.doc_idx(ifnone(i.stop,self.totlen+1))\n", " res = [self.chunks[st_d][st_i:(en_i if st_d==en_d else sys.maxsize)]]\n", " for b in range(st_d+1,en_d): res.append(self.chunks[b])\n", " if st_d!=en_d and en_d 0: o = t+'\\n'+str(o)\n", " ax.set_title(o, color=color)\n", " elif isinstance(ax, pd.Series):\n", " while label in ax: label += '_'\n", " ax = pd.concat([ax,pd.Series({label: o})])\n", " return ax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_stdout(lambda: show_title(\"title\"), \"title\")\n", "# ensure that col names are unique when showing to a pandas series\n", "assert show_title(\"title\", ctx=pd.Series(dict(a=1)), label='a').equals(pd.Series(dict(a=1,a_='title')))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ShowTitle:\n", " \"Base class that adds a simple `show`\"\n", " _show_args = {'label': 'text'}\n", " def show(self, ctx=None, **kwargs):\n", " \"Show self\"\n", " return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))\n", "\n", "class TitledInt(Int, ShowTitle):\n", " _show_args = {'label': 'text'}\n", " def show(self, ctx=None, **kwargs):\n", " \"Show self\"\n", " return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))\n", "\n", "class TitledFloat(Float, ShowTitle):\n", " _show_args = {'label': 'text'}\n", " def show(self, ctx=None, **kwargs):\n", " \"Show self\"\n", " return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))\n", "\n", "class TitledStr(Str, ShowTitle):\n", " _show_args = {'label': 'text'}\n", " def show(self, ctx=None, **kwargs):\n", " \"Show self\"\n", " return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))\n", "\n", "class TitledTuple(fastuple, ShowTitle):\n", " _show_args = {'label': 'text'}\n", " def show(self, ctx=None, **kwargs):\n", " \"Show self\"\n", " return show_title(str(self), ctx=ctx, **merge(self._show_args, kwargs))\n", "\n", "add_docs(TitledInt, \"An `int` with `show`\"); add_docs(TitledStr, \"An `str` with `show`\");\n", "add_docs(TitledFloat, \"A `float` with `show`\"); add_docs(TitledTuple, \"A `fastuple` with `show`\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### TitledInt\n", "\n", "\n", "\n", "An `int` with `show`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(TitledInt, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### TitledStr\n", "\n", "\n", "\n", "An `str` with `show`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(TitledStr, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### TitledFloat\n", "\n", "> TitledFloat (x=0)\n", "\n", "A `float` with `show`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(TitledFloat, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_stdout(lambda: TitledStr('s').show(), 's')\n", "test_stdout(lambda: TitledInt(1).show(), '1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### TitledTuple\n", "\n", "> TitledTuple (x=None, *rest)\n", "\n", "A `fastuple` with `show`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(TitledTuple, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "df = pd.DataFrame(index = range(1))\n", "row = df.iloc[0]\n", "x = TitledFloat(2.56)\n", "row = x.show(ctx=row, label='lbl')\n", "test_eq(float(row.lbl), 2.56)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def truncate(self:TitledStr, n):\n", " \"Truncate self to `n`\"\n", " words = self.split(' ')[:n]\n", " return TitledStr(' '.join(words))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "if not hasattr(pd.DataFrame,'_old_init'): pd.DataFrame._old_init = pd.DataFrame.__init__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def __init__(self:pd.DataFrame, data=None, index=None, columns=None, dtype=None, copy=None):\n", " if data is not None and isinstance(data, Tensor): data = to_np(data)\n", " self._old_init(data, index=index, columns=columns, dtype=dtype, copy=copy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def get_empty_df(n):\n", " \"Return `n` empty rows of a dataframe\"\n", " df = pd.DataFrame(index = range(n))\n", " return [df.iloc[i] for i in range(n)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def display_df(df):\n", " \"Display `df` in a notebook or defaults to print\"\n", " try: from IPython.display import display, HTML\n", " except: return print(df)\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def get_first(c):\n", " \"Get the first element of c, even if c is a dataframe\"\n", " return getattr(c, 'iloc', c)[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def one_param(m):\n", " \"First parameter in `m`\"\n", " return first(m.parameters())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def item_find(x, idx=0):\n", " \"Recursively takes the `idx`-th element of `x`\"\n", " if is_listy(x): return item_find(x[idx])\n", " if isinstance(x,dict):\n", " key = list(x.keys())[idx] if isinstance(idx, int) else idx\n", " return item_find(x[key])\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def find_device(b):\n", " \"Recursively search the device of `b`.\"\n", " return item_find(b).device" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t2 = to_device(tensor(0))\n", "dev = default_device()\n", "test_eq(find_device(t2), dev)\n", "test_eq(find_device([t2,t2]), dev)\n", "test_eq(find_device({'a':t2,'b':t2}), dev)\n", "test_eq(find_device({'a':[[t2],[t2]],'b':t2}), dev)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def find_bs(b):\n", " \"Recursively search the batch size of `b`.\"\n", " res = item_find(b)\n", " if not hasattr(res, \"shape\"): return len(b)\n", " return res.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "x1 = [1,2,3]\n", "test_eq(find_bs(x1), 3)\n", "test_eq(find_bs(x), 4)\n", "test_eq(find_bs((x,x)), 4)\n", "test_eq(find_bs([x, x]), 4)\n", "test_eq(find_bs({'a':x,'b':x}), 4)\n", "test_eq(find_bs({'a':[[x],[x]],'b':x}), 4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def np_func(f):\n", " \"Convert a function taking and returning numpy arrays to one taking and returning tensors\"\n", " def _inner(*args, **kwargs):\n", " nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]\n", " return tensor(f(*nargs, **kwargs))\n", " functools.update_wrapper(_inner, f)\n", " return _inner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This decorator is particularly useful for using numpy functions as fastai metrics, for instance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@np_func\n", "def f1(inp,targ): return f1_score(targ, inp)\n", "\n", "a1,a2 = array([0,1,1]),array([1,0,1])\n", "t = f1(tensor(a1),tensor(a2))\n", "test_eq(f1_score(a1,a2), t)\n", "assert isinstance(t,Tensor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class Module(nn.Module, metaclass=PrePostInitMeta):\n", " \"Same as `nn.Module`, but no need for subclasses to call `super().__init__`\"\n", " def __pre_init__(self, *args, **kwargs): super().__init__()\n", " def __init__(self): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### Module\n", "\n", "> Module ()\n", "\n", "Same as `nn.Module`, but no need for subclasses to call `super().__init__`" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(Module, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.0832], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class _T(Module):\n", " def __init__(self): self.f = nn.Linear(1,1)\n", " def forward(self,x): return self.f(x)\n", "\n", "t = _T()\n", "t(tensor([1.]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from torch.nn.parallel import DistributedDataParallel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def get_model(model):\n", " \"Return the model maybe wrapped inside `model`.\"\n", " return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def one_hot(x, c):\n", " \"One-hot encode `x` with `c` classes.\"\n", " res = torch.zeros(c, dtype=torch.uint8)\n", " if isinstance(x, Tensor) and x.numel()>0: res[x] = 1.\n", " else: res[list(L(x, use_list=None))] = 1.\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(one_hot([1,4], 5), tensor(0,1,0,0,1).byte())\n", "test_eq(one_hot(torch.tensor([]), 5), tensor(0,0,0,0,0).byte())\n", "test_eq(one_hot(2, 5), tensor(0,0,1,0,0).byte())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def one_hot_decode(x, vocab=None):\n", " return L(vocab[i] if vocab else i for i,x_ in enumerate(x) if x_==1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(one_hot_decode(tensor(0,1,0,0,1)), [1,4])\n", "test_eq(one_hot_decode(tensor(0,0,0,0,0)), [ ])\n", "test_eq(one_hot_decode(tensor(0,0,1,0,0)), [2 ])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def params(m):\n", " \"Return all parameters of `m`\"\n", " return [p for p in m.parameters()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def trainable_params(m):\n", " \"Return all trainable parameters of `m`\"\n", " return [p for p in m.parameters() if p.requires_grad]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Linear(4,5)\n", "test_eq(trainable_params(m), [m.weight, m.bias])\n", "m.weight.requires_grad_(False)\n", "test_eq(trainable_params(m), [m.bias])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def norm_bias_params(m, with_bias=True):\n", " \"Return all bias and BatchNorm parameters\"\n", " if isinstance(m, norm_types): return L(m.parameters())\n", " res = L(m.children()).map(norm_bias_params, with_bias=with_bias).concat()\n", " if with_bias and getattr(m, 'bias', None) is not None: res.append(m.bias)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for norm_func in [nn.BatchNorm1d, partial(nn.InstanceNorm1d, affine=True)]:\n", " model = nn.Sequential(nn.Linear(10,20), norm_func(20), nn.Conv1d(3,4, 3))\n", " test_eq(norm_bias_params(model), [model[0].bias, model[1].weight, model[1].bias, model[2].bias])\n", " model = nn.ModuleList([nn.Linear(10,20, bias=False), nn.Sequential(norm_func(20), nn.Conv1d(3,4,3))])\n", " test_eq(norm_bias_params(model), [model[1][0].weight, model[1][0].bias, model[1][1].bias])\n", " model = nn.ModuleList([nn.Linear(10,20), nn.Sequential(norm_func(20), nn.Conv1d(3,4,3))])\n", " test_eq(norm_bias_params(model, with_bias=False), [model[1][0].weight, model[1][0].bias])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def batch_to_samples(b, max_n=10):\n", " \"'Transposes' a batch to (at most `max_n`) samples\"\n", " if isinstance(b, Tensor): return retain_types(list(b[:max_n]), [b])\n", " else:\n", " res = L(b).map(partial(batch_to_samples,max_n=max_n))\n", " return retain_types(res.zip(), [b])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([1,2,3])\n", "test_eq(batch_to_samples([t,t+1], max_n=2), ([1,2],[2,3]))\n", "test_eq(batch_to_samples(tensor([1,2,3]), 10), [1, 2, 3])\n", "test_eq(batch_to_samples([tensor([1,2,3]), tensor([4,5,6])], 10), [(1, 4), (2, 5), (3, 6)])\n", "test_eq(batch_to_samples([tensor([1,2,3]), tensor([4,5,6])], 2), [(1, 4), (2, 5)])\n", "test_eq(batch_to_samples([tensor([1,2,3]), [tensor([4,5,6]),tensor([7,8,9])]], 10), \n", " [(1, (4, 7)), (2, (5, 8)), (3, (6, 9))])\n", "test_eq(batch_to_samples([tensor([1,2,3]), [tensor([4,5,6]),tensor([7,8,9])]], 2), [(1, (4, 7)), (2, (5, 8))])\n", "\n", "t = fastuple(tensor([1,2,3]),TensorBase([2,3,4]))\n", "test_eq_type(batch_to_samples(t)[0][1], TensorBase(2))\n", "test_eq(batch_to_samples(t).map(type), [fastuple]*3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def interp_1d(x:Tensor, xp, fp):\n", " \"Same as `np.interp`\"\n", " slopes = (fp[1:]-fp[:-1])/(xp[1:]-xp[:-1])\n", " incx = fp[:-1] - (slopes*xp[:-1])\n", " locs = (x[:,None]>=xp[None,:]).long().sum(1)-1\n", " locs = locs.clamp(0,len(slopes)-1)\n", " return slopes[locs]*x + incx[locs]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "brks = tensor(0,1,2,4,8,64).float()\n", "ys = tensor(range_of(brks)).float()\n", "ys /= ys[-1].item()\n", "pts = tensor(0.2,0.5,0.8,3,5,63)\n", "\n", "preds = pts.interp_1d(brks, ys)\n", "test_close(preds.numpy(), np.interp(pts.numpy(), brks.numpy(), ys.numpy()))\n", "\n", "plt.scatter(brks,ys)\n", "plt.scatter(pts,preds)\n", "plt.legend(['breaks','preds']);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def pca(x:Tensor, k=2):\n", " \"Compute PCA of `x` with `k` dimensions.\"\n", " x = x-torch.mean(x,0)\n", " U,S,V = torch.svd(x.t())\n", " return torch.mm(x,U[:,:k])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def logit(x):\n", " \"Logit of `x`, clamped to avoid inf.\"\n", " x = x.clamp(1e-7, 1-1e-7)\n", " return -(1/x-1).log()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def num_distrib():\n", " \"Return the number of processes in distributed training (if applicable).\"\n", " return int(os.environ.get('WORLD_SIZE', 0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def rank_distrib():\n", " \"Return the distributed rank of this process (if applicable).\"\n", " return int(os.environ.get('RANK', 0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def distrib_barrier():\n", " \"Place a synchronization barrier in distributed training\"\n", " if num_distrib() > 1 and torch.distributed.is_initialized(): torch.distributed.barrier()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After calling this, ALL sub-processes in the pytorch process group must arrive here before proceeding." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "# Saving arrays requires pytables - optional dependency\n", "try: import tables\n", "except: pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _comp_filter(lib='lz4',lvl=3): return tables.Filters(complib=f'blosc:{lib}', complevel=lvl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def save_array(p:Path, o, complib='lz4', lvl=3):\n", " \"Save numpy array to a compressed `pytables` file, using compression level `lvl`\"\n", " if isinstance(o,Tensor): o = to_np(o)\n", " with tables.open_file(p, mode='w', filters=_comp_filter(lib=complib,lvl=lvl)) as f: f.create_carray('/', 'data', obj=o)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compression lib can be any of: blosclz, lz4, lz4hc, snappy, zlib or zstd." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def load_array(p:Path):\n", " \"Save numpy array to a `pytables` file\"\n", " with tables.open_file(p, 'r') as f: return f.root.data.read()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def base_doc(elt):\n", " \"Print a base documentation of `elt`\"\n", " name = getattr(elt, '__qualname__', getattr(elt, '__name__', ''))\n", " print(f'{name}{inspect.signature(elt)}\\n{inspect.getdoc(elt)}\\n')\n", " print('To get a prettier result with hyperlinks to source code and documentation, install nbdev: pip install nbdev')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def doc(elt):\n", " \"Try to use doc form nbdev and fall back to `base_doc`\"\n", " try:\n", " from nbdev.showdoc import doc\n", " doc(elt)\n", " except: base_doc(elt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def nested_reorder(t, idxs):\n", " \"Reorder all tensors in `t` using `idxs`\"\n", " if isinstance(t, (Tensor,L)): return t[idxs]\n", " elif is_listy(t): return type(t)(nested_reorder(t_, idxs) for t_ in t)\n", " if t is None: return t\n", " raise TypeError(f\"Expected tensor, tuple, list or L but got {type(t)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = tensor([0,1,2,3,4,5])\n", "idxs = tensor([2,5,1,0,3,4])\n", "test_eq_type(nested_reorder(([x], x), idxs), ([idxs], idxs))\n", "\n", "y = L(0,1,2,3,4,5)\n", "z = L(i.item() for i in idxs)\n", "test_eq_type(nested_reorder((y, x), idxs), (z,idxs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def flatten_check(inp, targ):\n", " \"Check that `inp` and `targ` have the same number of elements and flatten them.\"\n", " inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)\n", " test_eq(len(inp), len(targ))\n", " return inp,targ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(5,4),torch.randn(20)\n", "x1,x2 = flatten_check(x1,x2)\n", "test_eq(x1.shape, [20])\n", "test_eq(x2.shape, [20])\n", "x1,x2 = torch.randn(5,4),torch.randn(21)\n", "test_fail(lambda: flatten_check(x1,x2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Image helpers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def make_cross_image(bw=True):\n", " \"Create a tensor containing a cross image, either `bw` (True) or color\"\n", " if bw:\n", " im = torch.zeros(5,5)\n", " im[2,:] = 1.\n", " im[:,2] = 1.\n", " else:\n", " im = torch.zeros(3,5,5)\n", " im[0,2,:] = 1.\n", " im[1,:,2] = 1.\n", " return im" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI40lEQVR4nO3dT4ichR3G8efpJqJgwUPmELKh60GkQaiSIQj2FDzEKtqjgj0JuVSIUBDtzUOvxYuXYEVBUQQ9iFhEUGsLVp34r6ZRCJJiUMgEkeqloj49zBxiu7vzzuR959355fuBhZ3dycyD7nff2dnlHScRgDp+0vcAAO0iaqAYogaKIWqgGKIGitnVxY3u2bMnGxsbXdz0Je/EiRN9T5jLwYMH+55Q0pkzZ3T+/Hlv9rlOot7Y2NBoNOripi959qb/H3csvg66MRwOt/wcD7+BYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiGkVt+4jtT2yftv1A16MALG5m1LbXJD0i6RZJByTdZftA18MALKbJkfqQpNNJPk3yraRnJN3R7SwAi2oS9T5Jn11w+ez0Yz9i+6jtke3ReDxuax+AOTWJerPTV/7fq+olOZ5kmGQ4GAwufhmAhTSJ+qyk/RdcXpf0eTdzAFysJlG/I+ka21fbvkzSnZJe6HYWgEXNPJl/ku9s3yvpZUlrkh5LcrLzZQAW0ugVOpK8JOmljrcAaAF/UQYUQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDEzo7b9mO1ztj9axiAAF6fJkfpxSUc63gGgJTOjTvKGpC+XsAVAC/iZGiimtahtH7U9sj0aj8dt3SyAObUWdZLjSYZJhoPBoK2bBTAnHn4DxTT5ldbTkt6UdK3ts7bv6X4WgEXtmnWFJHctYwiAdvDwGyiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYpyk/Ru1279RAD+SxJt9nCM1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxcyM2vZ+26/ZPmX7pO1jyxgGYDEzz1Fme6+kvUnetf1TSSck/TrJP7f5N5yjDOjYwucoS/JFknen738t6ZSkfe3OA9CWXfNc2faGpBskvbXJ545KOtrOLACLanyKYNtXSvqLpD8keX7GdXn4DXTsok4RbHu3pOckPTUraAD9avJEmSU9IenLJPc1ulGO1EDntjpSN4n6l5L+Kukfkn6Yfvj3SV7a5t8QNdCxhaNeBFED3eNld4BLBFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8XMdTbRpg4ePKjRaNTFTV/yJmeXWh1dnIQD0nA43PJzHKmBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiZkZt+3Lbb9v+wPZJ2w8tYxiAxTQ5ndF/JB1O8o3t3ZL+ZvvPSf7e8TYAC5gZdSYnmfpmenH39I0TTwE7VKOfqW2v2X5f0jlJryR5q9NVABbWKOok3ye5XtK6pEO2r/vf69g+antkezQej1ueCaCpuZ79TvKVpNclHdnkc8eTDJMMB4NBO+sAzK3Js98D21dN379C0s2SPu54F4AFNXn2e6+kJ2yvafJN4NkkL3Y7C8Cimjz7/aGkG5awBUAL+IsyoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgmMZR216z/Z7tF7scBODizHOkPibpVFdDALSjUdS21yXdKunRbucAuFhNj9QPS7pf0g9bXcH2Udsj26PxeNzGNgALmBm17dsknUtyYrvrJTmeZJhkOBgMWhsIYD5NjtQ3Sbrd9hlJz0g6bPvJTlcBWNjMqJM8mGQ9yYakOyW9muTuzpcBWAi/pwaK2TXPlZO8Lun1TpYAaAVHaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGinGS9m/UHkv6V8s3u0fS+ZZvs0urtHeVtkqrtberrT9LsukZPjuJugu2R0mGfe9oapX2rtJWabX29rGVh99AMUQNFLNKUR/ve8CcVmnvKm2VVmvv0reuzM/UAJpZpSM1gAaIGihmJaK2fcT2J7ZP236g7z3bsf2Y7XO2P+p7yyy299t+zfYp2ydtH+t701ZsX277bdsfTLc+1PemJmyv2X7P9ovLus8dH7XtNUmPSLpF0gFJd9k+0O+qbT0u6UjfIxr6TtLvkvxc0o2SfruD/9v+R9LhJL+QdL2kI7Zv7HdSI8cknVrmHe74qCUdknQ6yadJvtXklTfv6HnTlpK8IenLvnc0keSLJO9O3/9aky++ff2u2lwmvple3D1929HP8tpel3SrpEeXeb+rEPU+SZ9dcPmsdugX3iqzvSHpBklv9TxlS9OHsu9LOifplSQ7duvUw5Lul/TDMu90FaL2Jh/b0d+hV43tKyU9J+m+JP/ue89Wknyf5HpJ65IO2b6u50lbsn2bpHNJTiz7vlch6rOS9l9weV3S5z1tKcf2bk2CfirJ833vaSLJV5q8+upOfu7iJkm32z6jyY+Mh20/uYw7XoWo35F0je2rbV+myQvfv9DzphJsW9KfJJ1K8se+92zH9sD2VdP3r5B0s6SPex21jSQPJllPsqHJ1+yrSe5exn3v+KiTfCfpXkkva/JEzrNJTva7amu2n5b0pqRrbZ+1fU/fm7Zxk6TfaHIUeX/69qu+R21hr6TXbH+oyTf6V5Is7ddEq4Q/EwWK2fFHagDzIWqgGKIGiiFqoBiiBoohaqAYogaK+S/20vv5FixpogAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(make_cross_image(), cmap=\"Greys\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI3klEQVR4nO3dQYic9R3G8efpJqLUgofmINnQeBCpBBoxBCGXEiykGvSqUE/CXiqsYBF7Kt6LeOklaLCgKIIeJBcJNEUEG7OJsRhXSxCLi8K2SNH0UIn+epg5pHZ3553Z95133yffDwzs7M6+8+Nlvvu+M7P8x1UlADl+0PcAANpF1EAYogbCEDUQhqiBMLu62KhtXlLvyt19DzCl830PkKuqvNH33cVbWkTdoaHt2Q0fdmjDZlFz+g2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EaRS17WO2P7Z92fZTXQ8FYHYTlzOyvSDpb5J+IWlN0jlJD1fVh1v8ztAW3RmOoe1ZljPqzHaWMzos6XJVfVJV30h6RdKDbQ4HoD1Not4r6bNrrq+Nv/c/bC/ZXrG90tZwAKbXZIngjQ7x/3cSWFUnJJ2QOP0G+tTkSL0mad811xclfd7NOAC2q0nU5yTdbvs22zdIekjSG92OBWBWE0+/q+qq7cckvSlpQdLJqrrU+WQAZsIndAzN0PYsb2l1hk/oAK4TRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EmRi17ZO2121/MI+BAGxPkyP1C5KOdTwHgJZMjLqq3pL05RxmAdACnlMDYXa1tSHbS5KW2toegNm4qibfyN4v6VRVHWi0UXvyRjGboe1Z9z1ArqracO9y+g2EafKW1suS3pF0h+012492PxaAWTU6/Z56o5x+d2doe5bT785w+g1cJ4gaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogTGsLD17rbkkrXWwYgCTp0BY/40gNhCFqIAxRA2GIGghD1EAYogbCEDUQhqiBMEQNhCFqIAxRA2GIGghD1EAYogbCEDUQhqiBMEQNhCFqIMzEqG3vs33G9qrtS7aX5zEYgNk0WaPsqqQnquqC7R9JOm/7dFV92PFsAGYw8UhdVV9U1YXx119LWpW0t+vBAMxmqufUtvdLukvS2Q1+tmR7xfbKP1oaDsD0Gkdt+2ZJr0l6vKq++v7Pq+pEVR2qqkN72pwQwFQaRW17t0ZBv1RVr3c7EoDtaPLqtyU9L2m1qp7pfiQA29HkSH1E0iOSjtq+OL7c1/FcAGY08S2tqnpbkucwC4AW8B9lQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCEPUQBiiBsIQNRCGqIEwRA2EIWogDFEDYYgaCOOqan+jdvsbxcjQ9izLa3SmqjbcuxypgTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTATo7Z9o+13bb9v+5Ltp+cxGIDZTFzOyLYl/bCqrtjeLeltSctV9Zctfmdoi+4Mx9D2LMsZdWaz5Yx2NfjFknRlfHX3+DK0hxZw3Wj0nNr2gu2LktYlna6qs51OBWBmjaKuqm+r6qCkRUmHbR/4/m1sL9lesb3S8owApjD1EsG2fyfp31X1+y1uw+l5V4a2Z3lO3ZmZlwi2vcf2LeOvb5J0r6SPWp0OQGsmvlAm6VZJf7S9oNEfgVer6lS3YwGYFZ/QMTRD27OcfneGT+gArhNEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiBM46htL9h+z/apLgcCsD3THKmXJa12NQiAdjSK2vaipPslPdftOAC2q+mR+llJT0r6brMb2F6yvWJ7pY3BAMxmYtS2j0tar6rzW92uqk5U1aGqOtTadACm1uRIfUTSA7Y/lfSKpKO2X+x0KgAzc1U1v7H9c0m/qarjE27XfKOYztD2rPseIFdVbbh3eZ8aCDPVkbrxRjlSd2doe5YjdWc4UgPXCaIGwhA1EIaogTBEDYQhaiAMUQNhiBoIQ9RAGKIGwhA1EIaogTBEDYQhaiAMUQNhiBoIs6uj7f5T0t9b3uaPx9sdim7m7WbRAfZtd7qa9Seb/aCTlU+6YHtlSCuVDmneIc0qDWvePmbl9BsIQ9RAmCFFfaLvAaY0pHmHNKs0rHnnPutgnlMDaGZIR2oADRA1EGYQUds+Zvtj25dtP9X3PFuxfdL2uu0P+p5lEtv7bJ+xvWr7ku3lvmfajO0bbb9r+/3xrE/3PVMTthdsv2f71Lzuc8dHbXtB0h8k/VLSnZIetn1nv1Nt6QVJx/oeoqGrkp6oqp9KukfSr3fwvv2PpKNV9TNJByUds31PvyM1sixpdZ53uOOjlnRY0uWq+qSqvtHokzcf7HmmTVXVW5K+7HuOJqrqi6q6MP76a40efHv7nWpjNXJlfHX3+LKjX+W1vSjpfknPzfN+hxD1XkmfXXN9TTv0gTdktvdLukvS2Z5H2dT4VPaipHVJp6tqx8469qykJyV9N887HULUG/23847+Cz00tm+W9Jqkx6vqq77n2UxVfVtVByUtSjps+0DPI23K9nFJ61V1ft73PYSo1yTtu+b6oqTPe5olju3dGgX9UlW93vc8TVTVvyT9WTv7tYsjkh6w/alGTxmP2n5xHnc8hKjPSbrd9m22b5D0kKQ3ep4pgm1Lel7SalU90/c8W7G9x/Yt469vknSvpI96HWoLVfXbqlqsqv0aPWb/VFW/msd97/ioq+qqpMckvanRCzmvVtWlfqfanO2XJb0j6Q7ba7Yf7XumLRyR9IhGR5GL48t9fQ+1iVslnbH9V43+0J+uqrm9TTQk/JsoEGbHH6kBTIeogTBEDYQhaiAMUQNhiBoIQ9RAmP8Coa7o5KFYWnEAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(make_cross_image(False).permute(1,2,0));" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs):\n", " \"Display batch `b` in a grid of size `items` with `cols` width\"\n", " if items" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_image_batch(([Image.open(TEST_IMAGE_BW),Image.open(TEST_IMAGE)],['bw','color']), items=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model init" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def requires_grad(m):\n", " \"Check if the first parameter of `m` requires grad or not\"\n", " ps = list(m.parameters())\n", " return ps[0].requires_grad if len(ps)>0 else False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Linear(4,5)\n", "assert requires_grad(tst)\n", "for p in tst.parameters(): p.requires_grad_(False)\n", "assert not requires_grad(tst)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def init_default(m, func=nn.init.kaiming_normal_):\n", " \"Initialize `m` weights with `func` and set `bias` to 0.\"\n", " if func:\n", " if hasattr(m, 'weight'): func(m.weight)\n", " if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)\n", " return m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Linear(4,5)\n", "tst.weight.data.uniform_(-1,1)\n", "tst.bias.data.uniform_(-1,1)\n", "tst = init_default(tst, func = lambda x: x.data.fill_(1.))\n", "test_eq(tst.weight, torch.ones(5,4))\n", "test_eq(tst.bias, torch.zeros(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def cond_init(m, func):\n", " \"Apply `init_default` to `m` unless it's a batchnorm module\"\n", " if (not isinstance(m, norm_types)) and requires_grad(m): init_default(m, func)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Linear(4,5)\n", "tst.weight.data.uniform_(-1,1)\n", "tst.bias.data.uniform_(-1,1)\n", "cond_init(tst, func = lambda x: x.data.fill_(1.))\n", "test_eq(tst.weight, torch.ones(5,4))\n", "test_eq(tst.bias, torch.zeros(5))\n", "\n", "tst = nn.BatchNorm2d(5)\n", "init = [tst.weight.clone(), tst.bias.clone()]\n", "cond_init(tst, func = lambda x: x.data.fill_(1.))\n", "test_eq(tst.weight, init[0])\n", "test_eq(tst.bias, init[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def apply_leaf(m, f):\n", " \"Apply `f` to children of `m`.\"\n", " c = m.children()\n", " if isinstance(m, nn.Module): f(m)\n", " for l in c: apply_leaf(l,f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.Linear(4,5)))\n", "apply_leaf(tst, partial(init_default, func=lambda x: x.data.fill_(1.)))\n", "for l in [tst[0], *tst[1]]: test_eq(l.weight, torch.ones(5,4))\n", "for l in [tst[0], *tst[1]]: test_eq(l.bias, torch.zeros(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def apply_init(m, func=nn.init.kaiming_normal_):\n", " \"Initialize all non-batchnorm layers of `m` with `func`.\"\n", " apply_leaf(m, partial(cond_init, func=func))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(5)))\n", "init = [tst[1][1].weight.clone(), tst[1][1].bias.clone()]\n", "apply_init(tst, func=lambda x: x.data.fill_(1.))\n", "for l in [tst[0], tst[1][0]]: test_eq(l.weight, torch.ones(5,4))\n", "for l in [tst[0], tst[1][0]]: test_eq(l.bias, torch.zeros(5))\n", "test_eq(tst[1][1].weight, init[0])\n", "test_eq(tst[1][1].bias, init[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## autograd jit functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def script_use_ctx(f):\n", " \"Decorator: create jit script and pass everything in `ctx.saved_variables to `f`, after `*args`\"\n", " sf = torch.jit.script(f)\n", " def _f(ctx, *args, **kwargs): return sf(*args, *ctx.saved_variables, **kwargs)\n", " return update_wrapper(_f,f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def script_save_ctx(static, *argidx):\n", " \"Decorator: create jit script and save args with indices `argidx` using `ctx.save_for_backward`\"\n", " def _dec(f):\n", " sf = torch.jit.script(f)\n", " def _f(ctx, *args, **kwargs):\n", " if argidx:\n", " save = [args[o] for o in argidx]\n", " ctx.save_for_backward(*save)\n", " if not argidx: args = [ctx]+args\n", " return sf(*args, **kwargs)\n", " if static: _f = staticmethod(_f)\n", " return update_wrapper(_f,f)\n", " return _dec" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def script_fwd(*argidx):\n", " \"Decorator: create static jit script and save args with indices `argidx` using `ctx.save_for_backward`\"\n", " return script_save_ctx(True, *argidx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def script_bwd(f):\n", " \"Decorator: create static jit script and pass everything in `ctx.saved_variables to `f`, after `*args`\"\n", " return staticmethod(script_use_ctx(f))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def grad_module(cls):\n", " \"Decorator: convert `cls` into an autograd function\"\n", " class _c(nn.Module):\n", " def forward(self, *args, **kwargs): return cls.apply(*args, **kwargs)\n", " return _c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Torch Version Checks -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def ismin_torch(min_version):\n", " \"Check if `torch.__version__` >= `min_version` using packaging.version\"\n", " return _torch_version >= parse(min_version)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def notmax_torch(max_version):\n", " \"Check if `torch.__version__` < `max_version` using packaging.version\"\n", " return _torch_version < parse(max_version)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PyTorch 1.13 `__format__` workaround -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "# PyTorch 1.13 introduced a Tensor Subclass string formatting bug\n", "# Workaround from pending PyTorch PR: https://github.com/pytorch/pytorch/pull/82766\n", "if ismin_torch('1.13') and notmax_torch('1.14'):\n", " from torch.overrides import has_torch_function_unary, handle_torch_function\n", " @patch\n", " def __format__(self:Tensor, format_spec):\n", " if has_torch_function_unary(self):\n", " return handle_torch_function(Tensor.__format__, (self,), self, format_spec)\n", " if self.dim() == 0 and not self.is_meta and issubclass(type(self), Tensor):\n", " return self.item().__format__(format_spec)\n", " return object.__format__(self, format_spec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "import nbdev; nbdev.nbdev_export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }