{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\nCreating extensions using numpy and scipy\n=========================================\n**Author**: `Adam Paszke `_\n\nIn this tutorial, we shall go through two tasks:\n\n1. Create a neural network layer with no parameters.\n\n - This calls into **numpy** as part of it\u2019s implementation\n\n2. Create a neural network layer that has learnable weights\n\n - This calls into **SciPy** as part of it\u2019s implementation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torch.autograd import Function\nfrom torch.autograd import Variable" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parameter-less example\n----------------------\n\nThis layer doesn\u2019t particularly do anything useful or mathematically\ncorrect.\n\nIt is aptly named BadFFTFunction\n\n**Layer Implementation**\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from numpy.fft import rfft2, irfft2\n\n\nclass BadFFTFunction(Function):\n\n def forward(self, input):\n numpy_input = input.numpy()\n result = abs(rfft2(numpy_input))\n return torch.FloatTensor(result)\n\n def backward(self, grad_output):\n numpy_go = grad_output.numpy()\n result = irfft2(numpy_go)\n return torch.FloatTensor(result)\n\n# since this layer does not have any parameters, we can\n# simply declare this as a function, rather than as an nn.Module class\n\n\ndef incorrect_fft(input):\n return BadFFTFunction()(input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Example usage of the created layer:**\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "input = Variable(torch.randn(8, 8), requires_grad=True)\nresult = incorrect_fft(input)\nprint(result.data)\nresult.backward(torch.randn(result.size()))\nprint(input.grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parametrized example\n--------------------\n\nThis implements a layer with learnable weights.\n\nIt implements the Cross-correlation with a learnable kernel.\n\nIn deep learning literature, it\u2019s confusingly referred to as\nConvolution.\n\nThe backward computes the gradients wrt the input and gradients wrt the\nfilter.\n\n**Implementation:**\n\n*Please Note that the implementation serves as an illustration, and we\ndid not verify it\u2019s correctness*\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from scipy.signal import convolve2d, correlate2d\nfrom torch.nn.modules.module import Module\nfrom torch.nn.parameter import Parameter\n\n\nclass ScipyConv2dFunction(Function):\n\n def forward(self, input, filter):\n result = correlate2d(input.numpy(), filter.numpy(), mode='valid')\n self.save_for_backward(input, filter)\n return torch.FloatTensor(result)\n\n def backward(self, grad_output):\n input, filter = self.saved_tensors\n grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')\n grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')\n return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)\n\n\nclass ScipyConv2d(Module):\n\n def __init__(self, kh, kw):\n super(ScipyConv2d, self).__init__()\n self.filter = Parameter(torch.randn(kh, kw))\n\n def forward(self, input):\n return ScipyConv2dFunction()(input, self.filter)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Example usage:**\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "module = ScipyConv2d(3, 3)\nprint(list(module.parameters()))\ninput = Variable(torch.randn(10, 10), requires_grad=True)\noutput = module(input)\nprint(output)\noutput.backward(torch.randn(8, 8))\nprint(input.grad)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 0 }