{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Installing packages:\n", "\t.package(path: \"/home/sgugger/git/course-v3/nbs/swift/FastaiNotebook_09_optimizer\")\n", "\t\tFastaiNotebook_09_optimizer\n", "With SwiftPM flags: []\n", "Working in: /tmp/tmp5t4bwpzq/swift-install\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "Updating https://github.com/mxcl/Path.swift\n", "Updating https://github.com/saeta/Just\n", "Updating https://github.com/latenitesoft/NotebookExport\n", "Completed resolution in 2.36s\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[1/15] Compiling FastaiNotebook_09_optimizer 08a_heterogeneous_dictionary.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[2/15] Compiling FastaiNotebook_09_optimizer 01_matmul.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[3/15] Compiling FastaiNotebook_09_optimizer 03_minibatch_training.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[4/15] Compiling FastaiNotebook_09_optimizer 02_fully_connected.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[5/15] Compiling FastaiNotebook_09_optimizer 05b_early_stopping.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[6/15] Compiling FastaiNotebook_09_optimizer 06_cuda.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[7/15] Compiling FastaiNotebook_09_optimizer 05_anneal.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[8/15] Compiling FastaiNotebook_09_optimizer 02a_why_sqrt5.swift\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[9/15] Compiling FastaiNotebook_09_optimizer 00_load_data.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[10/15] Compiling FastaiNotebook_09_optimizer 08_data_block.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[11/15] Compiling FastaiNotebook_09_optimizer 04_callbacks.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[12/15] Compiling FastaiNotebook_09_optimizer 09_optimizer.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[13/15] Compiling FastaiNotebook_09_optimizer 01a_fastai_layers.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[14/15] Compiling FastaiNotebook_09_optimizer 07_batchnorm.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[15/16] Merging module FastaiNotebook_09_optimizer\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[16/17] Compiling jupyterInstalledPackages jupyterInstalledPackages.swift\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "[17/18] Merging module jupyterInstalledPackages\n", "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n", "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n", "/home/sgugger/swift/usr/bin/swift-autolink-extract: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift-autolink-extract)\n", "[18/18] Linking libjupyterInstalledPackages.so\n", "Initializing Swift...\n", "Installation complete!\n" ] } ], "source": [ "%install-location $cwd/swift-install\n", "%install '.package(path: \"$cwd/FastaiNotebook_09_optimizer\")' FastaiNotebook_09_optimizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "import Path\n", "import TensorFlow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import FastaiNotebook_09_optimizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('inline', 'module://ipykernel.pylab.backend_inline')\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%include \"EnableIPythonDisplay.swift\"\n", "IPythonDisplay.shell.enable_matplotlib(\"inline\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "//TODO: switch to imagenette when possible to train" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let data = mnistDataBunch(flat: true)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let (n,m) = (60000,784)\n", "let c = 10\n", "let nHid = 50" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func modelInit() -> BasicModel {return BasicModel(nIn: m, nHid: nHid, nOut: c)}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: sgdOpt(lr: 0.1), modelInit: modelInit)\n", "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.35363615, 0.9073] \n", " \r" ] } ], "source": [ "learner.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mixup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "extension RandomDistribution {\n", " // Returns a batch of samples.\n", " func next(\n", " _ count: Int, using generator: inout G\n", " ) -> [Sample] {\n", " var result: [Sample] = []\n", " for _ in 0.. [Sample] {\n", " return next(count, using: &ThreefryRandomNumberGenerator.global)\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mixup requires one-hot encoded targets since we don't have a loss function with no reduction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "extension Learner {\n", " public class MixupDelegate: Delegate {\n", " private var distribution: BetaDistribution\n", " \n", " public init(alpha: Float = 0.4){\n", " distribution = BetaDistribution(alpha: alpha, beta: alpha)\n", " }\n", " \n", " override public func batchWillStart(learner: Learner) {\n", " if let xb = learner.currentInput {\n", " if let yb = learner.currentTarget as? Tensor{\n", " var lambda = Tensor(distribution.next(Int(yb.shape[0])))\n", " lambda = max(lambda, 1-lambda)\n", " let shuffle = Raw.randomShuffle(value: Tensor(0.. MixupDelegate {\n", " return MixupDelegate(alpha: alpha)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let (n,m) = (60000,784)\n", "let c = 10\n", "let nHid = 50" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to one-hot encode the targets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var train1 = data.train.innerDs.map { DataBatch(xb: $0.xb, \n", " yb: Raw.oneHot(indices: $0.yb, depth: TI(10), onValue: TF(1), offValue: TF(0))) }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var valid1 = data.valid.innerDs.map { DataBatch(xb: $0.xb, \n", " yb: Raw.oneHot(indices: $0.yb, depth: TI(10), onValue: TF(1), offValue: TF(0))) }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let data1 = DataBunch(train: train1, valid: valid1, trainLen: data.train.dsCount, \n", " validLen: data.valid.dsCount, bs: data.train.bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func modelInit() -> BasicModel {return BasicModel(nIn: m, nHid: nHid, nOut: c)}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func accuracyFloat(_ out: TF, _ targ: TF) -> TF {\n", " return TF(out.argmax(squeezingAxis: 1) .== targ.argmax(squeezingAxis: 1)).mean()\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data1, lossFunc: softmaxCrossEntropy, optFunc: sgdOpt(lr: 0.1), modelInit: modelInit)\n", "let recorder = learner.makeRecorder()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeShowProgress(), \n", " learner.makeAvgMetric(metrics: [accuracyFloat]), recorder,\n", " learner.makeMixupDelegate(alpha: 0.2)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.5300334, 0.9163] \n", "Epoch 1: [0.507523, 0.926] \n", " \r" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Labelsmoothing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//export\n", "@differentiable(wrt: out)\n", "public func labelSmoothingCrossEntropy(_ out: TF, _ targ: TI, ε: Float = 0.1) -> TF {\n", " let c = out.shape[1]\n", " let loss = softmaxCrossEntropy(logits: out, labels: targ)\n", " let logPreds = logSoftmax(out)\n", " return (1-ε) * loss - (ε / Float(c)) * logPreds.mean()\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@differentiable(wrt: out)\n", "func lossFunc(_ out: TF, _ targ: TI) -> TF { return labelSmoothingCrossEntropy(out, targ, ε: 0.1) }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunc: lossFunc, optFunc: sgdOpt(lr: 0.1), modelInit: modelInit)\n", "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n", "learner.delegates.append(learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.2789237, 0.9376] \n", "Epoch 1: [0.28466254, 0.9394] \n", " \r" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "success\r\n" ] } ], "source": [ "import NotebookExport\n", "let exporter = NotebookExport(Path.cwd/\"10_mixup_ls.ipynb\")\n", "print(exporter.export(usingPrefix: \"FastaiNotebook_\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Swift", "language": "swift", "name": "swift" } }, "nbformat": 4, "nbformat_minor": 2 }