{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%install '.package(path: \"$cwd/FastaiNotebook_06_cuda\")' FastaiNotebook_06_cuda" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import FastaiNotebook_06_cuda\n", "%include \"EnableIPythonDisplay.swift\"\n", "IPythonDisplay.shell.enable_matplotlib(\"inline\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "import Path\n", "import TensorFlow\n", "import Python" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by building our own batchnorm layer from scratch. Eventually we want something like this to work:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AlmostBatchNorm { // : Layer\n", " // Configuration hyperparameters\n", " let momentum, epsilon: Scalar\n", " // Running statistics\n", " var runningMean, runningVariance: Tensor\n", " // Trainable parameters\n", " var scale, offset: Tensor\n", " \n", " init(featureCount: Int, momentum: Scalar = 0.9, epsilon: Scalar = 1e-5) {\n", " (self.momentum, self.epsilon) = (momentum, epsilon)\n", " (scale, offset) = (Tensor(ones: [featureCount]), Tensor(zeros: [featureCount]))\n", " (runningMean, runningVariance) = (Tensor(0), Tensor(1))\n", " }\n", "\n", " func call(_ input: Tensor) -> Tensor {\n", " let mean, variance: Tensor\n", " switch Context.local.learningPhase {\n", " case .training:\n", " mean = input.mean(alongAxes: [0, 1, 2])\n", " variance = input.variance(alongAxes: [0, 1, 2])\n", " runningMean += (mean - runningMean) * (1 - momentum)\n", " runningVariance += (variance - runningVariance) * (1 - momentum)\n", " case .inference:\n", " (mean, variance) = (runningMean, runningVariance)\n", " }\n", " let normalizer = rsqrt(variance + epsilon) * scale\n", " return (input - mean) * normalizer + offset\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But there are some automatic differentiation limitations (lack of support for classes and control flow) that make this impossible for now, so we'll need a few workarounds. A `Reference` will let us update running statistics without making the layer a class or declaring the `applied` method `mutating`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public class Reference {\n", " public var value: T\n", " public init(_ value: T) { self.value = value }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following snippet will let us differentiate a layer's `forward` method (which is the one called in `call` for `FALayer`) if it's composed of training and inference implementations that are each differentiable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public protocol LearningPhaseDependent: FALayer {\n", " associatedtype Input\n", " associatedtype Output\n", " @differentiable func forwardTraining (_ input: Input) -> Output\n", " @differentiable func forwardInference(_ input: Input) -> Output\n", "}\n", "\n", "extension LearningPhaseDependent {\n", " public func forward(_ input: Input) -> Output {\n", " switch Context.local.learningPhase {\n", " case .training: return forwardTraining(input)\n", " case .inference: return forwardInference(input)\n", " }\n", " }\n", "\n", " @differentiating(forward)\n", " func gradForward(_ input: Input) ->\n", " (value: Output, pullback: (Self.Output.TangentVector) ->\n", " (Self.TangentVector, Self.Input.TangentVector)) {\n", " switch Context.local.learningPhase {\n", " case .training: return valueWithPullback(at: input) { $0.forwardTraining($1) }\n", " case .inference: return valueWithPullback(at: input) { $0.forwardInference($1) }\n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can implement a BatchNorm that we can use in our models:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public protocol Norm: FALayer where Input == TF, Output == TF {\n", " init(_ featureCount: Int, epsilon: Float)\n", "}\n", "\n", "public struct FABatchNorm: LearningPhaseDependent, Norm {\n", " // Configuration hyperparameters\n", " @noDerivative var momentum, epsilon: Float\n", " // Running statistics\n", " @noDerivative let runningMean, runningVariance: Reference\n", " // Trainable parameters\n", " public var scale, offset: TF\n", " \n", " public init(_ featureCount: Int, momentum: Float, epsilon: Float = 1e-5) {\n", " self.momentum = momentum\n", " self.epsilon = epsilon\n", " self.scale = Tensor(ones: [featureCount])\n", " self.offset = Tensor(zeros: [featureCount])\n", " self.runningMean = Reference(Tensor(0))\n", " self.runningVariance = Reference(Tensor(1))\n", " }\n", " \n", " public init(_ featureCount: Int, epsilon: Float = 1e-5) {\n", " self.init(featureCount, momentum: 0.9, epsilon: epsilon)\n", " }\n", "\n", " @differentiable\n", " public func forwardTraining(_ input: TF) -> TF {\n", " let mean = input.mean(alongAxes: [0, 1, 2])\n", " let variance = input.variance(alongAxes: [0, 1, 2])\n", " runningMean.value += (mean - runningMean.value) * (1 - momentum)\n", " runningVariance.value += (variance - runningVariance.value) * (1 - momentum)\n", " let normalizer = rsqrt(variance + epsilon) * scale\n", " return (input - mean) * normalizer + offset\n", " }\n", " \n", " @differentiable\n", " public func forwardInference(_ input: TF) -> TF {\n", " let (mean, variance) = (runningMean.value, runningVariance.value)\n", " let normalizer = rsqrt(variance + epsilon) * scale\n", " return (input - mean) * normalizer + offset\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is a generic `ConvNorm` layer, that combines a conv2d and a norm (like batchnorm, running batchnorm etc...) layer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct ConvNorm: FALayer\n", " where NormType.AllDifferentiableVariables == NormType.TangentVector {\n", " public var conv: FANoBiasConv2D\n", " public var norm: NormType\n", " \n", " public init(_ cIn: Int, _ cOut: Int, ks: Int = 3, stride: Int = 2){\n", " self.conv = FANoBiasConv2D(cIn, cOut, ks: ks, stride: stride, activation: relu) \n", " self.norm = NormType(cOut, epsilon: 1e-5)\n", " }\n", "\n", " @differentiable\n", " public func forward(_ input: Tensor) -> Tensor {\n", " return norm(conv(input))\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct CnnModelNormed: FALayer\n", " where NormType.AllDifferentiableVariables == NormType.TangentVector {\n", " public var convs: [ConvNorm]\n", " public var pool = FAGlobalAvgPool2D()\n", " public var linear: FADense\n", " \n", " public init(channelIn: Int, nOut: Int, filters: [Int]){\n", " let allFilters = [channelIn] + filters\n", " convs = Array(0..(allFilters[i], allFilters[i+1], ks: 3, stride: 2)\n", " }\n", " linear = FADense(filters.last!, nOut)\n", " }\n", " \n", " @differentiable\n", " public func forward(_ input: TF) -> TF {\n", " // TODO: Work around https://bugs.swift.org/browse/TF-606\n", " return linear.forward(pool.forward(convs(input)))\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's benchmark this batchnorm implementation!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func benchmark(forward: () -> (), backward: () -> ()) {\n", " print(\"forward:\")\n", " time(repeating: 10, forward)\n", " print(\"backward:\")\n", " time(repeating: 10, backward)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let input = TF(randomUniform: [64, 28, 28, 32])\n", "let norm = FABatchNorm(32)\n", "let pb = pullback(at: input) { x in norm(x) }\n", "benchmark(forward: { norm(input) }, backward: { pb(input) })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yikes, that's pretty bad. Luckily, TensorFlow has a built-in fused batchnorm layer. Let's see how the performance looks for that:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let input = TF(randomUniform: [64, 28, 28, 32])\n", "let norm = FABatchNorm(32)\n", "let bnresult = Raw.fusedBatchNormV2(\n", " input, scale: norm.scale, offset: norm.offset, \n", " mean: TF([] as [Float]), variance: TF([] as [Float]), \n", " epsilon: Double(norm.epsilon))\n", "benchmark(\n", " forward: {\n", " Raw.fusedBatchNormV2(\n", " input, scale: norm.scale, offset: norm.offset, \n", " mean: TF([] as [Float]), variance: TF([] as [Float]), \n", " epsilon: Double(norm.epsilon))\n", " },\n", " backward: {\n", " Raw.fusedBatchNormGradV2(\n", " yBackprop: input, input, scale: TF(norm.scale), \n", " reserveSpace1: bnresult.reserveSpace1, \n", " reserveSpace2: bnresult.reserveSpace2, \n", " epsilon: Double(norm.epsilon))\n", " })" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "struct PullbackArgs : TensorGroup {\n", " let input: T\n", " let cotangent: U\n", "}\n", "\n", "class CompiledFunction {\n", " let f: @differentiable (Input) -> Output\n", " init(_ f: @escaping @differentiable (Input) -> Output) {\n", " self.f = f\n", " }\n", "}\n", "\n", "func xlaCompiled(\n", " _ fn: @escaping @differentiable (T) -> U) -> CompiledFunction\n", " where T.TangentVector : TensorGroup, U.TangentVector : TensorGroup {\n", " let xlaCompiledFn: (T) -> U = _graph(fn, useXLA: true)\n", " let xlaCompiledPullback = _graph(\n", " { (pbArgs: PullbackArgs) in\n", " pullback(at: pbArgs.input, in: fn)(pbArgs.cotangent) },\n", " useXLA: true\n", " )\n", " return CompiledFunction(differentiableFunction { x in\n", " (value: xlaCompiledFn(x), pullback: { v in\n", " xlaCompiledPullback(PullbackArgs(input: x, cotangent: v))})\n", " })\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "struct TrainingKernelInput: TensorGroup, Differentiable, AdditiveArithmetic {\n", " var input, scale, offset, runningMean, runningVariance, momentum, epsilon: TF\n", "}\n", "\n", "struct TrainingKernelOutput: TensorGroup, Differentiable, AdditiveArithmetic {\n", " var normalized, newRunningMean, newRunningVariance: TF\n", "}\n", "\n", "@differentiable\n", "func trainingKernel(_ input: TrainingKernelInput) -> TrainingKernelOutput {\n", " let mean = input.input.mean(alongAxes: [0, 1, 2])\n", " let variance = input.input.variance(alongAxes: [0, 1, 2])\n", " let invMomentum = TF(1) - input.momentum\n", " let newRunningMean = input.runningMean * input.momentum + mean * invMomentum\n", " let newRunningVariance = input.runningVariance * input.momentum + variance * invMomentum\n", " let normalizer = rsqrt(variance + input.epsilon) * input.scale\n", " let normalized = (input.input - mean) * normalizer + input.offset\n", " return TrainingKernelOutput(\n", " normalized: normalized,\n", " newRunningMean: newRunningMean,\n", " newRunningVariance: newRunningVariance\n", " )\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let input = TF(randomUniform: [64, 28, 28, 32])\n", "let norm = FABatchNorm(32)\n", "let compiledTrainingKernel = xlaCompiled(trainingKernel)\n", "let kernelInput = TrainingKernelInput(\n", " input: input,\n", " scale: norm.scale,\n", " offset: norm.offset,\n", " runningMean: norm.runningMean.value,\n", " runningVariance: norm.runningVariance.value,\n", " momentum: Tensor(norm.momentum),\n", " epsilon: Tensor(norm.epsilon))\n", "let pb = pullback(at: kernelInput) { x in compiledTrainingKernel.f(x) }\n", "let kernelOutput = compiledTrainingKernel.f(kernelInput)\n", "\n", "benchmark(\n", " forward: { compiledTrainingKernel.f(kernelInput) },\n", " backward: { pb(kernelOutput) })" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Swift", "language": "swift", "name": "swift" } }, "nbformat": 4, "nbformat_minor": 2 }