{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Installing packages:\n", "\t.package(path: \"/home/sgugger/git/course-v3/nbs/swift/FastaiNotebook_08a_heterogeneous_dictionary\")\n", "\t\tFastaiNotebook_08a_heterogeneous_dictionary\n", "With SwiftPM flags: []\n", "Working in: /tmp/tmpjnhcoy0k/swift-install\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[1/13] Compiling FastaiNotebook_08_data_block 01_matmul.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[2/13] Compiling FastaiNotebook_08_data_block 03_minibatch_training.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[3/13] Compiling FastaiNotebook_08_data_block 02_fully_connected.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[4/13] Compiling FastaiNotebook_08_data_block 05b_early_stopping.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[5/13] Compiling FastaiNotebook_08_data_block 06_cuda.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[6/13] Compiling FastaiNotebook_08_data_block 05_anneal.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[7/13] Compiling FastaiNotebook_08_data_block 02a_why_sqrt5.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[8/13] Compiling FastaiNotebook_08_data_block 00_load_data.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[9/13] Compiling FastaiNotebook_08_data_block 04_callbacks.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[10/13] Compiling FastaiNotebook_08_data_block 08_data_block.swift\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[11/13] Compiling FastaiNotebook_08_data_block 01a_fastai_layers.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[12/13] Compiling FastaiNotebook_08_data_block 07_batchnorm.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[13/14] Merging module FastaiNotebook_08_data_block\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[14/15] Compiling FastaiNotebook_08a_heterogeneous_dictionary 08_data_block.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[15/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 03_minibatch_training.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[16/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 02_fully_connected.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[17/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 05b_early_stopping.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[18/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 06_cuda.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[19/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 05_anneal.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[20/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 02a_why_sqrt5.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[21/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 00_load_data.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[22/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 04_callbacks.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[23/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 01a_fastai_layers.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[24/25] Compiling FastaiNotebook_08a_heterogeneous_dictionary 07_batchnorm.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[25/26] Merging module FastaiNotebook_08a_heterogeneous_dictionary\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[26/27] Compiling jupyterInstalledPackages jupyterInstalledPackages.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[27/28] Merging module jupyterInstalledPackages\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift-autolink-extract: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift-autolink-extract)\n", "[28/28] Linking libjupyterInstalledPackages.so\n", "Initializing Swift...\n", "Installation complete!\n" ] } ], "source": [ "%install-location $cwd/swift-install\n", "%install '.package(path: \"$cwd/FastaiNotebook_08a_heterogeneous_dictionary\")' FastaiNotebook_08a_heterogeneous_dictionary" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "import Path\n", "import TensorFlow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import FastaiNotebook_08a_heterogeneous_dictionary" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('inline', 'module://ipykernel.pylab.backend_inline')\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%include \"EnableIPythonDisplay.swift\"\n", "IPythonDisplay.shell.enable_matplotlib(\"inline\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let path = downloadImagenette()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let il = ItemList(fromFolder: path, extensions: [\"jpeg\", \"jpg\"])\n", "let sd = SplitData(il, fromFunc: {grandParentSplitter(fName: $0, valid: \"val\")})\n", "var procLabel = CategoryProcessor()\n", "let sld = makeLabeledData(sd, fromFunc: parentLabeler, procLabel: &procLabel)\n", "let rawData = sld.toDataBunch(itemToTensor: pathsToTensor, labelToTensor: intsToTensor)\n", "let data = transformData(rawData, tfmItem: { openAndResize(fname: $0, size: 128) })" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func modelInit() -> CNNModel { return CNNModel(channelIn: 3, nOut: 10, filters: [64, 64, 128, 256]) }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stateful optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we begin, we create this structure to contain the names of our hyper-parameters. This will give us some tab completion and typo-proof way of handling them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct HyperParams {\n", " public static let lr = \"learningRate\"\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Like in the python version, we create `statDelegates` that will be responsible for computing/updating statistics in the state (like the moving average of gradients) and `stepDelegates` that will be responsible for performing a part of the update of the weights. \n", "\n", "In PyTorch we created a basic class with functions that needed to be implemented. In swift this is what protocols are for." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public protocol StatDelegate {\n", " var name: String {get}\n", " var defaultHPs: [String:Float] {get}\n", " \n", " func update(_ state: inout [String:TF], p: TF, 𝛁p: TF, hps: inout [String:Float])\n", "}\n", "\n", "public protocol StepDelegate {\n", " var defaultHPs: [String:Float] {get}\n", " \n", " func update(_ p: inout TF, 𝛁p: inout TF, state: [String:TF], hps: inout [String:Float])\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Those are helper functions to merge dictionaries that we'll use in the `StatefulOptimizer`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public func mergeDicts(_ dicts: inout [[String:Float]], with newDict: [String:Float]) {\n", " for i in dicts.indices { \n", " dicts[i].merge(newDict) { (_, new) in new } \n", " }\n", "}\n", "\n", "public func mergeDicts(_ dicts: inout [[String:Float]], with newDicts: [[String:Float]]) {\n", " for i in dicts.indices { \n", " dicts[i].merge(newDicts[i]) { (_, new) in new } \n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Those two extensions are there to initialize dicts easily." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "extension Dictionary where Value == Int{\n", " public init(mapFromArrays arrays: [[Key]]){\n", " self.init(uniqueKeysWithValues: arrays.enumerated().flatMap { i, arr in arr.map { ($0, i) } })\n", " }\n", "}\n", "\n", "extension Dictionary {\n", " public init(constant: Value, keys: [Key]){\n", " self.init(uniqueKeysWithValues: keys.map { ($0, constant) })\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the initial state of our StatefulOptimizer. It's a dictionary keyPath (see below) to dictionary that maps names to tensor of floats." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public func initState(for model: Model, names: [String]) \n", "-> [WritableKeyPath: [String:TF]] {\n", " return [WritableKeyPath: [String:TF]](\n", " constant: [String: TF](constant: TF(0), keys: names),\n", " keys: model.variables.keyPaths)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can define the main `StatefulOptimizer`. It takes a model, hyperparameters for each parameter group, some `steppers` and `stats` and a `splitArray` that defines our different parameter groups.\n", "\n", "To understand how this work, you need to know a little bit about keyPaths. This is a tool in swift to access any elements in a nested structure like our models: one model typically has a few attributes that are modules which in turn contain other modules and so forth until we reach the primitives layers like `Conv2d` or `Dense`. KeyPaths will allow us to index in that nested structure the objects of a particular type. \n", "\n", "For instance, the shortcut `keyPaths` you can apply to any `Layer` will find all the tensors of floats. If we apply it to a `Model.AllDifferentiableVariables` object, we will find all the parameters of the model (since `model.allDifferentiableVariables` only contain the trainable parameters).\n", "\n", "That's why the inner loop of our `StatefulOptimizer` is over `variables.keyPaths`. The same keyPath index will give us the gradients. Then we create a `state` to be a dictionary of such keyPaths to `[String:TF]` and the `splitArray` we provide is an array of different keyPaths (each giving us a parameter group) from which we build a `splitDict` that maps our keyPaths to the index of the corresponding group." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public class StatefulOptimizer\n", " where Model.AllDifferentiableVariables == Model.TangentVector {\n", " public typealias ModelKeyPath = WritableKeyPath\n", " public typealias SplitDict = [ModelKeyPath: Int]\n", " public var hpGroups: [[String:Float]]\n", " public var splitDict: SplitDict\n", " public var states: [ModelKeyPath: [String: TF]]\n", " public var stats: [StatDelegate]\n", " public var steppers: [StepDelegate]\n", " public init( \n", " for model: __shared Model,\n", " steppers: [StepDelegate],\n", " stats: [StatDelegate],\n", " hpGroups: [[String:Float]],\n", " splitArray: [[ModelKeyPath]]\n", " ) {\n", " self.hpGroups = Array(repeating: [:], count: hpGroups.count)\n", " (self.steppers,self.stats) = (steppers,stats)\n", " self.splitDict = SplitDict(mapFromArrays: splitArray)\n", " states = [:]\n", " steppers.forEach { mergeDicts(&self.hpGroups, with: $0.defaultHPs) }\n", " stats.forEach { mergeDicts(&self.hpGroups, with: $0.defaultHPs) }\n", " states = initState(for: model, names: stats.map { $0.name })\n", " mergeDicts(&self.hpGroups, with: hpGroups)\n", " }\n", " \n", " public func update(\n", " _ variables: inout Model.AllDifferentiableVariables,\n", " along direction: Model.TangentVector\n", " ) {\n", " for kp in variables.keyPaths {\n", " var 𝛁p = direction[keyPath: kp]\n", " var hps = hpGroups[splitDict[kp]!]\n", " stats.forEach() { $0.update(&states[kp]!, p: variables[keyPath: kp], 𝛁p: 𝛁p, hps: &hps) }\n", " steppers.forEach() { $0.update(&variables[keyPath: kp], 𝛁p: &𝛁p, state: states[kp]!, hps: &hps) }\n", " hpGroups[splitDict[kp]!] = hps\n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make `StatefulOptimizer` conform to the `Optimizer` protocol, we need to add a `learningRate` property." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "extension StatefulOptimizer: Optimizer{\n", " public var learningRate: Float {\n", " get { return hpGroups.last![HyperParams.lr]! } \n", " set { \n", " for i in hpGroups.indices {self.hpGroups[i][HyperParams.lr] = newValue }\n", " }\n", " }\n", " //For discriminative learning rates\n", " public var learningRates: [Float] {\n", " get { return hpGroups.map { $0[HyperParams.lr]! } }\n", " set { \n", " for i in hpGroups.indices {self.hpGroups[i][HyperParams.lr] = newValue[i] } \n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we don't have any parameter groups, we just use one with all the `keyPaths`. This convenience init automatically does that for us." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "extension StatefulOptimizer{\n", " public convenience init (for model: __shared Model,\n", " steppers: [StepDelegate],\n", " stats: [StatDelegate],\n", " hps: [String:Float]) {\n", " self.init(for: model,\n", " steppers: steppers,\n", " stats: stats,\n", " hpGroups: [hps],\n", " splitArray: [model.variables.keyPaths])\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to define `steppers` and `stats`. Let's begin with basic SGD:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct SGDStep: StepDelegate {\n", " public var defaultHPs: [String: Float] { return [HyperParams.lr: 3e-3] }\n", " public init() {}\n", " public func update(_ p: inout TF, 𝛁p: inout TF, state: [String:TF], hps: inout [String:Float]) {\n", " p -= 𝛁p * hps[HyperParams.lr]!\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check all is working and train:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var hps: [String:Float] = [HyperParams.lr: 0.01]\n", "func optFunc(_ model: CNNModel) -> StatefulOptimizer {\n", " return StatefulOptimizer(for: model, steppers: [SGDStep()], stats: [], hps: hps)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n", "var recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [1.5975714, 0.45] \n", " \r" ] } ], "source": [ "learner.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we can add weight decay and L2 regularization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public extension HyperParams {\n", " static let wd = \"weightDecay\"\n", "}\n", "\n", "public struct WeightDecay: StepDelegate {\n", " public var defaultHPs: [String: Float] { return [HyperParams.wd: 0] }\n", " public init() {}\n", " public func update(_ p: inout TF, 𝛁p: inout TF, state: [String:TF], hps: inout [String:Float]) {\n", " p *= 1 - hps[HyperParams.lr]! * hps[HyperParams.wd]!\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct L2Regularization: StepDelegate {\n", " public var defaultHPs: [String: Float] { return [HyperParams.wd: 0] }\n", " public init() {}\n", " public func update(_ p: inout TF, 𝛁p: inout TF, state: [String:TF], hps: inout [String:Float]) {\n", " 𝛁p += hps[HyperParams.wd]! * p\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next step is SGD with momentum. For this we need a statistic that keeps track of the moving average of the gradients." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "//Expandable enum to have tab completes/typo-proof for state variable names.\n", "public struct StateKeys {\n", " public static let avgGrad = \"averageGrad\"\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public extension HyperParams {\n", " static let mom = \"momentum\"\n", " static let momDamp = \"dampening\"\n", "}\n", "\n", "public struct AverageGrad: StatDelegate {\n", " public var defaultHPs: [String: Float] { return [HyperParams.mom: 0.9] }\n", " public let dampened: Bool\n", " public init(dampened: Bool = false) { self.dampened = dampened }\n", " public var name: String { return StateKeys.avgGrad }\n", " public func update(_ state: inout [String: TF], p: TF, 𝛁p: TF, hps: inout [String:Float]) {\n", " state[StateKeys.avgGrad]! *= hps[HyperParams.mom]!\n", " hps[HyperParams.momDamp] = 1.0 - (dampened ? hps[HyperParams.mom]! : 0.0)\n", " state[StateKeys.avgGrad]! += hps[HyperParams.momDamp]! * 𝛁p\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct MomentumStep: StepDelegate {\n", " public var defaultHPs: [String: Float] = [:]\n", " public init() {}\n", " public func update(_ p: inout TF, 𝛁p: inout TF, state: [String: TF], hps: inout [String:Float]) {\n", " p -= state[StateKeys.avgGrad]! * hps[HyperParams.lr]!\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can check it trains properly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let hps: [String:Float] = [HyperParams.lr: 0.01]\n", "func optFunc(_ model: CNNModel) -> StatefulOptimizer {\n", " return StatefulOptimizer(for: model, steppers: [MomentumStep()], stats: [AverageGrad()], hps: hps)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n", "var recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [1.335211, 0.544] \n", " \r" ] } ], "source": [ "learner.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The hyper-parameters have taken the default values provided (except for learning rates)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "▿ 3 elements\n", " ▿ 0 : 2 elements\n", " - key : \"dampening\"\n", " - value : 1.0\n", " ▿ 1 : 2 elements\n", " - key : \"momentum\"\n", " - value : 0.9\n", " ▿ 2 : 2 elements\n", " - key : \"learningRate\"\n", " - value : 0.01\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.opt.hpGroups[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next step is Adam. For that we need to keep track of the averages of the gradients squared." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public extension HyperParams {\n", " static let ²mom = \"momentumSquares\"\n", " static let ²momDamp = \"dampeningSquares\"\n", "}\n", "\n", "public extension StateKeys {\n", " static let avgSqr = \"averageSquaredGrad\"\n", "}\n", "\n", "public struct AverageSquaredGrad: StatDelegate {\n", " let dampened: Bool\n", " public init(dampened: Bool = true) { self.dampened = dampened }\n", " public var name: String { return StateKeys.avgSqr }\n", " public var defaultHPs: [String: Float] { return [HyperParams.²mom: 0.99] }\n", " public func update(_ state: inout [String: TF], p: TF, 𝛁p: TF, hps: inout [String:Float]) {\n", " state[StateKeys.avgSqr]! *= hps[HyperParams.²mom]!\n", " hps[HyperParams.²momDamp] = 1.0 - (dampened ? hps[HyperParams.²mom]! : 0.0)\n", " state[StateKeys.avgSqr]! += hps[HyperParams.²momDamp]! * 𝛁p.squared()\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we also need to keep track of the number of iterations we did." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public extension StateKeys {\n", " static let step = \"stepCount\"\n", "}\n", "\n", "public struct StepCount: StatDelegate {\n", " public var name: String { return StateKeys.step }\n", " public var defaultHPs: [String:Float] = [:]\n", " public init() {}\n", " public func update(_ state: inout [String: TF], p: TF, 𝛁p: TF, hps: inout [String:Float]) {\n", " state[StateKeys.step]! += 1.0\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "//public struct Epsilon: HetDictKey { public static var defaultValue: Float = 1e-5 }\n", "public extension HyperParams {\n", " static let eps = \"epsilon\"\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "public struct AdamStep: StepDelegate {\n", " public var defaultHPs: [String: Float] { return [HyperParams.eps: 1e-5] }\n", " public init() {}\n", " public func update(_ p: inout TF, 𝛁p: inout TF, state: [String: TF], hps: inout [String:Float]) {\n", " let step = state[StateKeys.step]!\n", " let (mom,damp) = (hps[HyperParams.mom]!,hps[HyperParams.momDamp]!)\n", " let debias1 = damp * (1 - pow(mom, step)) / (1 - mom)\n", " let num = debias1 * state[StateKeys.avgGrad]!\n", " \n", " let (²mom,²damp) = (hps[HyperParams.²mom]!,hps[HyperParams.²momDamp]!)\n", " let debias2 = ²damp * (1 - pow(²mom, step)) / (1 - ²mom)\n", " let denom = sqrt(state[StateKeys.avgSqr]!/debias2) + hps[HyperParams.eps]!\n", " \n", " p -= hps[HyperParams.lr]! * num / denom\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again let's check it's all training properly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func optFunc(_ model: CNNModel) -> StatefulOptimizer {\n", " return StatefulOptimizer(\n", " for: model,\n", " steppers: [AdamStep()], \n", " stats: [AverageGrad(dampened: true), AverageSquaredGrad(), StepCount()], \n", " hps: [HyperParams.lr: 1e-3])\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n", "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [1.208637, 0.596] \n", " \r" ] } ], "source": [ "learner.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also check the values of the hyper-parameters have been set properly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "▿ 6 elements\n", " ▿ 0 : 2 elements\n", " - key : \"dampeningSquares\"\n", " - value : 0.00999999\n", " ▿ 1 : 2 elements\n", " - key : \"momentumSquares\"\n", " - value : 0.99\n", " ▿ 2 : 2 elements\n", " - key : \"dampening\"\n", " - value : 0.100000024\n", " ▿ 3 : 2 elements\n", " - key : \"learningRate\"\n", " - value : 0.001\n", " ▿ 4 : 2 elements\n", " - key : \"epsilon\"\n", " - value : 1e-05\n", " ▿ 5 : 2 elements\n", " - key : \"momentum\"\n", " - value : 0.9\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.opt.hpGroups[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Defining the Lamb optimizer is as easy as before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "public struct LambStep: StepDelegate {\n", " public var defaultHPs: [String: Float] { return [HyperParams.eps: 1e-6, HyperParams.wd: 0.0] }\n", " public func update(_ p: inout TF, 𝛁p: inout TF, state: [String: TF], hps: inout [String:Float]) {\n", " let stepCount = state[StateKeys.step]!\n", " let (mom,damp) = (hps[HyperParams.mom]!,hps[HyperParams.momDamp]!)\n", " let debias1 = damp * (1 - pow(mom, stepCount)) / (1 - mom)\n", " let num = debias1 * state[StateKeys.avgGrad]!\n", " \n", " let (²mom,²damp) = (hps[HyperParams.²mom]!,hps[HyperParams.²momDamp]!)\n", " let debias2 = ²damp * (1 - pow(²mom, stepCount)) / (1 - ²mom)\n", " let denom = sqrt(state[StateKeys.avgSqr]!/debias2) + hps[HyperParams.eps]!\n", " \n", " let step = num / denom + hps[HyperParams.wd]! * p\n", " let r1 = sqrt((p * p).mean())\n", " let r2 = sqrt((step * step).mean())\n", " let factor = min(r1 / r2, Float(10.0))\n", " p -= hps[HyperParams.lr]! * factor * step\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making convenience functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To easily create our optimizers, we have two convenience functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "public func sgdOpt(lr: Float, mom: Float = 0.9, wd: Float = 0.0, dampening: Bool = false\n", " ) -> ((Model) -> StatefulOptimizer) {\n", " var steppers: [StepDelegate] = (mom != 0) ? [MomentumStep()] : [SGDStep()]\n", " if wd != 0 { steppers.append(WeightDecay()) }\n", " let stats = (mom != 0) ? [AverageGrad(dampened: dampening)] : []\n", " var hps: [String: Float] = [HyperParams.lr: lr]\n", " if mom != 0 { hps[HyperParams.mom] = mom }\n", " if wd != 0 { hps[HyperParams.wd ] = wd }\n", " return {model in \n", " return StatefulOptimizer(for: model, steppers: steppers, stats: stats, hps: hps)}\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "public func adamOpt(lr: Float, mom: Float = 0.9, beta: Float=0.99, wd: Float = 0.0, eps: Float = 1e-5\n", " ) -> ((Model) -> StatefulOptimizer) {\n", " var steppers: [StepDelegate] = [AdamStep()]\n", " if wd != 0 { steppers.append(WeightDecay()) }\n", " let stats: [StatDelegate] = [AverageGrad(dampened: true), AverageSquaredGrad(), StepCount()]\n", " var hps: [String: Float] = [HyperParams.lr: lr]\n", " hps[HyperParams.mom] = mom\n", " hps[HyperParams.²mom] = beta\n", " hps[HyperParams.eps] = eps\n", " if wd != 0 { hps[HyperParams.wd ] = wd }\n", " return {model in \n", " return StatefulOptimizer(for: model, steppers: steppers, stats: stats, hps: hps)}\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Schedule the hyperparams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next thing is that we need to schedule our hyper-parameters. The following function allows us to schedule any of them, as long as they are present in the `hpGroups` dictionaries." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "public extension StatefulOptimizer {\n", " func setParam(_ hp: String, _ val: Float) {\n", " for i in 0.. Float\n", "\n", " // A learning rate schedule from step to float.\n", " public var scheduler: ScheduleFunc\n", " public let hp: String\n", " \n", " public init(scheduler: @escaping (Float) -> Float, hp: String) {\n", " (self.scheduler,self.hp) = (scheduler,hp)\n", " }\n", " \n", " override public func batchWillStart(learner: Learner) {\n", " let val = scheduler(learner.pctEpochs/Float(learner.epochCount))\n", " (learner.opt as! StatefulOptimizer).setParam(hp, val)\n", " }\n", " }\n", " \n", " public func makeParamScheduler(_ scheduler: @escaping (Float) -> Float, hp: String) -> ParamScheduler {\n", " return ParamScheduler(scheduler: scheduler, hp: hp)\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then define a helper function to schedule a 1cycle policy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export \n", "public func oneCycleSchedulers(_ lrMax: Float, pctStart:Float=0.25, divStart: Float = 10, divEnd: Float = 1e5, \n", " moms: (Float,Float,Float) = (0.95,0.85,0.95)) \n", "-> ((Float) -> Float, (Float) -> Float){\n", " let lrSched = combineSchedules(\n", " pcts: [pctStart, 1-pctStart], \n", " schedules: [makeAnnealer(start: lrMax/divStart, end: lrMax, schedule: cosineSchedule),\n", " makeAnnealer(start: lrMax, end: lrMax/divEnd, schedule: cosineSchedule)])\n", " let momSched = combineSchedules(\n", " pcts: [pctStart, 1-pctStart], \n", " schedules: [makeAnnealer(start: moms.0, end: moms.1, schedule: cosineSchedule),\n", " makeAnnealer(start: moms.1, end: moms.2, schedule: cosineSchedule)])\n", " return (lrSched, momSched)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "extension Learner where Opt.Scalar: BinaryFloatingPoint, \n", " Opt.Model.AllDifferentiableVariables == Opt.Model.TangentVector{\n", "\n", " public func addOneCycleDelegates(_ lrMax: Float, pctStart:Float=0.25, divStart: Float = 10, divEnd: Float = 1e5, \n", " moms: (Float,Float,Float) = (0.95,0.85,0.95)) {\n", " let scheds = oneCycleSchedulers(lrMax, pctStart: pctStart, divStart: divStart, divEnd: divEnd, moms: moms)\n", " addDelegates([makeParamScheduler(scheds.0 , hp: HyperParams.lr), \n", " makeParamScheduler(scheds.1 , hp: HyperParams.mom)])\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And check it's all training properly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let optFunc: (CNNModel) -> StatefulOptimizer = adamOpt(lr: 1e-3, mom: 0.9, beta: 0.99, wd: 1e-2, eps: 1e-6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n", "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [1.3358655, 0.582] \n", " \r" ] } ], "source": [ "learner.addOneCycleDelegates(1e-3)\n", "learner.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "recorder.plotLRs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Differential learning rates" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To train at differential learning rates (or freeze part of the models) we need to pass to our optimizer arrays of KeyPaths (which will define our layer groups). For instance, we can begin with the firt 9 keyPaths (which corresponds to the first three ConvBNs):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func modelInit() -> CNNModel { return CNNModel(channelIn: 3, nOut: 10, filters: [64, 64, 128, 256]) }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var model = modelInit()\n", "let splitArray = [Array(model.variables.keyPaths[0..<9]), Array(model.variables.keyPaths[9...])]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let hpGroups: [[String: Float]] = [[HyperParams.lr: 0], [HyperParams.lr: 0.1]]\n", "func optFunc(_ model: CNNModel) -> StatefulOptimizer {\n", " return StatefulOptimizer(for: model, steppers: [SGDStep()], stats: [], hpGroups: hpGroups, splitArray: splitArray)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n", "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First parameter shouldn't change since the corresponding layer group as a LR of 0., second should." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.model.convs[0].norm.scale" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,\n", " 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,\n", " 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.model.convs[3].norm.scale" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [1.3130397, 0.534] \n", " \r" ] } ], "source": [ "learner.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.99947506, 1.0001966, 0.9900738, 0.98156047, 0.99899787, 1.0270612, 0.9824792, 1.0029291,\n", " 1.0091724, 1.0046912, 1.0073055, 1.0055526, 0.97067755, 0.9816713, 1.016569, 0.9986565]\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.model.convs[0].norm.scale" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[ 1.0017759, 1.0072277, 0.9770908, 1.0054005, 1.0019907, 1.0047579, 1.0007318, 1.0017214,\n", " 0.97550786, 1.0026606, 1.0116018, 1.0013392, 0.9883987, 0.98546517, 0.9893817, 1.0137036,\n", " 0.999046, 1.0305443, 1.005647, 1.0052477, 1.0025676, 0.99569345, 0.9984126, 1.0053662,\n", " 0.97381794, 0.9985082, 0.9871409, 0.99648446, 0.9966728, 1.0009459, 0.9931889, 1.0144972,\n", " 1.0004792, 0.9952349, 1.0133064, 0.9818646, 0.99662995, 0.9993304, 1.001136, 0.985157,\n", " 0.99845725, 0.98352814, 0.98484004, 0.97534674, 0.98715824, 0.9976744, 0.9670125, 0.97067595,\n", " 0.9903402, 1.0086904, 0.9866701, 1.0071692, 1.0245737, 0.9928977, 1.009554, 1.0141283,\n", " 1.0014772, 0.97363013, 1.0075214, 0.9731512, 0.9915053, 1.001776, 1.0038495, 1.0007021]\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.model.convs[3].norm.scale" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another way to get those keyPaths is to use the keyPaths to certain layers then append the keyPaths of all the parameters inside. This function takes a model, a layer and the keyPath that points from `model.variables` to `layer.variables` and returns the keyPaths of all the parameters of that layer. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "public func parameterKeyPaths(\n", " _ model: M1,\n", " _ kp: WritableKeyPath,\n", " _ layer: M2) -> [WritableKeyPath]\n", "where M1: Layer, M2: Layer {\n", " return model.variables[keyPath: kp].keyPaths.map { kp.appending(path: $0) }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To access a keyPath directly, we use \\ commands. Here is the keyPath to the array of convs, which lets us easily get the keyPaths for all the body of our CNN:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let kp = \\(CNNModel.AllDifferentiableVariables).convs\n", "let conv = model.convs\n", "let bodyKeyPaths = parameterKeyPaths(model, kp, conv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we could split body and head:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "▿ 2 elements\n", " - 0 : 21\n", " - 1 : 2\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "let splitArray = [bodyKeyPaths, model.variables.keyPaths.filter { return !bodyKeyPaths.contains($0) }]\n", "splitArray.map { $0.count }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we want to refine this a bit and split our body between the first 4 convs and the last 3 we can proceed like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 4\r\n", "2 5\r\n", "3 6\r\n" ] }, { "data": { "text/plain": [ "▿ 3 elements\n", " - 0 : 0 elements\n", " - 1 : 0 elements\n", " - 2 : 0 elements\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "let x = [1,2,3]\n", "let y = [4,5,6]\n", "zip(x,y).map { print($0, $1) }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let deepBody = (0..<4).map { parameterKeyPaths(\n", " model, \n", " \\(CNNModel.AllDifferentiableVariables).convs.base[$0], \n", " model.convs[$0]\n", ") }.reduce([], +)\n", "\n", "let upperBody = (4..<7).map { parameterKeyPaths(\n", " model, \n", " \\(CNNModel.AllDifferentiableVariables).convs.base[$0], \n", " model.convs[$0]\n", ") }.reduce([], +)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "▿ 3 elements\n", " - 0 : 12\n", " - 1 : 9\n", " - 2 : 2\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "let splitArray = [deepBody, upperBody, model.variables.keyPaths.filter { return !bodyKeyPaths.contains($0) }]\n", "splitArray.map { $0.count }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's say we want a parameter group will all the batchnorm layers. KeyPaths allow us to get all the batchnorm layers by saying `(to: FABatchNorm.self)`. The `.keyPaths` method we have been using is jsut a shortcut for `recursivelyAllWritableKeyPaths(to: TF.self)`, which grabs all the keypaths to all the tensors." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let bns = model.recursivelyAllWritableKeyPaths(to: FABatchNorm.self).map { model[keyPath: $0] }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we need the keypaths from model.variables to those batchnorms." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let bnKeyPaths = model.variables.recursivelyAllWritableKeyPaths(to: FABatchNorm.AllDifferentiableVariables.self)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "14\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "let bnParameters = zip(bnKeyPaths, bns).map { parameterKeyPaths(model, $0, $1) }.reduce([], +)\n", "bnParameters.count" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "success\r\n" ] } ], "source": [ "import NotebookExport\n", "let exporter = NotebookExport(Path.cwd/\"09_optimizer.ipynb\")\n", "print(exporter.export(usingPrefix: \"FastaiNotebook_\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Swift", "language": "swift", "name": "swift" } }, "nbformat": 4, "nbformat_minor": 1 }