{ "cells": [ { "cell_type": "markdown", "id": "varied-fossil", "metadata": {}, "source": [ "# Circular Convolutional Layers in the Fourier Domain\n", "--------------------------------------------------------\n", "This is a tutorial accompanying the ICLR 2021 paper *\"Orthogonalizing Convolutional Layers with the Cayley Transform\"* by Asher Trockman and Zico Kolter.\n", "\n", "This Jupyter notebook is best viewed with [nbviewer (click here)](https://nbviewer.jupyter.org/github/locuslab/orthogonal-convolutions/blob/main/FFT%20Convolutions.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "military-interstate", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [3, 3]" ] }, { "cell_type": "markdown", "id": "institutional-graduation", "metadata": {}, "source": [ "### Quick start" ] }, { "cell_type": "code", "execution_count": 2, "id": "balanced-narrative", "metadata": {}, "outputs": [], "source": [ "# Try changing these parameters and running the notebook.\n", "cin = 3 # input channels\n", "cout = 3 # output channels\n", "n = 5 # spatial (input) size\n", "k = 3 # conv. kernel size\n", "batches = 1 # batches" ] }, { "cell_type": "markdown", "id": "loose-aside", "metadata": {}, "source": [ "We will implement the following circular convolution using FFT functions:" ] }, { "cell_type": "code", "execution_count": 3, "id": "worst-boulder", "metadata": {}, "outputs": [], "source": [ "x = torch.randn(batches, cin, n, n)\n", "conv = nn.Conv2d(cin, cout, k, bias=False)" ] }, { "cell_type": "markdown", "id": "positive-finance", "metadata": {}, "source": [ "Notice that we make the convolution \"circular\" using circular padding (below). The left/right and top/bottom paddings differ to account for even kernel sizes (they will all be the same for odd kernel sizes)." ] }, { "cell_type": "code", "execution_count": 4, "id": "located-thing", "metadata": {}, "outputs": [], "source": [ "y1 = conv(F.pad(x, ((k - 1) // 2, k // 2, (k - 1) // 2, k // 2), mode=\"circular\"))" ] }, { "cell_type": "markdown", "id": "favorite-palestine", "metadata": {}, "source": [ "First, we compute the 2D FFT of the input (applied only to the last two dimensions):" ] }, { "cell_type": "code", "execution_count": 5, "id": "lyric-nicaragua", "metadata": {}, "outputs": [], "source": [ "xfft = torch.fft.fft2(x)" ] }, { "cell_type": "markdown", "id": "still-return", "metadata": {}, "source": [ "Then we compute the 2D FFT of the weights, also applied to only the last two dimensions.\n", "We have to do a couple things here, however: first, we have to pad the weights to have the\n", "same spatial size as the inputs. " ] }, { "cell_type": "code", "execution_count": 6, "id": "checked-membrane", "metadata": {}, "outputs": [], "source": [ "wpad = F.pad(conv.weight, (0, n - k, 0, n - k))" ] }, { "cell_type": "code", "execution_count": 7, "id": "physical-killer", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAL4AAADSCAYAAAD0Qnq8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAKu0lEQVR4nO3dbYxU5RnG8f/FCgFBCyJRYRFMtLSERGwRTbCt8SVFq9X4xZdK/GDSpNEGW5PGpo3V1vZTozbR1phqqfGtVq01RmtpizUmviGiEdEEjVSQCgIbAYm83f0wZ+3pdoWDO+fMLPf1SyY5M8+c57mZvebhnJnd5ygiMMtmRKcLMOsEB99ScvAtJQffUnLwLSUH31Jy8NtA0nWS7t5L+zuSzviMfX/mfe3TpQ5+EartkrZKel/SIknjOl1XVUW9NzQwzqmS1tQ9TpNSB79wbkSMA74EzAF+3OF6rAEOfiEi1gJPALMkTZD0mKQNkjYX2739z5V0jKR/StoiaTFweLkvSQskrZa0UdKPBrSNkHSNpLeK9gckHVZl372RNF1SSLpM0r8kfVDevzgce1DSH4q6l0k6vtQeko4t3V8k6QZJY4vXZXLxP+NWSZOr1tWtHPyCpKnA2cDLtF6X3wHTgKOB7cAtpaffC7xEK/A/Ay4r9TMT+A2wAJgMTAR6S/t+Fzgf+FrRvhm4teK+VZwCzABOB66V9MVS23nAH4HDin/DI5JG7q2ziNgGnAW8FxHjitt7+1lT94mItDfgHWAr0AesBn4NjBnkebOBzcX20cAuYGyp/V7g7mL7WuD+UttYYAdwRnF/JXB6qf0oYCdw0L72HaSuRcANxfZ0IIDeUvsLwEXF9nXAc6W2EcA64CvF/QCO/ZS+TwXWdPrn1c7bQUN4zxwozo+Iv5UfkHQwcBMwH5hQPHyIpB6KWTpaM2G/1cDUYnsy8G5/Q0Rsk7Sx9NxpwJ8k7Sk9ths4osK+Vfy7tP0RUD5ZL/e9pzhhHfaHLZ+FD3UGdzWtw4WTIuJQ4KvF46I1S04ojn37HV3aXsd/3wT9b6KJpfZ3gbMiYnzpNjpa5xj72neoyn2PoHUY1X/Y8hFwcOm5R5a2D7hf4XXwB3cIreP6vuLE8yf9DRGxGlgKXC9plKRTgHNL+z4InCPpFEmjgJ/yv6/zbcDPJU0DkDRJ0nkV9x2qL0u6QNJBwFXAx8BzRdty4BJJPZLm0zoH6fc+MFHS59pYS0c5+IO7GRgDfEArGH8Z0H4JcBKwidab4q7+hohYAVxB67h/Ha2T1/Jn4L8CHgX+KmlL0f9JFfcdqj8DFxb9LgAuiIidRdtCWm/gPuBbwCOlf9MbwH3A25L6DoRPdVScvNgBTtJ1tE5eL+10Ld3AM76l5OBbSj7UsZQ841tKDr6lVMs3t6PGj4kxRx5aR9eV7d4wqqPjH9f7fkfHB1i58YhOl9BxOzdvYve2bRr4eC3BH3Pkocy7/cI6uq5s023TOjr+E7+8qaPjA5x41/c6XULHrbll8J+DD3UsJQffUnLwLSUH31Jy8C0lB99ScvAtJQffUnLwLSUH31Jy8C2lSsGXNF/Sm5JWSbqm7qLM6rbP4BdrydxKazWtmcDFxYpfZsNWlRl/LrAqIt6OiB3A/bSWojMbtqoEfwqlFbhoLXcxZeCTJH1b0lJJS3f0bW9XfWa1aNvJbUTcHhFzImLOqPFj2tWtWS2qBH8tpaXnaC07t7aecsyaUSX4LwLHFWvCjwIuorUSmNmwtc8/PYyIXZKuBJ4EeoA7i6XuzIatSn9zGxGPA4/XXItZY/zNraXk4FtKDr6l5OBbSg6+peTgW0oOvqXk4FtKDr6l5OBbSrUsE74nxPZdI+vourIPp3f2PX3iXd/v6Pi2d57xLSUH31Jy8C0lB99ScvAtJQffUnLwLSUH31Jy8C0lB99ScvAtJQffUqqyTPidktZLeq2JgsyaUGXGXwTMr7kOs0btM/gR8TSwqYFazBrjY3xLqW3BL18YYmffR+3q1qwWtVwYYuT4g9vVrVktfKhjKVX5OPM+4FlghqQ1ki6vvyyzelW5MMTFTRRi1iQf6lhKDr6l5OBbSg6+peTgW0oOvqXk4FtKDr6l5OBbSg6+peTgW0q1XBhixug+/j7z0Tq6ruzrZ8zu6PgjZn2ho+MDvHXJhE6X0LU841tKDr6l5OBbSg6+peTgW0oOvqXk4FtKDr6l5OBbSg6+peTgW0oOvqVUZSW1qZKWSHpd0gpJC5sozKxOVX47cxdwdUQsk3QI8JKkxRHxes21mdWmyoUh1kXEsmJ7C7ASmFJ3YWZ12q9jfEnTgROA5wdp+2R9/A0bd7epPLN6VA6+pHHAQ8BVEfHhwPby+viTJva0s0aztqsUfEkjaYX+noh4uN6SzOpX5VMdAXcAKyPixvpLMqtflRl/HrAAOE3S8uJ2ds11mdWqyoUhngHUQC1mjfE3t5aSg28pOfiWkoNvKTn4lpKDbyk5+JaSg28pOfiWkoNvKdWyPv5rGyfx+UXfqaPr6n7R2eGtu3nGt5QcfEvJwbeUHHxLycG3lBx8S8nBt5QcfEvJwbeUHHxLycG3lBx8S6nKSmqjJb0g6ZViffzrmyjMrE5VfjvzY+C0iNharKH5jKQnIuK5mmszq02VldQC2FrcHVncos6izOpWdbXkHknLgfXA4oj4v/XxzYaTSsGPiN0RMRvoBeZKmjXwOeULQ+zetq3NZZq11359qhMRfcASYP4gbZ9cGKJn7Ng2lWdWjyqf6kySNL7YHgOcCbxRc11mtaryqc5RwO8l9dB6ozwQEY/VW5ZZvap8qvMqrQu+mR0w/M2tpeTgW0oOvqXk4FtKDr6l5OBbSg6+peTgW0oOvqXk4FtKDr6l5OBbSg6+peTgW0oOvqXk4FtKDr6l5OBbSg6+peTgW0oOvqXk4FtKDr6l5OBbSpWDX6yY/LIkr6Jmw97+zPgLgZV1FWLWpKrr4/cC3wB+W285Zs2oOuPfDPwA2PNpT/D6+DacVFkm/BxgfUS8tLfneX18G06qzPjzgG9Kege4HzhN0t21VmVWs30GPyJ+GBG9ETEduAj4R0RcWntlZjXy5/iWUpUronwiIp4CnqqlErMGeca3lBx8S8nBt5QcfEvJwbeUHHxLycG3lBx8S8nBt5QcfEvJwbeUFBHt71TaAKweQheHAx+0qZzhOL5raN/40yJi0sAHawn+UElaGhFzso7vGuof34c6lpKDbyl1a/BvTz4+uIZax+/KY3yzunXrjG9Wq64KvqT5kt6UtErSNR0Y/05J6yW91vTYpRqmSloi6XVJKyQtbHj80ZJekPRKMf71TY4/oJbalq3smuBL6gFuBc4CZgIXS5rZcBmLgPkNjznQLuDqiJgJnAxc0fDr8DFwWkQcD8wG5ks6ucHxy2pbtrJrgg/MBVZFxNsRsYPWGj7nNVlARDwNbGpyzEFqWBcRy4rtLbR+8FMaHD8iYmtxd2Rxa/xEsO5lK7sp+FOAd0v319DgD7wbSZoOnAA83/C4PZKWA+uBxRHR6PiFm9nHspVD0U3BtxJJ44CHgKsi4sMmx46I3RExG+gF5kqa1eT4VZetHIpuCv5aYGrpfm/xWDqSRtIK/T0R8XCn6oiIPmAJzZ/31L5sZTcF/0XgOEnHSBpFa7nCRztcU+MkCbgDWBkRN3Zg/EmSxhfbY4AzgTearKGJZSu7JvgRsQu4EniS1gndAxGxoskaJN0HPAvMkLRG0uVNjl+YByygNcstL25nNzj+UcASSa/SmowWR8QBdxUcf3NrKXXNjG/WJAffUnLwLSUH31Jy8C0lB99ScvAtJQffUvoPP3aC0AYsZjoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.title('Padded Input')\n", "plt.imshow(wpad.detach()[0, 0]); plt.show()" ] }, { "cell_type": "markdown", "id": "cooperative-cooper", "metadata": {}, "source": [ "Next, in order to agree with the circular convolution implemented in PyTorch, we have to shift the kernel to the \"center\" of the input field, which is the top left corner. We could implement this using `torch.roll` operations, but we found that it is easier and more efficient to center the kernel\n", "in the Fourier domain using the [shift theorem](https://ccrma.stanford.edu/~jos/st/Shift_Theorem.html).\n", "\n", "We construct a \"shift matrix\" and then use it in an elementwise product in the Fourier domain. The *shift amount* ensures that the center of the kernel is in the top left of the input field; the remainder of the kernel wraps around as if the input were \"circular\"." ] }, { "cell_type": "code", "execution_count": 8, "id": "permanent-possession", "metadata": {}, "outputs": [], "source": [ "def fft_shift_matrix(n, shift_amount):\n", " shift = torch.arange(0, n).repeat((n, 1))\n", " shift = shift + shift.T\n", " return torch.exp(1j * 2 * np.pi * shift_amount * shift / n)\n", " \n", "shift_amount = (k - 1) // 2\n", "shift_matrix = fft_shift_matrix(n, -shift_amount)" ] }, { "cell_type": "markdown", "id": "elementary-irrigation", "metadata": {}, "source": [ "Then, we also must take the complex conjugate. This is equivalent to flipping the kernel horizontally and vertically--\n", "convolution as implemented in PyTorch is actually \"cross-correlation\" mathematically.\n", "The difference is just that the kernel is flipped.\n", "Conjugation in the Fourier domain is equivalent to flipping the signal in the spatial domain.\n", "Refer to the [flip theorems](https://ccrma.stanford.edu/~jos/sasp/Flip_Theorems.html)." ] }, { "cell_type": "code", "execution_count": 9, "id": "thousand-steam", "metadata": {}, "outputs": [], "source": [ "wfft = shift_matrix * torch.fft.fft2(wpad).conj()" ] }, { "cell_type": "markdown", "id": "dying-there", "metadata": {}, "source": [ "The result of our Fourier-domain shifting and flipping operations can be seen below. Compare to the padded input above." ] }, { "cell_type": "code", "execution_count": 10, "id": "textile-security", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAN0AAADSCAYAAADOksXPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPWUlEQVR4nO3de7BdZX3G8e+T5ITcuIMISSDItZERsMhlRGUi0EBR6EynBhHBSztaUmGgtIiVi4CjbbnYAYsImFpuImKlKEIsCRrljoEKwZlIExMIhJCkkMMtCb/+8b4nruyeyz6clXefvfN8ZvbM2vtd633fvdZ61u3ss5YiAjMrZ0SrO2C2uXHozApz6MwKc+jMCnPozApz6MwKKxY6SRdIuqGf8kWSjnybdb/tadudpJC0Z6vblnS1pC+3oA+nSppXut2hGDB0eYV+TdIaSS9ImiVpQonObSqSPi7pkfydlkm6S9LhNdTb74al1ZRcKuml/LqtiWnmSno9z6ue12GN40XE5yLiok3T87dH0pS8YRhVoK25kj7bzLjN7uk+EhETgPcCBwH/8HY712qSzgSuAL4K7ATsCnwTOL6F3QKgwMpxNPAJYH9gF+BbTU43MyImVF73b7IebgYGdXgZEc8CdwH7SdpW0p2SXpS0Kg9P6hlX0u6S7pP0iqTZwA7VuiSdLGlx3uJ+qaFshKRzJP0ul98qabtmpu2PpK2BrwCnRcTtEdEdEWsj4j8j4uyB2q5sOU+R9HtJK3ralzQdOBf4WN4bPN7TpqTr8h71WUkXSxqZy06V9EtJl0t6CbhA0haS/jnX/0I+bBtb+Q5n57qek/TpZr97thZ4DXg+It6IiNmDnL5P+Qjo4jx8hKSlks7N82iRpJMaxr1a0uy8ftwnabdK+b65bKWk30r6i0rZ9pLukPSypIeAPQbZx6sk/Ti3+6CkPSrlIekLkp7J/f4nSSNy2UZHMdW9qKRLgA8AV+Zlf2V//RhU6CRNBo4Ffp2n/Q6wG2lv8RpQbewm4FFS2C4CTqnUMxX4V+Bk0hZ3e2BSZdq/AU4APpTLVwFXNTltfw4DxgA/7GecPtuuOBzYB/gwcJ6kP4qIn5L2nt/Le4P987izgHXAnsCBpL1N9TDkEOAZ0l73EuBrwN7AAXmaicB5+btPB/4WOArYCxjseezTwHbAtT0r0yb0TtKyn0ha9tdI2qdSfhJpvdgBmA/cCCBpPDCbtP68A5gBfDMvd0jL4nVgZ+DT+TUYM4ALgW2BhaR5XvVnpKO595KOfgasPyK+BPyCPxwRzBxogn5fwCJgDbAaWEw6FBvby3gHAKvy8K6kFW18pfwm4IY8fB5wS6VsPPAmcGR+vwD4cKV8Z9JWetRA0w7wXU4ibeX7G6e/tqcAAUyqlD8EzMjDF/R8x/x+J+CN6vwCTgTm5OFTgd9XygR0A3tUPjsM+J88fD3wtUrZ3rk/ezbx3buA/yYdXv4o1zUil80jnUL0Nt1c4NW8/FcDj1XKNrRN2rhcnIeP6GX53wp8uTJudRlOANYDk4GPAb9o6MO3gPOBkXlZ7Fsp+yowr4++9yyvUZV2r62UHws83fB9plfe/zXwX30s28a65wKfHWg5RATNnkOcEBE/q34gaRxwOTCdtNUA2DIfOu1CCmB3ZZLFeaaSy5f0FEREdz686rEb8ENJb1U+W09aiQeatj8vATtIGhUR6/oYp7+2ezxfGX6VtNL0VVcXsExSz2cjqv1vGN4RGAc8WhlfpJUN0nd/tDL+4j7a7c00YHRE3KB0AeUu0h7vDGBfUvD68oWIuHYQbUHvy3+XyvvqMlwjaWUu3w04RNLqyrijgH8nzZ9RbDzPBjMPYOBl11j3LtRsKIcYZ5EOsQ6JiK2AD+bPBSwDts2HCj12rQwv4w8B7Anw9pXyJcAxEbFN5TUm0jnlQNP2537SnueEfsbpr+2BNP7LxpLc3g6VuraKiHf3Mc0K0mH6uyvjbx3pIhY0fHc2nqcDGUXaABARrwMfBd4DPEza66waRF3N6G35P1d5X12GE0iHvc+R5tl9DfN/QkR8HniRtAd9u/OgGY119/S5m7RB7PHOhuma/nedoYRuS9IKsjpfaDh/Q+sRi4FHgAsljVa6HP+RyrS3AcdJOlzSaNLFjWpfrgYu6Tm5lrSjpOObmTafxPc6AyLif0mHp1dJOkHSOEldko6R9I9NtD2QF4ApPedLEbEMuAe4VNJWShdp9pD0oT769xbwbeBySe/I7U+U9Cd5lFuBUyVNzRub86vT5wszi/ro2zxgjKSv5AszI4A5pEPUV5v8foPVs/w/ABwHfL9SdmxlGV4EPBARS4A7gb2VLpZ15df78nnzeuB20gWncfk875TGRofobKWLhJOB04Hv5c/nAx+UtKvSBbkvNkz3AvCuZhoYSuiuAMaSts4PAD9tKP846SLBStLK8d2egoh4EjiNdJ63jHSxYmll2m8AdwD3SHol139Ik9NOBn7VV6cj4lLgTNKfPV4kbVlnAv8xUNtN6FmpXpL0WB7+JDAaeCr39TbSeWJf/p50gv+ApJeBn5GOKIiIu0jz/d48zr0N004GftlbpXmDczRwKGnr/TvSEcLBwKck/WWT37FZz5O+73OkiySfi4inK+U3kdaLlcAfk841iYhXcj9n5GmfB74ObJGnm0k6JHyedI72nZr7/SPSIfx84MfAdblfs0kBfCKX39kw3TeAP1e6kv8v/TWg6LB/YpV0LfD9iLi71X0pTdI9wOkRsaDF/TiCdNGh16vKkmYBSyNiWP29Nx8h7RURCzdlO5v8L/WlRURTvwroRBFxdKv7YAPzD57NCuu4w0uz4c57OrPCHDqzwlpyIWXUuPHRtfV2A4+4Kfuw1dqWtr+2u6ul7VuydtVK1nd3a+Ax69OS0HVtvR27f+rMVjS9wU5HLh14pE1oyUMTW9q+JUuvvLx4mz68NCvMoTMrzKEzK8yhMyvMoTMrzKEzK8yhMyvMoTMrzKEzK8yhMyvMoTMrrJbQSZqe78S7UNI5ddRp1qmGHLp8n8urgGOAqcCJlbvxmlmDOvZ0BwMLI+KZiHgTuIVh8DAOs+GqjtBNZOO74i7Nn21E0l8pPZ7qkXWvdjcWm202il1IiYhrIuKgiDho1LjxA09g1qHqCN2zbHwr6kn5MzPrRR2hexjYS+l5dKNJd+a9o4Z6zTrSkG/XEBHrJM0E7iY9Xeb6fOtzM+tFLfdIiYifAD+poy6zTudfpJgV5tCZFebQmRXm0JkV5tCZFebQmRXm0JkV5tCZFebQmRXm0JkV1pJHZY18E7Za9FYrmt5g7KjWPp/Okoc/eVlL25928/LibXpPZ1aYQ2dWmENnVphDZ1aYQ2dWmENnVphDZ1aYQ2dWmENnVphDZ1aYQ2dWmENnVlhdz6e7XtJySb+poz6zTlbXnm4WML2musw6Wi2hi4ifAyvrqMus0/mczqywYqGrPhRy7et+KKRtvlryUMiuMX4opG2+fHhpVlhdfzK4Gbgf2EfSUkmfqaNes05U1/PpTqyjHrPNgQ8vzQpz6MwKc+jMCnPozApz6MwKc+jMCnPozApz6MwKc+jMCnPozApz6MwKa8lDIdeNhRXvUSua3mDFr6a0tH1L3vfdM1va/tKXLi/epvd0ZoU5dGaFOXRmhTl0ZoU5dGaFOXRmhTl0ZoU5dGaFOXRmhTl0ZoU5dGaFOXRmhQ05dJImS5oj6SlJT0o6vY6OmXWqOv7LYB1wVkQ8JmlL4FFJsyPiqRrqNus4Q97TRcSyiHgsD78CLAAmDrVes05V6zmdpCnAgcCDvZRteD7d+m4/n842X7WFTtIE4AfAGRHxcmN59fl0I8f7+XS2+arrUVldpMDdGBG311GnWaeq4+qlgOuABRFx2dC7ZNbZ6tjTvR84GZgmaX5+HVtDvWYdach/MoiIeUBr7zJk1kb8ixSzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8Lquq36GEkPSXo8P6PuwjrqNetEdTyfDuANYFpErMnPNZgn6a6IeKCm+s06Ri2hi4gA1uS3XfkVddRt1mnqfFTWSEnzgeXA7Ij4f8+oM7MaQxcR6yPiAGAScLCk/arlfiikWVL71cuIWA3MAaY3fO6HQppR39XLHSVtk4fHAkcBT9dRt1mnqevq5c7Av0kaSQryrRFxZ011m3WUuq5ePgEcWEddZp3Ov0gxK8yhMyvMoTMrzKEzK8yhMyvMoTMrzKEzK8yhMyvMoTMrzKEzK8yhMyusrh88D8roZ7vZ/dz7W9H0Bnc/N7+l7e896/MtbX+42OOmVS1t/8WV64u36T2dWWEOnVlhDp1ZYQ6dWWEOnVlhDp1ZYQ6dWWEOnVlhDp1ZYQ6dWWEOnVlhDp1ZYXU/tefXknxnZ7N+1LmnOx1YUGN9Zh2prgeITAL+FLi2jvrMOllde7orgL8D3uprhOrz6dbyRk3NmrWfIYdO0nHA8oh4tL/xqs+n62KLoTZr1rbq2NO9H/iopEXALcA0STfUUK9ZRxpy6CLiixExKSKmADOAeyPiE0PumVmH8t/pzAqr9cZEETEXmFtnnWadxns6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCHDqzwhw6s8IcOrPCFBHlG5VeBBYPoYodgBU1dcd9aN/26+jDbhGxY12daUZLQjdUkh6JiIPch9b2odXtD5c+DJYPL80Kc+jMCmvX0F3T6g7gPgyH9mF49GFQ2vKczqydteuezqxttV3oJE2X9FtJCyWd04L2r5e0XNJvSred258saY6kpyQ9Ken0FvRhjKSHJD2e+3Bh6T7kfrTlrfzbKnSSRgJXAccAU4ETJU0t3I1ZwPTCbVatA86KiKnAocBpLZgHbwDTImJ/4ABguqRDC/cB2vRW/m0VOuBgYGFEPBMRb5Lus3l8yQ5ExM+BlSXbbGh/WUQ8lodfIa10Ewv3ISJiTX7blV9FLw6086382y10E4EllfdLKbzCDSeSpgAHAg+2oO2RkuYDy4HZEVG6D1cwwK38h6t2C51lkiYAPwDOiIiXS7cfEesj4gBgEnCwpP1Ktd3srfyHq3YL3bPA5Mr7SfmzzYqkLlLgboyI21vZl4hYDcyh7HluW9/Kv91C9zCwl6TdJY0m3cb9jhb3qShJAq4DFkTEZS3qw46StsnDY4GjgKdLtd/ut/Jvq9BFxDpgJnA36QLCrRHxZMk+SLoZuB/YR9JSSZ8p2T5pK38yaes+P7+OLdyHnYE5kp4gbQhnR0RbXbZvJf8ixaywttrTmXUCh86sMIfOrDCHzqwwh86sMIfOrDCHzqwwh86ssP8Dy/K9LKsBS0gAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.title('Padded, Centered, & Flipped Input')\n", "plt.imshow(torch.fft.ifft2(wfft)[0, 0].real.detach()); plt.show()" ] }, { "cell_type": "markdown", "id": "composed-equity", "metadata": {}, "source": [ "Then, we implement the product described by this figure, using `einsum`.

\n", "
\n", "\"Multi-Channel
\n", "

(Colored slices: pixels; arrows: dot products; left: Fourier-domain weights; middle: input tensor; right: output tensor; bottom-right: the matrix-vector product for each pixel.)

\n", "

\n", "That is: Each channel of each pixel of the output is the dot product of the corresponding input pixel and weight pixel.\n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "induced-inspector", "metadata": {}, "outputs": [], "source": [ "yfft = torch.einsum('dchw, bchw -> bdhw', wfft, xfft)" ] }, { "cell_type": "markdown", "id": "pursuant-services", "metadata": {}, "source": [ "Now we apply the inverse Fourier transform, and we're done." ] }, { "cell_type": "code", "execution_count": 12, "id": "specialized-walker", "metadata": {}, "outputs": [], "source": [ "y2 = torch.fft.ifft2(yfft)" ] }, { "cell_type": "markdown", "id": "banned-compiler", "metadata": {}, "source": [ "The result is the same as the circular convolution we implemented with `nn.Conv2d`:" ] }, { "cell_type": "code", "execution_count": 13, "id": "catholic-maldives", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.5329876532632625e-06" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(y1 - y2).norm().item()" ] }, { "cell_type": "markdown", "id": "heated-employee", "metadata": {}, "source": [ "------------------------------------------------------------------------\n", "### Matrix Block-Diagonalization Perspective\n", "And how to orthogonalize FFT convolutions." ] }, { "cell_type": "code", "execution_count": 14, "id": "tribal-toner", "metadata": {}, "outputs": [], "source": [ "plt.rcParams['figure.figsize'] = [5, 5]" ] }, { "cell_type": "markdown", "id": "liable-joseph", "metadata": {}, "source": [ "We can make a block diagonal Fourier-domain matrix out of the $n^2$ pixels of `wfft`.\n", "Then, we can implement circular convolution in the Fourier domain by multiplying the Fourier-domain\n", "inputs with the block diagonal matrix." ] }, { "cell_type": "code", "execution_count": 15, "id": "minute-fifteen", "metadata": {}, "outputs": [], "source": [ "D = torch.block_diag(*[wfft.reshape(cin, cout, n**2)[..., i] for i in range(n**2)])" ] }, { "cell_type": "code", "execution_count": 16, "id": "unusual-specific", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(D.real.detach())\n", "plt.title(r'$\\bf{D}$: Block Diagonal Matrix')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "fuzzy-monte", "metadata": {}, "source": [ "We have to reshape the input into a batch of vectors, where we lay out the channels first. That is,\n", "the first `cin` elements of the vector belong to the first pixel, the next `cin` to the second pixel,\n", "and so on.\n", "\n", "Then, we multiply by the diagonal matrix `D` and convert back to the original input shape." ] }, { "cell_type": "code", "execution_count": 17, "id": "refined-bedroom", "metadata": {}, "outputs": [], "source": [ "yfft_matrix = (D @ xfft.reshape(batches, cin, n**2).permute(2, 1, 0).reshape(-1, batches)) \\\n", " .reshape(n, n, cin, batches).permute(3, 2, 0, 1)" ] }, { "cell_type": "markdown", "id": "fallen-evening", "metadata": {}, "source": [ "Again, this is the same as doing circular convolution using `nn.Conv2d`, as well as our previous implementation of Fourier convolutions. We apply the inverse 2D FFT and compare:" ] }, { "cell_type": "code", "execution_count": 18, "id": "verified-carter", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.5019836610008497e-06" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y3 = torch.fft.ifft2(yfft_matrix)\n", "(y1 - y3).norm().item()" ] }, { "cell_type": "markdown", "id": "exciting-start", "metadata": {}, "source": [ "-----------------------------------------\n", "We will now illustrate the \"block diagonalization\" part of this figure by constructing DFT matrices that convert the previous block diagonal matrix `D` into a block matrix of convolution matrices `C`:\n", "
\n", " \"Block-Diagonalization\n", "
" ] }, { "cell_type": "code", "execution_count": 19, "id": "usual-cartoon", "metadata": {}, "outputs": [], "source": [ "# Somewhat easier to build these matrices with np.block\n", "import numpy as np \n", "\n", "# Perfect shuffle matrices, implemented as described in the appendix of our paper\n", "def shuffle_matrix(p, q):\n", " r = p * q\n", " Ir = np.eye(r)\n", " S = []\n", " for i in range(q):\n", " for x in Ir[i:r:q]:\n", " S.append([x])\n", " return np.block(S)\n", "\n", "# The \"script F\" matrix from Corollary A.1.1 of our paper\n", "def script_F(c):\n", " global n\n", " S = shuffle_matrix(c, n**2)\n", " F = np.fft.fft(np.eye(n), norm=\"ortho\")\n", " FF = np.kron(F, F)\n", " return S @ np.kron(np.eye(c), FF)" ] }, { "cell_type": "code", "execution_count": 20, "id": "colonial-action", "metadata": {}, "outputs": [], "source": [ "Fs_cout = torch.tensor(script_F(cout), dtype=torch.complex64)\n", "Fs_cin = torch.tensor(script_F(cin), dtype=torch.complex64)" ] }, { "cell_type": "code", "execution_count": 21, "id": "polar-worship", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "C = Fs_cout.conj().T @ D @ Fs_cin\n", "plt.imshow(C.real.detach())\n", "plt.title(r'$\\bf{C}$: Block Matrix of Doubly-Block-Circulant Matrices')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "superior-arrow", "metadata": {}, "source": [ "This matrix `C` represents the circular convolution in the spatial domain.\n", "We reshape the input vector such that the $n^2$ pixels of the first channel come first,\n", "then the second channel, and so on:" ] }, { "cell_type": "code", "execution_count": 22, "id": "finnish-victory", "metadata": {}, "outputs": [], "source": [ "y_matrix = (C.real @ x.reshape(batches, cin, n**2).permute(1, 2, 0).reshape(-1, batches)) \\\n", " .reshape(cin, n, n, batches).permute(3, 0, 1, 2)" ] }, { "cell_type": "code", "execution_count": 23, "id": "helpful-standard", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.3579970072896685e-06" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(y_matrix - y3).norm().item()" ] }, { "cell_type": "markdown", "id": "fallen-travel", "metadata": {}, "source": [ "We will now make the claim in the figure above explicit: applying the Cayley transform to `C` is equivalent to applying the Cayley transform to `D` and then applying the DFT matrices:" ] }, { "cell_type": "code", "execution_count": 24, "id": "permanent-cricket", "metadata": {}, "outputs": [], "source": [ "assert cin == cout, \"This section only applies for cin == cout, since we need square matrices\"\n", "cayley = lambda B: (torch.eye(cin * n**2) - B + B.conj().T) @ torch.inverse(torch.eye(cin * n**2) + B - B.conj().T)" ] }, { "cell_type": "code", "execution_count": 25, "id": "lucky-martial", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.071995368169155e-06" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(cayley(C) - Fs_cin.conj().T @ cayley(D) @ Fs_cin).norm().item()" ] }, { "cell_type": "markdown", "id": "mathematical-shame", "metadata": {}, "source": [ "What does an orthogonal convolution look like, anyways?" ] }, { "cell_type": "code", "execution_count": 26, "id": "artificial-amendment", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "C_ortho = cayley(C)\n", "plt.imshow(C_ortho.detach().numpy().real)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "resistant-ceiling", "metadata": {}, "source": [ "As an aside, note that the imaginary part of `C_ortho` (and `C`) is zero:" ] }, { "cell_type": "code", "execution_count": 27, "id": "heard-legislation", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2.275218548675184e-06, 1.0707211686167284e-06)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "C_ortho.imag.norm().item(), C.imag.norm().item()" ] }, { "cell_type": "markdown", "id": "dominican-member", "metadata": {}, "source": [ "Let's illustrate that `C_ortho` is indeed orthogonal:" ] }, { "cell_type": "code", "execution_count": 28, "id": "complicated-entity", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAATEAAAFDCAYAAABSjzmiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUEUlEQVR4nO3de6xlZXnH8e+vwwCCFxilkykz6dA4QqmRQadUI22nIBSsgdpYK7ZmbGmmaazRtKZAmzTatA2mF0vShmTqpTS1gqAIpd5wCrU2BhllVGCEQQpl6DCDFgLahAI+/WOvwcP2XPa57Mt7zveTnOy11r6sJ3M2P5733evdJ1WFJLXqh8ZdgCQthiEmqWmGmKSmGWKSmmaISWqaISapaYaYJkaStye5Ocl9SW7vtn933HVpssXrxDRpkrwf+LuqumXctWjy2YlpEp0M3DnuItQGQ0yT6HlV9fi4i1AbDLEVKMnzkvxZknuSPJ7kP5P8TZLjxn3uJBuAB4Zdx6CSHJukkvzouGvR9AyxFSbJMcC/AycB51bV84CfBlYDQ/0PdcBz/wRwxzDrmKfNwCNVdf+4C9H0Dht3ARq59wH/A7yhqr4HUFX7gN+akHP/BJM1H7YZ2D3mGjQLQ2wF6YZqbwF++lCITNq5q+ovh1zHDcDpM9z9hap6Xd+xUzHEJpohtrK8Bni4qr64kCcn2Qi8pKo+2+1vBV5XVe8a9rmXyjQhNZfNwJ8PoRQtEefEVpa1wH8t4vk/Bpw9pnOTZGOSs7vtrUn+YjGvN8D5jgB+HDuxiWaIrSz/BRyfZNrfe5JVSf4xyb8l+Zfuk7mtSf45ybXAVcCvdFfSr+me9tIk1yb5apKXTvcag5x7Lt3zNrLwED30Op9K8p0Zfj7V9/CXAk8DexZzTg2XIbay3NDdXprk+QBJXpLksiQvAl4P7KuqnwWuBN7ePf4FwC8BvwxcVVVbq+p/uvtWV9XrgYuB35jlNWY99wAB+lbgt+lCFFjD4AH6jKo6t6qeO8PPuX0PPxW4vaqeWtC/tkbCEFtBquox4AzgJcDeJI8C1wLfrapvAS8Gbu0efiuwqdveVTOvT9vd3T4AHDvTawxw7lkDtKo+CFxOF6L0PuUcNEAXajMOJSeeE/srTFXdDfziDHffA5wGfAz4SWBvd/zQp4lPAqv6X3LKdmZ5jbnO3R9+ZwOfZ3EBuqihZ1X9zmKer9GwE9NUnwA2JPk8cAHwN333fx14RZKruwtXF/IaMzkUfjB9gMIPhuhMAdr/GlrG/BYLTYQkhwH/AKwHvgP8GvAyplzC0c2l3QAcAD4MnF5V70ryUuBdwG/2v8aUuTstU4aYpKY5nJTUtEWFWJJzktzVfSPBxUtVlCQNasHDySSrgLuBs4B99D4NuqCqJmnxrqRlbjGXWJwG3FNV9wIkuRI4n1m+geBFa1bVxg2rn9m/+2tHLeL0klaSx3nkW1X1A995t5gQO55nf3ndPuCnZnvCxg2r+dJnNjyz//M/snkRp5e0knyurpn2O92GPrGfZHuSXUl2Pfztp4d9OkkrzGJC7EFgw5T99d2xZ6mqHVW1paq2HPfC/ou9JWlxFjOcvBXYlOQEeuH1JuDNsz3h7q8d9awh5Gf+e/cz2w4tJS3EgkOsqp5K8jvAZ+gtBflgVU3Sd6NLWgEWtQC8qj4JfHKJapGkeRvrt1jMNLTsv0+SZuKyI0lNM8QkNc0Qk9S0iflm1/45MOfIJA3CTkxS0wwxSU2bmOFkv9mGlw4tJR1iJyapaYaYpKYZYpKaNrFzYv1coiRpOnZikppmiElqmiEmqWnNzIlN5TVkkg6xE5PUNENMUtOaHE728/ILaeWyE5PUNENMUtMMMUlNWxZzYlP5DbHSymInJqlphpikphlikpq27ObE+rlESVre5uzEknwwycEkt085tibJjUn2drfHDrdMSZreIMPJvwfO6Tt2MbCzqjYBO7t9SRq5OYeTVfX5JBv7Dp8PbO22rwBuBi5aysKGxSVK0vKy0In9tVW1v9t+CFi7RPVI0rws+tPJqiqgZro/yfYku5LsepInFns6SXqWhYbYgSTrALrbgzM9sKp2VNWWqtqymiMWeDpJmt5CL7G4HtgGXNrdXrdkFY2QS5Sk9g1yicVHgC8CJybZl+RCeuF1VpK9wGu6fUkauUE+nbxghrvOXOJaJGneXHYkqWnLftnRfLhESWqPnZikphlikprmcHIWLlGSJp+dmKSmGWKSmmaISWqac2ID8vILaTLZiUlqmiEmqWkOJxfIyy+kyWAnJqlphpikphlikprmnNgS8BtipfGxE5PUNENMUtMMMUlNc05sCFyiJI2OnZikphlikprmcHIEXKIkDY+dmKSmGWKSmmaISWqac2Ij5hIlaWnZiUlq2pwhlmRDkpuS3JnkjiTv6I6vSXJjkr3d7bHDL1eSnm2QTuwp4Peq6mTglcDbkpwMXAzsrKpNwM5uX5JGas45saraD+zvth9Psgc4Hjgf2No97ArgZuCioVS5jLlESVqceU3sJ9kInArcAqztAg7gIWDtDM/ZDmwHOJKjFlyoJE1n4In9JM8FPga8s6oem3pfVRVQ0z2vqnZU1Zaq2rKaIxZVrCT1G6gTS7KaXoB9uKo+3h0+kGRdVe1Psg44OKwiVxKXKEnzM8inkwE+AOypqr+actf1wLZuextw3dKXJ0mzG6QTezXwFuDrSXZ3x/4AuBT4aJILgfuBNw6lQkmaxSCfTn4ByAx3n7m05UjS/LjsaIJ5+YU0N5cdSWqaISapaQ4nG+LlF9IPshOT1DRDTFLTDDFJTXNOrFF+Q6zUYycmqWmGmKSmGWKSmuac2DLhEiWtVHZikppmiElqmsPJZcolSlop7MQkNc0Qk9Q0Q0xS05wTWwFcoqTlzE5MUtMMMUlNM8QkNc05sRXIJUpaTuzEJDXNEJPUNIeTcomSmmYnJqlpc4ZYkiOTfCnJV5PckeQ93fETktyS5J4kVyU5fPjlStKzDdKJPQGcUVWnAJuBc5K8Engv8L6qejHwCHDh0KqUpBnMOSdWVQV8p9td3f0UcAbw5u74FcC7gcuXvkSNkpdfqDUDzYklWZVkN3AQuBH4JvBoVT3VPWQfcPxQKpSkWQwUYlX1dFVtBtYDpwEnDXqCJNuT7Eqy60meWFiVkjSDeV1iUVWPJrkJeBVwTJLDum5sPfDgDM/ZAewAeH7W1CLr1Yh5+YUm3SCfTh6X5Jhu+znAWcAe4CbgDd3DtgHXDalGSZrRIJ3YOuCKJKvohd5Hq+qGJHcCVyb5E+A24ANDrFOSpjXIp5NfA06d5vi99ObHJGlsXHakgfkNsZpELjuS1DRDTFLTDDFJTXNOTAvmEiVNAjsxSU0zxCQ1zeGkloxLlDQOdmKSmmaISWqaISapac6JaShcoqRRsROT1DRDTFLTDDFJTXNOTCPhEiUNi52YpKYZYpKa5nBSY+ESJS0VOzFJTTPEJDXNEJPUNOfENHZefqHFsBOT1DRDTFLTHE5q4nj5hebDTkxS0wYOsSSrktyW5IZu/4QktyS5J8lVSQ4fXpmSNL35dGLvAPZM2X8v8L6qejHwCHDhUhYmSYMYaE4syXrgF4A/BX43SYAzgDd3D7kCeDdw+RBq1ArmN8RqLoN2Yn8N/D7wvW7/hcCjVfVUt78POH5pS5Okuc0ZYkleBxysqi8v5ARJtifZlWTXkzyxkJeQpBkNMpx8NXBektcCRwLPBy4DjklyWNeNrQcenO7JVbUD2AHw/KypJalakjpzhlhVXQJcApBkK/CuqvrVJFcDbwCuBLYB1w2vTKnHJUrqt5jrxC6iN8l/D705sg8sTUmSNLh5XbFfVTcDN3fb9wKnLX1JkjQ4lx2paS5RksuOJDXNEJPUNENMUtOcE9Oy4RKllclOTFLTDDFJTTPEJDXNOTEtWy5RWhnsxCQ1zRCT1DSHk1oxXKK0PNmJSWqaISapaYaYpKY5J6YVycsvlg87MUlNM8QkNc3hpISXX7TMTkxS0wwxSU0zxCQ1zTkxqY/fENsWOzFJTTPEJDXNEJPUNOfEpDm4RGmyDRRiSe4DHgeeBp6qqi1J1gBXARuB+4A3VtUjwylTkqY3n+Hkz1XV5qra0u1fDOysqk3Azm5fkkZqMcPJ84Gt3fYVwM3ARYusR5p4LlGaLIN2YgV8NsmXk2zvjq2tqv3d9kPA2iWvTpLmMGgndnpVPZjkh4Ebk3xj6p1VVUlquid2obcd4EiOWlSxktRvoE6sqh7sbg8C1wKnAQeSrAPobg/O8NwdVbWlqras5oilqVqSOnN2YkmOBn6oqh7vts8G/hi4HtgGXNrdXjfMQqVJ5BKl8RtkOLkWuDbJocf/U1V9OsmtwEeTXAjcD7xxeGVK0vTmDLGquhc4ZZrj3wbOHEZRkjQolx1JaprLjqQl5BKl0bMTk9Q0Q0xS0xxOSkPkEqXhsxOT1DRDTFLTDDFJTXNOTBoRL78YDjsxSU0zxCQ1zeGkNCZefrE07MQkNc0Qk9Q0Q0xS05wTkyaA3xC7cHZikppmiElqmiEmqWnOiUkTyCVKg7MTk9Q0Q0xS0xxOSg1widLM7MQkNc0Qk9Q0Q0xS05wTkxrjEqVnsxOT1LSBQizJMUmuSfKNJHuSvCrJmiQ3Jtnb3R477GIlqd+gndhlwKer6iTgFGAPcDGws6o2ATu7fUkaqTnnxJK8APgZ4K0AVfV/wP8lOR/Y2j3sCuBm4KJhFClpZit9idIgndgJwMPAh5LcluT9SY4G1lbV/u4xDwFrp3tyku1JdiXZ9SRPLE3VktQZJMQOA14OXF5VpwLfpW/oWFUF1HRPrqodVbWlqras5ojF1itJzzLIJRb7gH1VdUu3fw29EDuQZF1V7U+yDjg4rCIlDW6lLVGasxOrqoeAB5Kc2B06E7gTuB7Y1h3bBlw3lAolaRaDXuz6duDDSQ4H7gV+nV4AfjTJhcD9wBuHU6IkzWygEKuq3cCWae46c0mrkaR5ctmRtIythMsvXHYkqWmGmKSmOZyUVpDlePmFnZikphlikppmiElqmnNi0gq1XL4h1k5MUtMMMUlNM8QkNc05MUlAu0uU7MQkNc0Qk9Q0h5OSptXKEiU7MUlNM8QkNc0Qk9Q058QkzWmSlyjZiUlqmiEmqWmGmKSmOScmad4maYmSnZikphlikprmcFLSoo1ziZKdmKSmzRliSU5MsnvKz2NJ3plkTZIbk+ztbo8dRcGSNNWcIVZVd1XV5qraDLwC+F/gWuBiYGdVbQJ2dvuSNFLznRM7E/hmVd2f5Hxga3f8CuBm4KKlK01Si0Z9+cV8Q+xNwEe67bVVtb/bfghYO90TkmwHtgMcyVELqVGSZjTwxH6Sw4HzgKv776uqAmq651XVjqraUlVbVnPEgguVpOnMpxM7F/hKVR3o9g8kWVdV+5OsAw4ufXmSWjfsyy/mc4nFBXx/KAlwPbCt294GXLfoaiRpngYKsSRHA2cBH59y+FLgrCR7gdd0+5I0UgMNJ6vqu8AL+459m96nlZI0Ni47kjQyw/iGWJcdSWqaISapaYaYpKY5JyZpbJZiiZKdmKSmGWKSmuZwUtLEmG2J0qp10z/HTkxS0wwxSU0zxCQ1Lb2vAhvRyZKHgfuBFwHfGtmJ52Y9s5u0emDyarKe2S1FPT9aVcf1HxxpiD1z0mRXVW0Z+YlnYD2zm7R6YPJqsp7ZDbMeh5OSmmaISWrauEJsx5jOOxPrmd2k1QOTV5P1zG5o9YxlTkySlorDSUlNG2mIJTknyV1J7kkylr8YnuSDSQ4muX3KsTVJbkyyt7s9doT1bEhyU5I7k9yR5B3jrCnJkUm+lOSrXT3v6Y6fkOSW7nd3Vfcn/EYmyaoktyW5Ydz1JLkvydeT7E6yqzs2tvdQd/5jklyT5BtJ9iR51RjfQyd2/zaHfh5L8s5h1TOyEEuyCvhben/67WTggiQnj+r8U/w9cE7fsYuBnVW1CdjZ7Y/KU8DvVdXJwCuBt3X/LuOq6QngjKo6BdgMnJPklcB7gfdV1YuBR4ALR1TPIe8A9kzZH3c9P1dVm6dcNjDO9xDAZcCnq+ok4BR6/1Zjqamq7ur+bTYDrwD+F7h2aPVU1Uh+gFcBn5myfwlwyajO31fLRuD2Kft3Aeu67XXAXeOoqzv/dfT+stTYawKOAr4C/BS9CxUPm+53OYI61ndv+jOAG4CMuZ77gBf1HRvb7wt4AfCfdHPck1DTlBrOBv5jmPWMcjh5PPDAlP193bFJsLaq9nfbDwFrx1FEko3AqcAt46ypG7rtpvcHkW8Evgk8WlVPdQ8Z9e/ur4HfB77X7b9wzPUU8NkkX06yvTs2zvfQCcDDwIe6Iff7uz+zOAnv6zfx/b9XO5R6nNjvU73/TYz8I9skzwU+Bryzqh4bZ01V9XT1hgLrgdOAk0Z17n5JXgccrKovj6uGaZxeVS+nNzXytiQ/M/XOMbyHDgNeDlxeVacC36VvqDaO93U3T3kecHX/fUtZzyhD7EFgw5T99d2xSXAgyTqA7vbgKE+eZDW9APtwVR36A8VjrQmgqh4FbqI3XDsmyaHvnxvl7+7VwHlJ7gOupDekvGyM9VBVD3a3B+nN9ZzGeH9f+4B9VXVLt38NvVAb93voXOArVXWg2x9KPaMMsVuBTd2nSofTazOvH+H5Z3M9sK3b3kZvXmokkgT4ALCnqv5q3DUlOS7JMd32c+jNz+2hF2ZvGHU9VXVJVa2vqo303jP/WlW/Oq56khyd5HmHtunN+dzOGN9DVfUQ8ECSE7tDZwJ3jrOmzgV8fyjJ0OoZ8STfa4G76c2x/OGoJxm7Gj4C7AeepPd/sAvpzbHsBPYCnwPWjLCe0+m11V8Ddnc/rx1XTcDLgNu6em4H/qg7/mPAl4B76A0PjhjD724rcMM46+nO+9Xu545D7+Nxvoe6828GdnW/t08Ax475fX008G3gBVOODaUer9iX1DQn9iU1zRCT1DRDTFLTDDFJTTPEJDXNEJPUNENMUtMMMUlN+39fmirXKsoBIgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow((C_ortho @ C_ortho.T).detach().numpy().real)\n", "plt.title(r\"$C_\\mathsf{ortho} C_\\mathsf{ortho}^T = I$\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "worst-entity", "metadata": {}, "source": [ "--------------------------------------------------------\n", "## Orthogonalizing FFT Convolutions\n", "\n", "Putting the ideas illustrated above together, we want to orthogonalize\n", "a convolution in the Fourier domain without explicitly constructing a large matrix.\n", "Essentially, we just want to operate on the blocks of `D`.\n", "Recall that we extracted those blocks from the convolution weight tensor earlier:\n", "*we will just rearrange the tensor to make working with those blocks easier.*\n", "\n", "In particular, we will use the fact that each pixel of the output is given by a matrix vector product:\n", "we will arrange the input and weight tensors so that we can compute the convolution\n", "with a batch of such matrix multiplications." ] }, { "cell_type": "code", "execution_count": 29, "id": "minor-jason", "metadata": {}, "outputs": [], "source": [ "xfft = torch.fft.fft2(x)\n", "xfft = xfft.permute(2, 3, 1, 0) # batches, cin, n, n -> n, n, cin, batches\n", "xfft = xfft.reshape(n**2, cin, batches) # n**2 input pixels" ] }, { "cell_type": "code", "execution_count": 30, "id": "infinite-liability", "metadata": {}, "outputs": [], "source": [ "wfft = shift_matrix * torch.fft.fft2(wpad).conj() # cout, cin, n, n\n", "wfft = wfft.reshape(cout, cin, n**2) # cout, cin, n**2 pixels\n", "wfft = wfft.permute(2, 0, 1) # n**2, cout, cin" ] }, { "cell_type": "markdown", "id": "specific-breakdown", "metadata": {}, "source": [ "Now, similarly to the block diagonal matrix `D` earlier,\n", "we have `xfft` of shape ($n^2$, `cin`, `batches`) and\n", "`wfft` of shape ($n^2$, `cout`, `cin`).\n", "The last two dimensions for each of these are compatible for matrix multiplication with output\n", "shape ($n^2$, `cout`, `batches`).\n", "\n", "We can do batch matrix multiplication, and reshape again:" ] }, { "cell_type": "code", "execution_count": 31, "id": "insured-botswana", "metadata": {}, "outputs": [], "source": [ "yfft = wfft @ xfft # n**2, cout, batches\n", "yfft = yfft.reshape(n, n, cout, batches)\n", "yfft = yfft.permute(3, 2, 0, 1) # batches, cout, n, n\n", "y4 = torch.fft.ifft2(yfft)" ] }, { "cell_type": "markdown", "id": "therapeutic-vietnamese", "metadata": {}, "source": [ "Again, this is equivalent to the original circular convolution:" ] }, { "cell_type": "code", "execution_count": 32, "id": "short-malaysia", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.5329876532632625e-06" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(y1 - y4).norm().item()" ] }, { "cell_type": "markdown", "id": "quality-retreat", "metadata": { "tags": [] }, "source": [ "But this time, it's very easy to orthogonalize the convolution by orthogonalizing\n", "all of the \"blocks\":" ] }, { "cell_type": "code", "execution_count": 33, "id": "invisible-quantity", "metadata": {}, "outputs": [], "source": [ "# \"Batched\" Cayley transform\n", "wfft_skew_sym = wfft - wfft.conj().transpose(1, 2)\n", "I = torch.eye(cin, dtype=wfft.dtype)[None, :, :]\n", "wfft_ortho = (I - wfft_skew_sym) @ torch.inverse(I + wfft_skew_sym)" ] }, { "cell_type": "code", "execution_count": 34, "id": "heated-ranking", "metadata": {}, "outputs": [], "source": [ "yfft_ortho = wfft_ortho @ xfft # n**2, cout, batches\n", "yfft_ortho = yfft_ortho.reshape(n, n, cout, batches)\n", "yfft_ortho = yfft_ortho.permute(3, 2, 0, 1) # batches, cout, n, n\n", "y_ortho = torch.fft.ifft2(yfft_ortho)" ] }, { "cell_type": "markdown", "id": "lyric-pocket", "metadata": {}, "source": [ "We can see that this implements an orthogonal convolution because\n", "it doesn't expand the norm of input tensors:" ] }, { "cell_type": "code", "execution_count": 35, "id": "completed-welsh", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y_ortho norm: 8.151487350463867\n", "input (x) norm: 8.151487350463867\n" ] } ], "source": [ "print('y_ortho norm: ', y_ortho.norm().item())\n", "print('input (x) norm: ', x.norm().item())" ] }, { "cell_type": "markdown", "id": "understood-casting", "metadata": {}, "source": [ "---------------------------------------------------------------------\n", "### Optimizations and More" ] }, { "cell_type": "markdown", "id": "prime-efficiency", "metadata": {}, "source": [ "Using the symmetry of the Fourier transform, we can implement FFT-based convolutions\n", "even more efficiently. We can also handle strided FFT-based convolutions using the *aliasing theorem*, which can be found [here](https://ccrma.stanford.edu/~jos/st/Downsampling_Theorem_Aliasing_Theorem.html).\n", "\n", "We provide a more fully-featured FFT-based convolution layer in `extras/fftconv.py`." ] }, { "cell_type": "code", "execution_count": 38, "id": "outside-declaration", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0860642305488e-06" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from extras.fftconv import FFTConv\n", "\n", "fft_conv = FFTConv(cin, cout, k, bias=False)\n", "fft_conv.weight.data = conv.weight.data\n", "\n", "(fft_conv(x) - y1).norm().item()" ] }, { "cell_type": "markdown", "id": "conceptual-calcium", "metadata": {}, "source": [ "### Orthogonal Convolutions when `cin` != `cout`" ] }, { "cell_type": "markdown", "id": "norwegian-dictionary", "metadata": {}, "source": [ "In the appendix of our paper, we describe a method to efficiently \"orthogonalize\" convolutions with different numbers of input and output channels. This is actually called *semi-orthogonalization*.\n", "More details can be found in `Orthogonal Convolutions.ipynb`." ] }, { "cell_type": "markdown", "id": "adaptive-bristol", "metadata": {}, "source": [ "-----------------------------------------------------------------------\n", "You can cite our work as follows:\n", "\n", "
\n",
    "@inproceedings{trockman2021ortho,\n",
    "    title={Orthogonalizing Convolutional Layers with the Cayley Transform},\n",
    "    author={Asher Trockman and J. Zico Kolter},\n",
    "    booktitle={International Conference on Learning Representations},\n",
    "    year={2021},\n",
    "    url={https://openreview.net/forum?id=Pbj8H_jEHYv}\n",
    "}\n",
    "
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 5 }