{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Installing packages:\n", "\t.package(path: \"/home/jekbradbury/git/fastai_docs/dev_swift/FastaiNotebook_06_cuda\")\n", "\t\tFastaiNotebook_06_cuda\n", "With SwiftPM flags: []\n", "Working in: /tmp/tmpp95uqama/swift-install\n", "/home/jekbradbury/swift/usr/bin/swift-build: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/lib/swift/linux/libFoundation.so)\n", "/home/jekbradbury/swift/usr/bin/swift-build: /home/jekbradbury/anaconda3/lib/libcurl.so.4: no version information available (required by /home/jekbradbury/swift/usr/lib/swift/linux/libFoundation.so)\n", "Fetching https://github.com/mxcl/Path.swift\n", "Fetching https://github.com/JustHTTP/Just\n", "Fetching https://github.com/latenitesoft/NotebookExport\n", "Completed resolution in 3.65s\n", "Cloning https://github.com/latenitesoft/NotebookExport\n", "Resolving https://github.com/latenitesoft/NotebookExport at 0.5.0\n", "Cloning https://github.com/JustHTTP/Just\n", "Resolving https://github.com/JustHTTP/Just at 0.7.1\n", "Cloning https://github.com/mxcl/Path.swift\n", "Resolving https://github.com/mxcl/Path.swift at 0.16.2\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "Compile Swift Module 'Just' (1 sources)\n", "Compile Swift Module 'Path' (9 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'NotebookExport' (2 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_00_load_data' (1 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_01_matmul' (2 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_01a_fastai_layers' (3 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_02_fully_connected' (4 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_02a_why_sqrt5' (5 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_03_minibatch_training' (6 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_04_callbacks' (7 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_05_anneal' (8 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_05b_early_stopping' (9 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'FastaiNotebook_06_cuda' (10 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Compile Swift Module 'jupyterInstalledPackages' (1 sources)\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift)\n", "\n", "Linking ./.build/x86_64-unknown-linux/debug/libjupyterInstalledPackages.so\n", "/home/jekbradbury/swift/usr/bin/swiftc: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swiftc)\n", "\n", "/home/jekbradbury/swift/usr/bin/swift-autolink-extract: /home/jekbradbury/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jekbradbury/swift/usr/bin/swift-autolink-extract)\n", "\n", "Initializing Swift...\n", "Installation complete!\n" ] } ], "source": [ "%install '.package(path: \"$cwd/FastaiNotebook_06_cuda\")' FastaiNotebook_06_cuda" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('inline', 'module://ipykernel.pylab.backend_inline')\n" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import FastaiNotebook_06_cuda\n", "%include \"EnableIPythonDisplay.swift\"\n", "IPythonDisplay.shell.enable_matplotlib(\"inline\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "//export\n", "import Path\n", "import TensorFlow\n", "import Python" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by building our own batchnorm layer from scratch. Eventually we want something like this to work:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class AlmostBatchNorm { // : Layer\n", " // Configuration hyperparameters\n", " let momentum, epsilon: Scalar\n", " // Running statistics\n", " var runningMean, runningVariance: Tensor\n", " // Trainable parameters\n", " var scale, offset: Tensor\n", " \n", " init(featureCount: Int, momentum: Scalar = 0.9, epsilon: Scalar = 1e-5) {\n", " (self.momentum, self.epsilon) = (momentum, epsilon)\n", " (scale, offset) = (Tensor(ones: [featureCount]), Tensor(zeros: [featureCount]))\n", " (runningMean, runningVariance) = (Tensor(0), Tensor(1))\n", " }\n", "\n", " func call(_ input: Tensor) -> Tensor {\n", " let mean, variance: Tensor\n", " switch Context.local.learningPhase {\n", " case .training:\n", " mean = input.mean(alongAxes: [0, 1, 2])\n", " variance = input.variance(alongAxes: [0, 1, 2])\n", " runningMean += (mean - runningMean) * (1 - momentum)\n", " runningVariance += (variance - runningVariance) * (1 - momentum)\n", " case .inference:\n", " (mean, variance) = (runningMean, runningVariance)\n", " }\n", " let normalizer = rsqrt(variance + epsilon) * scale\n", " return (input - mean) * normalizer + offset\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But there are some automatic differentiation limitations (lack of support for classes and control flow) that make this impossible for now, so we'll need a few workarounds. A `Reference` will let us update running statistics without making the layer a class or declaring the `applied` method `mutating`:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "//export\n", "public class Reference {\n", " public var value: T\n", " public init(_ value: T) { self.value = value }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following snippet will let us differentiate a layer's `applied` method if it's composed of training and inference implementations that are each differentiable:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "//export\n", "public protocol LearningPhaseDependent: FALayer {\n", " associatedtype Input\n", " associatedtype Output\n", " @differentiable func forwardTraining(to input: Input) -> Output\n", " @differentiable func forwardInference(to input: Input) -> Output\n", "}\n", "\n", "extension LearningPhaseDependent {\n", " public func forward(_ input: Input) -> Output {\n", " switch Context.local.learningPhase {\n", " case .training: return forwardTraining(to: input)\n", " case .inference: return forwardInference(to: input)\n", " }\n", " }\n", "\n", " @differentiating(forward)\n", " func gradForward(_ input: Input) ->\n", " (value: Output, pullback: (Self.Output.CotangentVector) ->\n", " (Self.CotangentVector, Self.Input.CotangentVector)) {\n", " switch Context.local.learningPhase {\n", " case .training: return valueWithPullback(at: input) { $0.forwardTraining(to: $1) }\n", " case .inference: return valueWithPullback(at: input) { $0.forwardInference(to: $1) }\n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can implement a BatchNorm that we can use in our models:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "//export\n", "public protocol Norm: FALayer where Input == Tensor, Output == Tensor {\n", " init(featureCount: Int, epsilon: Float)\n", "}\n", "\n", "public struct FABatchNorm: LearningPhaseDependent, Norm {\n", " // Configuration hyperparameters\n", " @noDerivative var momentum, epsilon: Float\n", " // Running statistics\n", " @noDerivative let runningMean, runningVariance: Reference>\n", " // Trainable parameters\n", " public var scale, offset: Tensor\n", " \n", " public init(featureCount: Int, momentum: Float, epsilon: Float = 1e-5) {\n", " self.momentum = momentum\n", " self.epsilon = epsilon\n", " self.scale = Tensor(ones: [featureCount])\n", " self.offset = Tensor(zeros: [featureCount])\n", " self.runningMean = Reference(Tensor(0))\n", " self.runningVariance = Reference(Tensor(1))\n", " }\n", " \n", " public init(featureCount: Int, epsilon: Float = 1e-5) {\n", " self.init(featureCount: featureCount, momentum: 0.9, epsilon: epsilon)\n", " }\n", "\n", " @differentiable\n", " public func forwardTraining(to input: Tensor) -> Tensor {\n", " let mean = input.mean(alongAxes: [0, 1, 2])\n", " let variance = input.variance(alongAxes: [0, 1, 2])\n", " runningMean.value += (mean - runningMean.value) * (1 - momentum)\n", " runningVariance.value += (variance - runningVariance.value) * (1 - momentum)\n", " let normalizer = rsqrt(variance + epsilon) * scale\n", " return (input - mean) * normalizer + offset\n", " }\n", " \n", " @differentiable\n", " public func forwardInference(to input: Tensor) -> Tensor {\n", " let (mean, variance) = (runningMean.value, runningVariance.value)\n", " let normalizer = rsqrt(variance + epsilon) * scale\n", " return (input - mean) * normalizer + offset\n", " }\n", " \n", " // Things that should probably be synthesized/inherited, but aren't\n", " public typealias Input = Tensor\n", " public typealias Output = Tensor\n", "}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct ConvNorm: FALayer\n", " where NormType.AllDifferentiableVariables == NormType.CotangentVector {\n", " public var conv: FANoBiasConv2D\n", " public var norm: NormType\n", " \n", " public init(_ cIn: Int, _ cOut: Int, ks: Int = 3, stride: Int = 2){\n", " self.conv = FANoBiasConv2D(\n", " filterShape: (ks, ks, cIn, cOut), \n", " strides: (stride, stride), \n", " padding: .same, \n", " activation: relu)\n", " self.norm = NormType(featureCount: cOut, epsilon: 1e-5)\n", " }\n", "\n", " @differentiable\n", " public func forward(_ input: Tensor) -> Tensor {\n", " return norm(conv(input))\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct CnnModelNormed: FALayer\n", " where NormType.AllDifferentiableVariables == NormType.CotangentVector {\n", " public var convs: [ConvNorm]\n", " public var pool = FAGlobalAvgPool2D()\n", " public var flatten = Flatten()\n", " public var linear: FADense\n", " \n", " public init(channelIn: Int, nOut: Int, filters: [Int]){\n", " convs = []\n", " let allFilters = [channelIn] + filters\n", " for i in 0..(allFilters[i], allFilters[i+1])) }\n", " linear = FADense(filters.last!, nOut)\n", " }\n", " \n", " @differentiable\n", " public func forward(_ input: TF) -> TF {\n", " return input.sequenced(through: convs, pool, flatten, linear)\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's benchmark this batchnorm implementation!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "func benchmark(forward: () -> (), backward: () -> ()) {\n", " print(\"forward:\")\n", " time(repeating: 10, forward)\n", " print(\"backward:\")\n", " time(repeating: 10, backward)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let input = Tensor(randomUniform: [64, 28, 28, 32])\n", "let norm = FABatchNorm(featureCount: 32)\n", "let pb = pullback(at: input) { x in norm(x) }\n", "benchmark(forward: { norm(input) }, backward: { pb(input) })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yikes, that's pretty bad. Luckily, TensorFlow has a built-in fused batchnorm layer. Let's see how the performance looks for that:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let input = Tensor(randomUniform: [64, 28, 28, 32])\n", "let norm = FABatchNorm(featureCount: 32)\n", "let bnresult = Raw.fusedBatchNormV2(\n", " input, scale: norm.scale, offset: norm.offset, \n", " mean: Tensor([] as [Float]), variance: Tensor([] as [Float]), \n", " epsilon: Double(norm.epsilon))\n", "benchmark(\n", " forward: {\n", " Raw.fusedBatchNormV2(\n", " input, scale: norm.scale, offset: norm.offset, \n", " mean: Tensor([] as [Float]), variance: Tensor([] as [Float]), \n", " epsilon: Double(norm.epsilon))\n", " },\n", " backward: {\n", " Raw.fusedBatchNormGradV2(\n", " yBackprop: input, input, scale: Tensor(norm.scale), \n", " reserveSpace1: bnresult.reserveSpace1, \n", " reserveSpace2: bnresult.reserveSpace2, \n", " epsilon: Double(norm.epsilon))\n", " })" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "struct PullbackArgs : TensorGroup {\n", " let input: T\n", " let cotangent: U\n", "}\n", "\n", "class CompiledFunction {\n", " let f: @differentiable (Input) -> Output\n", " init(_ f: @escaping @differentiable (Input) -> Output) {\n", " self.f = f\n", " }\n", "}\n", "\n", "func xlaCompiled(\n", " _ fn: @escaping @differentiable (T) -> U) -> CompiledFunction\n", " where T.CotangentVector : TensorGroup, U.CotangentVector : TensorGroup {\n", " let xlaCompiledFn: (T) -> U = _graph(fn, useXLA: true)\n", " let xlaCompiledPullback = _graph(\n", " { (pbArgs: PullbackArgs) in\n", " pullback(at: pbArgs.input, in: fn)(pbArgs.cotangent) },\n", " useXLA: true\n", " )\n", " return CompiledFunction(differentiableFunction { x in\n", " (value: xlaCompiledFn(x), pullback: { v in\n", " xlaCompiledPullback(PullbackArgs(input: x, cotangent: v))})\n", " })\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "struct TrainingKernelInput: TensorGroup, Differentiable, AdditiveArithmetic {\n", " let input: Tensor\n", " let scale: Tensor\n", " let offset: Tensor\n", " let runningMean: Tensor\n", " let runningVariance: Tensor\n", " let momentum: Tensor\n", " let epsilon: Tensor\n", "}\n", "\n", "struct TrainingKernelOutput: TensorGroup, Differentiable, AdditiveArithmetic {\n", " let normalized: Tensor\n", " let newRunningMean: Tensor\n", " let newRunningVariance: Tensor\n", "}\n", "\n", "@differentiable\n", "func trainingKernel(_ input: TrainingKernelInput) -> TrainingKernelOutput {\n", " let mean = input.input.mean(alongAxes: [0, 1, 2])\n", " let variance = input.input.variance(alongAxes: [0, 1, 2])\n", " let invMomentum = Tensor(1) - input.momentum\n", " let newRunningMean = input.runningMean * input.momentum + mean * invMomentum\n", " let newRunningVariance = input.runningVariance * input.momentum + variance * invMomentum\n", " let normalizer = rsqrt(variance + input.epsilon) * input.scale\n", " let normalized = (input.input - mean) * normalizer + input.offset\n", " return TrainingKernelOutput(\n", " normalized: normalized,\n", " newRunningMean: newRunningMean,\n", " newRunningVariance: newRunningVariance\n", " )\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let input = Tensor(randomUniform: [64, 28, 28, 32])\n", "let norm = FABatchNorm(featureCount: 32)\n", "let compiledTrainingKernel = xlaCompiled(trainingKernel)\n", "let kernelInput = TrainingKernelInput(\n", " input: input,\n", " scale: norm.scale,\n", " offset: norm.offset,\n", " runningMean: norm.runningMean.value,\n", " runningVariance: norm.runningVariance.value,\n", " momentum: Tensor(norm.momentum),\n", " epsilon: Tensor(norm.epsilon))\n", "let pb = pullback(at: kernelInput) { x in compiledTrainingKernel.f(x) }\n", "let kernelOutput = compiledTrainingKernel.f(kernelInput)\n", "\n", "benchmark(\n", " forward: { compiledTrainingKernel.f(kernelInput) },\n", " backward: { pb(kernelOutput) })" ] } ], "metadata": { "kernelspec": { "display_name": "Swift", "language": "swift", "name": "swift" }, "language_info": { "file_extension": ".swift", "mimetype": "text/x-swift", "name": "swift", "version": "" } }, "nbformat": 4, "nbformat_minor": 2 }