{ "cells": [ { "cell_type": "markdown", "source": [ "# \"PyTorch Broadcasting\"\n", "> Running computations on tensors of different ranks with PyTorch\n", "\n", "- toc: true\n", "- badges: true\n", "- comments: false\n", "- categories: [jupyter, fastai, pytorch]" ], "metadata": {} }, { "cell_type": "code", "execution_count": 77, "source": [ "#hide\n", "from fastai.vision.all import *\n", "\n", "matplotlib.rc('image', cmap='Greys')" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 78, "source": [ "#hide\n", "path = untar_data(URLs.MNIST_SAMPLE)\n", "Path.BASE_PATH = path\n", "path.ls()" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#3) [Path('valid'),Path('labels.csv'),Path('train')]" ] }, "metadata": {}, "execution_count": 78 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 79, "source": [ "#hide\n", "threes_train = (path/'train'/'3').ls().sorted()\n", "sevens_train = (path/'train'/'7').ls().sorted()\n", "threes_valid = (path/'valid'/'3').ls().sorted()\n", "sevens_valid = (path/'valid'/'7').ls().sorted()\n", "\n", "threes_train" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#6131) [Path('train/3/10.png'),Path('train/3/10000.png'),Path('train/3/10011.png'),Path('train/3/10031.png'),Path('train/3/10034.png'),Path('train/3/10042.png'),Path('train/3/10052.png'),Path('train/3/1007.png'),Path('train/3/10074.png'),Path('train/3/10091.png')...]" ] }, "metadata": {}, "execution_count": 79 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Introduction\n", "\n", "Recently, I stumbled upon [fastai](https://github.com/fastai/fastai), a deep learning library build upon PyTorch with a very eye catching slogan: *Making neural nets uncool again*.\n", "As it turns out, there is also a very [good course](https://course.fast.ai/) on this library from one of its creators, Jeremy Howard, that is very much hands on.\n", "\n", "The task of Lesson 3 is to build a Digit Classifier from hand written images based on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).\n", "But before training any models, Jeremy explains that one should always create a simplistic baseline to later put the performance of the fancy machine learning models into perspective.\n", "The baseline should be easy to implement, e.g., by only relying on simple arithmetic operations.\n", "The key idea of the lessons' baseline is to average the pixel values for each digit (in our small example below only we only use the digits 3 and 7) and then compute the difference between a given digit image to these average.\n", "\n", "During the last few years, I did a lot of typical data scientist tasks as part of my PhD.\n", "One thing that I always found quite fascinating is how NumPy treats arrays with different dimensions/ranks during arithmetic operations through a mechanism called **broadcasting**.\n", "PyTorch tensors also support this mechanism and I really liked how Jeremy used it for creating the baseline.\n", "In addition, I think that Jeremys explained very well how broadcasting works so I decided to summarize it in this blog post.\n" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Training and Validation Data\n", "\n", "We have four datasets with digit images available: A training and a validation dataset of threes, and a training and a validation dataset of fours.\n", "For example, the second three from the training dataset looks like so:" ], "metadata": {} }, { "cell_type": "code", "execution_count": 80, "source": [ "#collapse-hide\n", "show_image(Image.open(threes_train[1]))" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 80 }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJHElEQVR4nO2bXXMSZxuAL1jYXYQsiCYxIWIIjImJ0TaVTu2H41FnnGmPPOtMf0NP+i/6H9oDx+OOOtMjW6cdGz/Sk1ZqxkRDhCTQEAjfsLDsvgeWbbMmxgrEzDtcR5ln+bi5eJ5n7/t+iM0wDPr8g/1tB3DY6Aux0BdioS/EQl+IBcc+1/+fb0G23Qb7M8RCX4iFvhALfSEW+kIs9IVY6Aux0BdiYb/ErGMMw6DVatFsNqlUKmiaRrPZpNls0mg0Xnq8LMtIkoSu6+i6jtPpxOFw4HQ6EQQBWZZxOHoXds+FaJpGvV4nHo/zww8/kMlkSKVSxONxnjx5wr/7MTabjffee4+pqSmq1SqqqjIyMsKxY8eIRCIMDw8zOzuLz+frWbxdF6LrujkDisUilUqFbDbLn3/+yeLiIoVCgfX1dTKZDJVKBVEUEUWRer1OrVZjZWUFh8NhCqlUKmxublIulzl+/DgTExM9FWLbp2P2n2sZVVXZ3t5mdXWV77//nlQqxaNHj8jlcqTTaQzDwDAMZFnG7XYzODhIIBDgyZMnPH/+HLvdjt1uN2eOzWbDZrPh8/nwer1cv36daDT6hh93B7vWMh3PEE3TqNVq1Go1UqkUlUqFVCrFysoKT58+pVQqIQgCExMTRKNRnE4nkiQhiiIulwtFUfB6vczOzpLJZFheXubZs2eUy2VqtZr5Po1GA1VV0XW905BfScdCVFUlFovx22+/8c0331CpVGg2m7RaLRqNBoFAgGg0ysWLF7l69Soejwe3220+3263Y7PZ0HUdwzC4ceMG165dIxaLkUwmOw3vP9OxEMMwaDQaVKtVSqUS1WoVXdex2+2IokggEGBubo5z587h8/lwOp04nU7z+e0l8e+lJAgCNtvOGe33+wkGg8iy3GnIr6QrQqrVKrVaDVVV0TQNAFEUOXr0KHNzc3zxxRf4fD4URXnpg7ZpjzudTux2+0vXpqenmZmZQVGUTkN+JR0LcTqdhEIhRFEkn8/TbDaBF0Lcbjdzc3N4vV5EUdxTBrzYizRNI5/Pk8/nzRzFZrMhCALDw8OEw2FcLlenIb+SjoVIksTp06cJh8N88MEH5nh7KQiCgCiK+75Oo9GgVCqxtrZGMpmkWq0CmM+PRCJEo9Ed+08v6FhI+1sXBGHH3tC+Zp3+e7G+vs7PP//M77//TqlUQtM0BEEgHA4zPj7OzMwMJ06ceOk9uk1XErP2bHidmbAXd+7c4auvvkLTNHRdx+FwIIoiFy9eJBqN8u6773LixIluhPtKep66WzEMA13XqdfrlMtlM2FbWFgwN2S73U4kEiESifDhhx8SjUZ7vpm2OXAhuq6jaRrZbJbHjx/z448/cvPmTbLZrHm7djgcXLhwgY8//phPP/2UkydPHlh8B1LtGoZBsVhkdXXVrGWSySSrq6ssLS2Ry+Wo1+vAi3zD7/dz+vRpzp8/z8DAQK9D3MGBlf/JZJLvvvuO5eVlHjx4gKqqO1LzNkNDQ0xPT3PhwgUmJyd7fpu1ciBLRtd1CoUCjx49Yn19HVVVzXzFSiaTIRaLcevWLeLxOMeOHcPj8TA2NobX62VwcLCnkg5syWxtbfHgwQM0TTM31t3IZDJkMhnW19dxu90oioLH4+HKlSucP3+eS5cuIcvyK5O8Tuh6+f/SC/y9ZDY2Nrh9+zblcplCoUCxWCSXy5mPi8fjLC0tmT0UURRxOp1mv2RqaorR0VE++ugjzpw5w9mzZ/F6vQiC8Nq5joVdjfZciJVarWbKWFtbM8d/+uknfvnlF1ZXV0mn07s+t91Ri0QifP3110xOTiJJEoIgvEkovemH/FecTieKoiDL8o7O19DQEJcuXWJtbY10Ok0mkyGXy3H//n3i8TjwYrYlEgmq1SrPnz9ncHCQ48ePv6mQXTlwIQ6HA4fDgcvlwuv1muPDw8NMT09TrVbNW/TKygrZbNYUArC5uUkulyMejxMKhTh69Gh34+vqq3VAuxBsd9VlWSYYDBKPx/nrr79IJBJsb28DL2bKxsYG8XicU6dOdTWOQ3Mu0y4EJUkye63BYJDZ2VmmpqZ2pO6GYZgzR1XVrsZxaITsRbtwtI55vd6eVL+HXgjAbndCt9uN3+/v6oYKh2gPsVIsFtne3mZhYYGHDx/uyFlsNhunTp0iEol01HLYjUMppF0MJpNJEokEiURiR2YrCAJDQ0P4/f6uH2seOiHVapVKpcK9e/e4c+cOf/zxh3lEATA5Ocn4+DiBQABZlt80S92TQyOk/YHr9TrZbJZYLMb8/DypVMq8ZrfbCQaDRCIRvF4vDoej6zXNoRFSKBRIp9Pcvn2bu3fv8vjxY5LJpNknGRgYwO12c/XqVS5fvszo6Oiu5zed8laFtCthwzDI5/M8ffqUhw8fcuPGDbO3Ci820SNHjjA4OMi5c+cIhUI9kQFvUYiqqqiqysbGBktLS9y9e5f5+XkSiQTNZtNcJi6XC1mW+fLLL/nkk08Ih8M9kwFdFvLvb9yaULXH2383Gg0KhQLxeJz5+XkWFha4d++e+fj2me/AwACKonD27FneeecdPB5Pz2RAF4Vomka1WqVcLrO2toaiKIyMjJiH3pVKha2tLUqlEpubmywvL7O4uGjWKfl8HvgnM41EIoTDYa5cuUI0GiUUCqEoSk9/PQRdFNJqtSiVSmxtbRGLxRgdHUWSJPMXRLlcjuXlZba2tsxl0u6tqqq645RPFEXGx8cJh8NEo1FmZmbMhlGv6ZqQfD7Pt99+SzKZ5P79++Zhd6vVotVqmecwqqqaM6ZSqZgb5/DwMGNjY7z//vtmk/nkyZMoioIkSV3PN/aia0IajQaJRIJnz56xuLj4Wj9ssdlsSJKEJEmMjY0RiUQ4c+YM0WiUiYkJ/H5/t8J7bbomxOPxcPnyZTweD7/++uu+QmRZ5siRI3z++ed89tlnhEIhRkZGcLlcB7Y8dqNrQhwOB8FgkHQ6TSAQMI8l98LlcuHz+ZicnGR2dpahoaEdHbS3RdeazK1WyzxvKRaL+7/x3w0ht9ttdsm6XcrvF8KugwfddT9E9P+j6nXoC7HQF2KhL8TCfrfd3lVRh5T+DLHQF2KhL8RCX4iFvhALfSEW/gcMlBno19ugeQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The first thing we want to do is to load the training images into tensors." ], "metadata": {} }, { "cell_type": "code", "execution_count": 81, "source": [ "three_train_tensors = [tensor(Image.open(o)) for o in threes_train]\n", "seven_train_tensors = [tensor(Image.open(o)) for o in sevens_train]\n", "# lets see how many tensors we have\n", "len(three_train_tensors), len(seven_train_tensors)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(6131, 6265)" ] }, "metadata": {}, "execution_count": 81 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Each tensor contains a pixel matrix, the values in the matrix describe the color of each pixel on a grey scale.\n", "A 0 indicates a white pixel and a 255 indicates a black pixel.\n", "For example, row and column 4 to 10 from the pixel matrix of the three above looks like so." ], "metadata": {} }, { "cell_type": "code", "execution_count": 82, "source": [ "three_train_tensors[1][4:10,4:10]" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 29],\n", " [ 0, 0, 0, 48, 166, 224],\n", " [ 0, 93, 244, 249, 253, 187],\n", " [ 0, 107, 253, 253, 230, 48],\n", " [ 0, 3, 20, 20, 15, 0]], dtype=torch.uint8)" ] }, "metadata": {}, "execution_count": 82 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Now, for calculating the mean of all pixel values, i.e., the *ideal* three or seven, we need to stack the image matrices.\n", "In other words, we need to create a cube.\n", "Since we will later on calculate the mean of each pixel based on the cube, we will end up with float values.\n", "Therefore, we already convert integers to floats and transform them to be between 0 and 1. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 83, "source": [ "threes_train_stacked = torch.stack(three_train_tensors).float()/255\n", "sevens_train_stacked = torch.stack(seven_train_tensors).float()/255\n", "stacked_threes.shape, stacked_sevens.shape" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([6131, 28, 28]), torch.Size([6265, 28, 28]))" ] }, "metadata": {}, "execution_count": 83 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The results are two cube, i.e., two rank-3 tensors, with the above shapes.\n", "The rank of each tensor is the number of axes while the shape is this size of each axis." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# A Metric to Determine the Similarity\n", "\n", "With the rank-3 tensors we can now calculate the mean three digit and the mean seven digit.\n", "We need both to compare how similar a given digit is to the *ideal* three or seven." ], "metadata": {} }, { "cell_type": "code", "execution_count": 84, "source": [ "mean3 = stacked_threes.mean(0)\n", "mean7 = stacked_sevens.mean(0)\n", "\n", "show_image(mean3), show_image(mean7)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(, )" ] }, "metadata": {}, "execution_count": 84 }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJtUlEQVR4nO1b2XLiWhJM7QsChDG22x3h//+qfnKzWVhoX5HmoaNqDufK9jRge2aCiiCEAS0nVUtWlqz0fY+r/dvU776A/za7AiLZFRDJroBIdgVEMv2D7/+fS5Ay9OHVQyS7AiLZFRDJroBI9lFS/RQ7tV1QlME8eFH7dEDkxdPf4udDAMmLVxQFfd8Pfn5JuyggQ4vs+/7o1XUdfy6+l01RFCiKAlX9E9WqqvJn9JJ/fwm7CCDy4ruu423XdTgcDjgcDqiqCm3boq5rtG2LpmlwOBzQti3vr6oqVFWFaZrQNA2WZUHXdViWBU3ToOs6NE3j3xFQZOcCcxYgQ15AIHRdh7Zt0bYtqqpCXdcoigJlWaIoCtR1jTzPUdc1yrLk4+i6DlVVMRqNYFnW0dY0Tdi2DcMwYBgGNE1jEGRgTrWTARHBkD2haRpUVYU8z1EUBcIwRJIk2Gw2CMMQu90OaZoiiiKUZYk8z3E4HNB1HUzThGEY8DwPo9EIi8UCs9kMj4+PmM1mmM/ncF0X4/EYlmUdeY4YYqeCc7aHDAFSliWqqkKSJMjzHLvdDmEYYrPZII5jBEGALMsQRRGyLEOWZRw6dOdnsxk8z0PXdSiKAoqioKoq6LqOw+EAXdfR9z17CYXPUOL9dEDkXCHmiKqqEMcx0jTFdrtFGIZ4fn7Gfr/HZrNBmqYIggBJkiCKIlRVhbIs0bYtDocDn8NxHFiWhYeHB9ze3iIMQ9zc3CDLMtze3qJtW0wmEyiKAsdxjpLvOXZ2UhXzBy1K3LZty+GlaRpM08RoNIKiKNA0jQFpmgZt27Knkaf0fY+6rlHXNYdhlmX8N53DNE2+ji/3EAJiKHfQhdZ1zVVEVVVYlsVxb9s270P7U/WhvNM0DZqmga7rnJgp76iqivl8Dk3TMB6Poes6uq7jkKEbcAowJ4eMaHRiimNd19kTyHNs24Zt27xQApRAITAJkKIokOc5NE07SpbyfqKHXsIuRszoonVdh+M4nOw8z4Pv++wB8j60QAJgv98jjmOuTMRZdP3PpYqe+R4Q31Jl6MQiGADQdR2Tp6ZpOESImYpGn2dZBl3XUVUViqJgPkILo5xD5zEM44ikib87x04CRD45ubVpmnyRXdfBcZyjaiTfTSJvTdPAsiyYpvmPUAFwlBNM02SCRpxlCLxT7SwPES9AVVVehBgKsnuLpZq25BVhGGK/32O/3yNNU+R5zhVLVVUYhnHEXonWEykb6nG+HBDxLhIQcqKTgaBS3Pc9V4+Xlxes12usViu8vLwgiiLs93tesKZpsG0bk8kEvu9jNBrBdV0GhRL6uXYyIDIQiqKg67pBUESeIvKJJEmYwa7Xa2y3WwRBgN1uhyzLkOc5JpMJl2rXdeF5HsbjMRzHOQpRsRv+FkBEUAgEIlJvtfvU4NHdXy6XWK1W7BUESBiGHIau60LTNDiOg8lkwpTecRw4jnPkHd+WQwgA8b28JRAOhwN3tHEcI4oiDo3lcsmhst1usdlsUBQFqqri5EkVi7jNUHW5VIU5GZD/BBSRxRZFgSRJsN1usVqt8OvXL6zXazw/P+P3799YLpeI4xhJkvAxPc+D53kA/lQxSqhi6y+LRnQt59jFRGbxQihcyDuKouCmbrPZYLPZIAgCLJdLBEGANE1R1zULRCLPAHBUiaiBpLZAZqvnMtazc8iQZirS67qukWUZ4jjGer1mQJbLJZbLJZIkQZIkXIYVRTnyAjpWVVUsFbiue9QMUh/zrSHznhFIIg+h5EpyoO/7WCwWGI1GmEwmDKSmaZxEKUwo7KIoQhAE3NRR9yxrr8A3UnfRZJFZJmZEvy3Lgud5uLu7Y3WNTNRLAcAwDHRdhzzPoaoqXl9foWkabm5uoOs6bNvm4367h7w3SlAUhXOB67po2xY/fvxghlkUBbIs49AiI/BkQbrve65UqqoiiiJomgbXdTnviJ5C1/C3dramKr8XL0am2/P5HI7jwHVdFn1kSk+9DVH3OI4ZOGK1qqpiv9/DsixUVcVeJHriqXaWHjK0FUsxjRNc14VhGDBNE3VdYz6fs2fIgJCC9vr6ijiOmXiRStY0DVet0WiEsixhmib3O8SWvyyHvFVVxO+o+pD7ip0pdbhDs5yu61igtm2bQ4tAAoC6rqFpGqv1ojInMmU69t8Cc5aHyG39EDAU3zRzeav5I05BJZd6HgLGMAxOvqJsQFLkECf58hwytDBxS6CQejaUa2QPIbNtm0uvyErl38tgnGt/BYjsGVQd3tM236PW4gLp+LRYyhUkWFOYDc13L8FQyU7OISIAlBxliVAeWL8Finx8keWSQCQn7KF9ZftS1V2Me3FoTQuStVZRURMBot9TcozjmGn+arXCbrfjkSclTmK7NBCn7ve9pwM+DRBZ6xABaZrmKBcQ42zbli9cFKMVRWFQReEoTVNW37MsQ1EURyFD9F7WU7+NqYqA0BCpbVu+8Kqq+Ht5VkPAkFG1oMZts9mwRvL6+ookSVAUBatjJBT5vo/pdMojTzr2uU3eSUlVBoW8g2a0RVFwDiAKL48OyGhARSradrvFbrdDEAQcKsRGiehRBSJuQ99dwlP+ChBZFBJPTqSqLEsEQcC0m7xIfIaDRo9Ex8uyRJqmSNOU5YA8z1GWJfMQz/MwnU5xf3+Pu7s73N3dYTKZwPM82LY9mEc+HRARGPmkYnIkgTiKIvYcAk0UpEV5kYbYSZLw4xF93zMpcxwHnudhMplgMpnAcRxmwEO66qn214CIYFBiI+2TdFAAR3khDEPUdY04jo9yDok8IuOkhOn7PsbjMR4fH+H7Pp6ennB/f4+npydMp1PMZjMGhfb58pAZAoUSpvwiLxCbsd1ux5WE8g6Vb3J3Yqeu68L3ffi+j/l8jsVigdvb26MwuVQiPRkQOinxCHERNL60bRsAMJ1Ooes69vs9DMNAkiQwDIPlRNI66O7SvGU+n2M8HuP+/h6z2Qw/f/7EYrHAzc0NC88iGEPc5ssAkcERgSHdAwCr5VmWQVGOH4Ui8IjIkUfNZjOMRiPMZjP4vs9PDj08PGA8HnPeoFmMqKxdcgyhfNADDH451NOQuEP6Z9M0XCnSNOVHrUg9pypDi6PRpLilZ0rkofZQiT0BjMEdzgJkiLVSmaXRgchPaEu5gzQTEotN02SSRS+x2x16NvUMr7gcIPzlAFED/tn9fjQ7kZO03JO81aOcGSKDO19ktiv/LbblQ9uPjvfW9q3zXtLO8pCP7BIaxScu/vIe8uEZP/FOfpZd/4FIsisgkn0UMv97Pn+mXT1Esisgkl0BkewKiGRXQCS7AiLZvwBtCZqwAvXF1QAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI6klEQVR4nO1baVPiTBc9ZF/IhoOWOjUf5v//KgelNEZICCErvB+eutemjaNC8N04VanG7H1y99uOdrsdzniF8u9+gf80nAmRcCZEwpkQCWdCJGgfHP9fdkGjvp1nCZFwJkTCmRAJZ0IknAmRcCZEwpkQCWdCJJwJkfBRpHoQPlNjkc8ZjXoDxzf47HmHYhBCxMnR74/Gz0IkYDQa8d/y2Hf+ITiKEHmS2+2Wx91uxxv9TcfFY+L1BHGy9FtRFIxGo96RjsvXH4KDCBEnIk6aJt51HbbbLbquQ9d1aNuWR9pP58n3kUET1zQNqqrCMAyoqgrTNKGqKjRNg6IoexvhEGK+TIj44kQCEdA0DZqmQVmWaJoGm80GdV1jvV6jrmvkeY6qqrDZbNA0Deq6Rtu2TFSfJKmqCkVR4LouTNPExcUFXNdFFEVwHAe+78M0TViWxQSNRiOoqordbvdlUr5EiCwZ9KVJAoqiQNM0WK/XKMsSWZahKAokSYLNZoMsy1CWJRPVNA1fS6SKKgaACRmPxzBNE9PpFJ7n4devX/B9H5qmYbfbMRHb7RaKohxExpcI6VMPmgxNcLVaoSgKxHGMNE0xn8+RZRniOEae51gsFlitVlgulywpoiqJ6kQbffUgCOB5Hn7//o0wDJGmKS4vLwEAQRBA0/6ZCqnYyQnpI0ckhlSFJCHLMqRpymOe50iSBOv1GqvVCm3boq7rvfuJIHLatkVVVRiNRmiaBlEUQVEU5HkO13VZ0kjCjsWXVUYkous6NE2DqqpQliXyPEeapojjGMvlEnEcI8sy3N/fY7VaIUkS1HWNsiyZAEVR2DCS1xCfQXamrmvoug7TNFHXNS4vL2EYBvI8h23bLK3vGeeTEPI3kIskEdd1HbquwzAM+L4PVVUBgL+6fA55ETKyy+US6/UaWZZhvV7zMwDskSkSSb9Fd/1VfIqQjxgXyaAJmqYJ27aZhPF4jDAMoaoqu03HcfhcXdehqip7qiRJkGUZ7u7u8PT0xAabDCdBdrnHkPFpQshIyQQoisISsd1uYZomuq5DEARQFAV1XcO2bei6ziqg6zosy4Jt23BdF5ZlwbIslpDNZoOyLHl/nucoioLVgQik0bIsjk2+jRCRiD5CDMMAALiuC1VV0XUdTNOEoiioqgphGLKtoNjB8zy4rssTo3sSIZ7nYTabIc9zrNdrVFWFtm1hWRYcx4Ft27Btm4kjCRNV5tu8jKizAKDrOn89ALBtm41j13WoqgqapsE0Tbiuy5LhOA50XedYYrfb7UWmwD/qRt5IVVU4jgPXdTEejxEEAUuIaJiPwacJER+kKAoHQPTypNuqqrL6ULS42+1YVSzLYsmwLIuJJZUiIk3TZEKqquKYxHEcOI4Dz/M4WqUoVbQjJydEJIaCHjHxAv6RFABMhghR98muEJF0Pbnbuq6RZRkHcuRlDMPAeDyG7/sIw5CjV13X3+Qxh+JglQFeJYV0V9d1Jqxt2z2dFr0P6TtNgK6h2KYsS6RpisVigcViwUEYqVwQBIii6A0hx7rcLxMiehvRsJLEkOEkkogQ2i8TQRCDvDzP8fz8jMfHRzw/PyNNU9R1zaG77/sIggCu68K27T3SAey938mTO3oQPVj0OuRxSBrETJU2MVUX70PGl/Kh+XyOp6cnzGYzpGmKpmmgaRo8z4PneQjDkCWGpGMoHBypytICvNoSsh80igUdoL+EUBQF0jTF/f097u7uEMcx4jhG0zRQVRVhGGIymfBIrvZYFZFxVOjeV84Tv5ZMmLyfJKNtW2w2G86QHx4eMJ/PkaYpezFys7TJrnYoUg42qrItod8A9ryGDLmMQMlekiT48+cPZrMZ5vM5FosFmqZhb3J1dYXLy0tMp1OMx2OOP0TJG4KUoyVEtCVkYGkTiZOLS2RI67pm6Xh4eEAcx3h4eECe5+i6DoZhIIoihGGIi4sLNqhiZErvMgSOtiF9Ve/tdvsm/wH2SaGSY1EUeHl5wWw2w2w2w8vLC6uK53m4vb3Fz58/cX19jZubGwRBwEmhHIx9u9v9G+RIVm5NyG6RCktUR0mSBEmSYLFYoCxLqKoK27bx48cPXFxcYDKZYDKZwHEcmKb5xn70kfBtucxnH0xSIhdtKHCjuuv9/T3HHGVZQtM0RFGEIAhwfX2N29tb3NzcIIoi2LbNkXBfyn+sCg1iQ+S/33sZUWWoClYUBZbLJfI8R5Zl7GZ938d0OkUURZhMJvB9H7Zt73mX98g4BkdLSB8phPdUheqvaZri6ekJSZJwi4LynNvbW0ynU5YQMqaidPTZjW/Ldv8GedLv7SPVIemgEiGRQV5FzGajKILneRyIDW1EZQza7O7zLMCrVxGN6HK5xOPjI6vLdruF4zgIggC2bePq6gpXV1fchxErY3LkKz7/WJx8OURfy4J6Mnmec08HANdIKGchcvq8CtAfFB6LQSVEVg+5b0M9mcVigefnZywWC2w2G4xGI1aJyWSCKIrY1ZLdoJrr0KG6jJOtD/lbD2ez2aAoCpRlia7rOF/RdZ2Lz5TeExGniEr7cJL1IWJoTipSFAVWqxWyLMPLywvyPOe2AnkWUUJEdZELQKfE4CrTJx3Ua6mqipO5uq65qKxpGtsPalGIlfTvIgMYgJC+tSLUZyXpoJ7ver1mQ0pVNbHOats2wjCE7/uwLOvDmOMUGNyG9NkOMqo00nIHsTBMTScav8uIyhhsSdV7RpQWxtBGrQbTNN8siKEikOhqRYP6X6EyIvoW1IgrgyiUp8YUqYSqqiwdhmFwi2KIPstXMWhy99451AR3XZcnKa8CIBsitiffa1mI49A4SRwC7Lc7adLU6gReWw+apnGb0zAM3khy/lbzOAUpow++8KdWnvTZEHHJFdkSMqxt2+6pEEkRkUNumEj5qDJ2IDG9Fw1eMaNqmdinESdMC+xEQoDXxXWiIZXvcYp0/808hpAQPrknYhX391XPel+qJ4EbqiImPqZ355CE8EXvlADeO953ft/EB5aKgwj5v8P530MknAmRcCZEwpkQCWdCJJwJkfAv6ObhbeIGuNEAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "For comparing how similar a given digit is to each of these means, a good metric is the L1 norm, i.e., the *mean absolute value of differences*.\n", "This might sound complicated, but it is actually very easy to understand how the L1 norm is calculated just by looking at some code." ], "metadata": {} }, { "cell_type": "code", "execution_count": 85, "source": [ "# let's load the second image again\n", "i_3 = threes_train_stacked[1]\n", "\n", "# then calculate the list of absolute differences for three and sevens\n", "list_of_differences_3 = (i_3 - mean3).abs()\n", "list_of_differences_7 = (i_3 - mean7).abs()\n", "\n", "# then calculate the l1 norm which is the mean of all of these differences\n", "l1_3 = list_of_differences_3.mean()\n", "l1_7 = list_of_differences_7.mean()\n", "\n", "l1_3, l1_7" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(tensor(0.1114), tensor(0.1586))" ] }, "metadata": {}, "execution_count": 85 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The L1 norm, i.e., the mean absolute value of differences, is smaller for the mean three so it is a three.\n", "This is correct, great!" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Computing the Metric with Broadcasting\n", "\n", "We now come to the most intersting part of this blog post: We will use broadcasting to calculate the L1 norm for both validation datasets.\n", "First, we again create two rank-3 tensors from the validation data." ], "metadata": {} }, { "cell_type": "code", "execution_count": 86, "source": [ "# read in validation data (same approach as above)\n", "threes_valid_stacked = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()]).float()/255\n", "sevens_valid_stacked = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()]).float()/255\n", "\n", "threes_valid_stacked.shape,sevens_valid_stacked.shape" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))" ] }, "metadata": {}, "execution_count": 86 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Then we define a function that computes the L1 norm with a similar approach as above." ], "metadata": {} }, { "cell_type": "code", "execution_count": 87, "source": [ "def l1_norm(a,b): return (a-b).abs().mean((-1,-2))\n", "\n", "l1_norm(i_3, mean3)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(0.1114)" ] }, "metadata": {}, "execution_count": 87 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "This seems to work! But did you notice the (-1,-2) in the mean function?\n", "With these parameters, we instruct the function to only use the last two axes for calculating the mean.\n", "Since we have a list of rank-2 tensors, we could have also ommited the parameters.\n", "But let's see what happens if we supply the complete stack of valid threes, i.e., a rank-3 tensor, instead of a single three, i.e., a rank-2 tensor.\n" ], "metadata": {} }, { "cell_type": "code", "execution_count": 88, "source": [ "result = l1_norm(threes_valid_stacked, mean3)\n", "result, result.shape" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(tensor([0.1634, 0.1145, 0.1363, ..., 0.1105, 0.1111, 0.1640]),\n", " torch.Size([1010]))" ] }, "metadata": {}, "execution_count": 88 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Wow! The same function now returns a rank-1 tensor instead of a single value.\n", "And this rank-1 tensor comprises the L1 norm for each three of the validation dataset.\n", "But why can we substract a rank-2 tensor (the b in the l1_norm function) form a rank_3 tensor (the a in the l1_norm function)?\n", "\n", "Jeremy explains it like so:\n", "\n", "> The magic trick is that PyTorch, when it tries to perform a simple subtraction operation between two tensors of different ranks, will use broadcasting. That is, it will automatically expand the tensor [actually PyTorch only pretends to expand the tensor, it does not allocate any extra memory] with the smaller rank to have the same size as the one with the larger rank. Broadcasting is an important capability that makes tensor code much easier to write. After broadcasting so the two argument tensors have the same rank, PyTorch applies its usual logic for two tensors of the same rank: it performs the operation on each corresponding element of the two tensors, and returns the tensor result.\n", "\n", "This is also why we supplied (-1,-2) to the mean function: We always want the mean of the last two axes, no matter what rank the provided tensors have." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Testing the Accuracy of the Baseline Model\n", "\n", "For testing the accuracy of the baseline model, we need a simple function that computes whether the l1_norm of a given image is smaller for the mean three or mean seven." ], "metadata": {} }, { "cell_type": "code", "execution_count": 89, "source": [ "def is_3(x): return l1_norm(x, mean3) < l1_norm(x, mean7)\n", "\n", "is_3(i_3)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(True)" ] }, "metadata": {}, "execution_count": 89 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "We again can use broadcasting to run this function on the validation dataset!" ], "metadata": {} }, { "cell_type": "code", "execution_count": 90, "source": [ "is_3(threes_valid_stacked)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([True, True, True, ..., True, True, True])" ] }, "metadata": {}, "execution_count": 90 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "With all the pieces together, we can calculate how many digit images have been correctly identified as threes or sevens, i.e., the accuracy of our baseline model." ], "metadata": {} }, { "cell_type": "code", "execution_count": 91, "source": [ "accuracy_3s = is_3(valid_3_tens).float() .mean()\n", "accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()\n", "average_accuracy = (accuracy_3s+accuracy_7s)/2\n", "\n", "average_accuracy" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(0.9511)" ] }, "metadata": {}, "execution_count": 91 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The average accuracy is well above 90%. It will be interesting to see whether we can top this with deep learning 😉." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Closing Remarks\n", "\n", "In this blog post, I used Jeremys approach for creating a baseline Digit Classifier model to provide a broadcasting example.\n", "By why should one use broadcasting instead of, for example, just creating some loops?\n", "I give you two reasons:\n", "1. A lot less code that is also much easier to read (even though the concepts are a lot more complex).\n", "2. PyTorch calculations are run in C (or CUDA if using a GPU), which makes it thousands of times faster than pure Python (or up to millions of times faster on a GPU)." ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.8.11", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.8.11 64-bit ('jupyter-blog': conda)" }, "interpreter": { "hash": "19d3752f5da779b9d184a1d3beb29a8abc3512516ead272e7e4eaf4a929c6215" } }, "nbformat": 4, "nbformat_minor": 2 }