{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Annealing" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Installing packages:\n", "\t.package(path: \"/home/ubuntu/fastai_docs/dev_swift/FastaiNotebook_04_callbacks\")\n", "\t\tFastaiNotebook_04_callbacks\n", "With SwiftPM flags: []\n", "Working in: /tmp/tmpye6sse7m/swift-install\n", "Fetching https://github.com/mxcl/Path.swift\n", "Fetching https://github.com/JustHTTP/Just\n", "Completed resolution in 2.74s\n", "Cloning https://github.com/JustHTTP/Just\n", "Resolving https://github.com/JustHTTP/Just at 0.7.1\n", "Cloning https://github.com/mxcl/Path.swift\n", "Resolving https://github.com/mxcl/Path.swift at 0.16.2\n", "Compile Swift Module 'Just' (1 sources)\n", "Compile Swift Module 'Path' (9 sources)\n", "Compile Swift Module 'FastaiNotebook_04_callbacks' (8 sources)\n", "Compile Swift Module 'jupyterInstalledPackages' (1 sources)\n", "Linking ./.build/x86_64-unknown-linux/debug/libjupyterInstalledPackages.so\n", "Initializing Swift...\n", "Installation complete!\n" ] } ], "source": [ "%install '.package(path: \"$cwd/FastaiNotebook_04_callbacks\")' FastaiNotebook_04_callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('inline', 'module://ipykernel.pylab.backend_inline')\n" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import FastaiNotebook_04_callbacks\n", "%include \"EnableIPythonDisplay.swift\"\n", "IPythonDisplay.shell.enable_matplotlib(\"inline\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "// export\n", "import Path\n", "import TensorFlow" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "let data = mnistDataBunch(flat: true)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "let (n,m) = (60000,784)\n", "let c = 10\n", "let nHid = 50" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "var opt = SGD(learningRate: 1e-2)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "func modelInit() -> BasicModel {return BasicModel(nIn: m, nHid: nHid, nOut: c)}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunction: softmaxCrossEntropy, optimizer: opt, initializingWith: modelInit)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeAvgMetric(metrics: [accuracy]),\n", " learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//Crashes!\n", "//learner.delegates = [type(of: learner).TrainEvalDelegate(), type(of: learner).AvgMetric(metrics: [accuracy])]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.28770122, 0.9175]\n", "Epoch 1: [0.23435025, 0.9308]\n" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Recorder's role is to keep track of the loss and our scheduled learning rate. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "// export\n", "import Python" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "// export\n", "public func plot(_ arr1: [S1], _ arr2: [S2], logScale:Bool = false, xLabel: String=\"\", yLabel: String = \"\") \n", " where S1:PythonConvertible, S2:PythonConvertible{\n", " plt.figure(figsize: [6,4])\n", " let (npArr1, npArr2) = (np.array(arr1), np.array(arr2))\n", " if logScale {plt.xscale(\"log\")} \n", " if !xLabel.isEmpty {plt.xlabel(xLabel)}\n", " if !yLabel.isEmpty {plt.ylabel(yLabel)} \n", " let fig = plt.plot(npArr1, npArr2)\n", " plt.show(fig)\n", "}" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "// export\n", "extension Learner where Opt.Scalar: PythonConvertible{\n", " public class Recorder: Delegate {\n", " public var losses: [Loss] = []\n", " public var lrs: [Opt.Scalar] = []\n", " \n", " public override func batchDidFinish(learner: Learner) {\n", " if learner.inTrain {\n", " losses.append(learner.currentLoss)\n", " lrs.append(learner.optimizer.learningRate)\n", " }\n", " }\n", " \n", " public func plotLosses(){\n", " plot(Array(0.. Recorder {\n", " return Recorder()\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunction: softmaxCrossEntropy, optimizer: opt, initializingWith: modelInit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Utility optional property to get backour `Recorder` if it was created by a utility function. This doesn't always work properly for unkwnon reasons" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//TODO: Fix\n", "extension Learner where Opt.Scalar: PythonConvertible{\n", " public var recorder: Learner.Recorder? {\n", " for callback in learner.delegates {\n", " if let recorder = callback as? Learner.Recorder { return recorder }\n", " }\n", " return nil\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeAvgMetric(metrics: [accuracy]), \n", " learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std), learner.makeRecorder()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.30299202, 0.9118]\n", "Epoch 1: [0.24488889, 0.93]\n" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8FPX9P/DXOzd3gASQcIQbEUQkIAooKCIC9bYerbfl13pXq8Vb64XVr1aL1eJRj1K13lRAUEEBOQNynwFB7hsSIAk53r8/ZmYzuzt7JZndjft6Ph55ZHd2dvadSTLv+dyiqiAiIgKApFgHQERE8YNJgYiIPJgUiIjIg0mBiIg8mBSIiMiDSYGIiDyYFIiIyINJgYiIPJgUiIjIIyXWAUQqKytLc3NzYx0GEVGdsnjx4n2qmh1qvzqXFHJzc5Gfnx/rMIiI6hQR2RLOfqw+IiIiDyYFIiLyYFIgIiIPJgUiIvJgUiAiIg8mBSIi8mBSICIij4RJCut2FeGF6euw70hprEMhIopbCZMUCvYcwcszCnDg6PFYh0JEFLcSJikkifG9UjW2gRARxbGESQoiRlaoqGRSICIKJGGSQrJZVGBBgYgosIRJClb1EUsKRESBJU5SMLMC2xSIiAJLnKQgVlKIcSBERHHMtaQgIhkislBElonIKhF53GGfdBH5UEQKRGSBiOS6FQ97HxERheZmSaEUwNmq2hvAKQBGiMgAn31uAnBQVTsDeBHAs24Fk2yVFFhUICIKyLWkoIYj5tNU88v3inwhgHfMxx8DOEesvqO1TFh9REQUkqttCiKSLCJLAewB8LWqLvDZJQfAVgBQ1XIAhwE0dyMWVh8REYXmalJQ1QpVPQVAGwD9RaSnzy5OpQK/q7aIjBGRfBHJ37t3b7ViSWbvIyKikKLS+0hVDwH4DsAIn5e2AWgLACKSAqAJgAMO75+gqnmqmpednV2tGDiimYgoNDd7H2WLSKb5uB6AYQDW+uw2CcB15uPLAMxQdedWniOaiYhCS3Hx2CcAeEdEkmEkn/+q6pci8hcA+ao6CcCbAN4TkQIYJYQr3QqGbQpERKG5lhRUdTmAPg7bH7E9LgFwuVsx2CWx+oiIKCSOaCYiIo/ESQrmT+pSkwUR0S9C4iQFq/qISYGIKKCESwqsPiIiCiyBkoLxnXMfEREFljBJgSOaiYhCS5ikYFUfHSktj3EkRETxK2GSgjX36iNfrIptIEREcSxhkoJVfURERIElTFJIcmeZBiKiX5SESQrMCUREoSVMUkhmViAiCilhkgKrj4iIQkucpMCGZiKikBInKTAnEBGFlEBJgVmBiCiUhEkKHKdARBRawiQFFhSIiEJLmKTA6iMiotASJilwnAIRUWgpsQ4gWpKSBHntmyI9NWHyIBFRxBLqCikCVFbGOgoioviVYElBoOAiO0REgbiWFESkrYjMFJE1IrJKRO502GeIiBwWkaXm1yNuxQMAAq7RTEQUjJttCuUA7lHVJSLSCMBiEflaVVf77DdbVUe7GIdHkggqmBWIiAJyraSgqjtVdYn5uAjAGgA5bn1eOES4RjMRUTBRaVMQkVwAfQAscHj5dBFZJiJTReQkN+NIEmFSICIKwvUuqSLSEMAnAO5S1UKfl5cAaK+qR0RkJIDPAXRxOMYYAGMAoF27djWIBWxmJiIKwtWSgoikwkgIE1X1U9/XVbVQVY+Yj6cASBWRLIf9JqhqnqrmZWdn1yQeNjQTEQXhZu8jAfAmgDWq+kKAfVqZ+0FE+pvx7HcrpiQBwOojIqKA3Kw+GgjgGgArRGSpue0BAO0AQFVfA3AZgD+ISDmAYgBXqrp31WaXVCKi4FxLCqo6B8Z1ONg+4wGMdysGX0kcvEZEFFSCjWjmNBdERMEkWFJgl1QiomASKimkpSTheAWLCkREgSRUUkhPScLxciYFIqJAEiwpJKOUSYGIKKAESwpJKC2riHUYRERxK7GSQmoSSlhSICIKKKGSQlpyEsrY0ExEFFBCJYWUpCSogmsqEBEFkFhJIdkYYM3SAhGRs4RKCqlmUihnSYGIyFGCJQXjxy1jYzMRkaOESgopVlLgBEhERI4SKimkJpnVRxWsPiIicpJQScEqKTApEBE5S6ikYDU0s/qIiMhZQiWFlCTjx91dWBLjSIiI4lNCJYXScmPeo+vfWhTjSIiI4lNCJQVrJDPXVCAicpZQSUEk6JLRREQJL6GSgnIpTiKioBIsKcQ6AiKi+JZQSYGIiIJzLSmISFsRmSkia0RklYjc6bCPiMjLIlIgIstF5FS34gGAC05pDQC45NQcNz+GiKjOSnHx2OUA7lHVJSLSCMBiEflaVVfb9jkfQBfz6zQAr5rfXZGRmgwAXKeZiCgA10oKqrpTVZeYj4sArAHge4t+IYB31TAfQKaInOBWTJbJy3e6/RFERHVSVNoURCQXQB8AC3xeygGw1fZ8G/wTBxERRYnrSUFEGgL4BMBdqlro+7LDW/z6CInIGBHJF5H8vXv3uhEmERHB5aQgIqkwEsJEVf3UYZdtANranrcBsMN3J1WdoKp5qpqXnZ3tTrBERORq7yMB8CaANar6QoDdJgG41uyFNADAYVVlhT8RUYy42ftoIIBrAKwQkaXmtgcAtAMAVX0NwBQAIwEUADgG4AYX4wEAnHhCY6zZWYjKSkVSEqe9ICKycy0pqOocOLcZ2PdRALe6FYOTNTuNZo2VOw7j5DaZ0fxoIqK4l7AjmtNSEvZHJyIKKGGvjDsOFcc6BCKiuJNwSaFts3oAgBvfzo9xJERE8SfhksK1A3JjHQIRUdxKuKSQzB5HREQBJVxSYE4gIgos4ZIC19khIgos8ZICswIRUUCJlxRiHQARURxLvKTAogIRUUAJlxSIiCiwhEsKnO+IiCiwsJKCiNwpIo3NKa7fFJElIjLc7eDc0L9DM1zZry3nPiIichDulfFGc9W04QCyYUxxPc61qFzWolE6yioq2b5AROQj3KRgDfkaCeBfqroMIabFjmcpyUlQBSoqmRSIiOzCTQqLRWQ6jKQwTUQaAah0Lyx37T9SCgDYdpAzpRIR2YWbFG4CMBZAP1U9BiAVUVglzS0f5m8FAExcsCXGkRARxZdwk8LpANap6iER+S2AhwAcdi8sd7VsnAEASE1mYzMRkV24V8VXARwTkd4A7gOwBcC7rkXlsleuPhUA0KZp/RhHQkQUX8JNCuXmesoXAnhJVV8C0Mi9sNzVpqmx0E5peUWMIyEiii8pYe5XJCL3A7gGwGARSYbRrlAnpZjVRmUVdbatnIjIFeGWFK4AUApjvMIuADkAnnMtKpelJhu9acsq2CWViMgurKRgJoKJAJqIyGgAJapaZ9sUUpOMH/t4OUsKRER24U5z8WsACwFcDuDXABaIyGUh3vOWiOwRkZUBXh8iIodFZKn59UikwVdXkrn82qLNB6L1kUREdUK4bQoPwhijsAcARCQbwDcAPg7ynrcBjEfwXkqzVXV0mDHUurkb96OsopJdU4mITOFeDZOshGDaH+q9qjoLQNzfin+w8OdYh0BEFDfCTQpficg0EbleRK4HMBnAlFr4/NNFZJmITBWRk2rheBF7+ItVnBiPiMgUVvWRqt4rIpcCGAhjIrwJqvpZDT97CYD2qnpEREYC+BxAF6cdRWQMgDEA0K5duxp+rL8DR4+jecP0Wj8uEVFdE26bAlT1EwCf1NYHm1NxW4+niMg/RCRLVfc57DsBwAQAyMvLq/Xbes6WSkRkCJoURKQIzmvdCwBV1cbV/WARaQVgt6qqiPSHUZW1v7rHq4nySsVP+46iUUYKslhiIKIEFjQpqGq1p7IQkfcBDAGQJSLbADwKcxS0qr4G4DIAfxCRcgDFAK7UGFXuV1Qqhj7/HdJSkrD+yfNjEQIRUVwIu/ooUqp6VYjXx8PoshpzI1+aDYCD2YiI2EEfQFFpeaxDICKKC0wKRETkkbBJISezXqxDICKKOwmbFDpmN4h1CEREcSdhk8J4c/U1IiKqkrBJoUm9OrtGEBGRaxI2KRARkT8mBSIi8mBSICIiDyYFIiLyYFIgIiIPJgUiIvJgUiAiIo+ETgoisY6AiCi+JHRSWPDAObEOgYgoriR0UmjRKMNv24tfr8e2g8diEA0RUewldFJw8tK3G3Dj24tiHQYRUUwkfFL48vZBftt2F5bGIBIiothL+KTQM6eJ37bisooYREJEFHsJnxSIiKgKk4KD4+WVsQ6BiCgmmBQCWL+7KNYhEBFFHZNCAMNfnIUjpeWxDoOIKKpcSwoi8paI7BGRlQFeFxF5WUQKRGS5iMRsfcyrT2vnuL3no9NQVsGqJCJKHG6WFN4GMCLI6+cD6GJ+jQHwqouxBHXzoA4BX3t++jr8b9mOKEZDRBQ7KW4dWFVniUhukF0uBPCuqiqA+SKSKSInqOpOt2IKpGN2w4Cv/fP7TQCAX/VuHa1wiIhiJpZtCjkAttqebzO3ERFRjMQyKTjNUaqOO4qMEZF8Ecnfu3evy2ERESWuWCaFbQDa2p63AeBYea+qE1Q1T1XzsrOzXQmmaf1UV45LRFSXxDIpTAJwrdkLaQCAw7FoT7BMvmMw3rmxf6w+nogoLrjZJfV9APMAdBORbSJyk4j8XkR+b+4yBcAmAAUAXgdwi1uxhKN1Zj2c1TUb//1/pzu+/kPBvihHREQUfW72ProqxOsK4Fa3Pr+6+ndo5rj9zg9+RP5D50Y5GiKi6OKI5jCVVSgK9hTh4c9XYuX2w7EOh4jIFUwKYSqvqMSwF2bhvflbMPrvc7xeq6hUHC4ui1FkRES1h0nBwUWn+A9UK6v07i2rqigx110YN3UNej8+nXMlEVGdx6TgoH+H5n7bfOdAmrjgZ3R/+CtsP1SM12f/BAAoKmFpgYjqNiYFBynJ/uPq1GdYnTUf0oTvNwbch4iormFScHCkJHQ10Ma9RwAA78zb4tnGnEBEdR2TgoP9R0tD7rPvyHG/bYXFZdhTWOJGSEREUeHaOIW6rKyievf85780GwCwedyo2gyHiChqWFJwwDWaiShRMSk46NMus8bH2HW4BK/MLICy9ZmI6hAmBQcXnpKDh0adWKNj3DJxMZ6btg4b9hyppagMZRWVqKxkoiEidzApBJCTWa/a7+14/2QcLTUGtlXYLuBbDxxDaXlFjeLq8uBUPPj5ihodg4goECaFAMRpCaAwVWrV+63ao5KyCgz+60zc899lAIDXZ21C7tjJ1UoS7y/cGnonIqJqYFJwiZUM1By9YI2I/nbNHgDAq+agN98xEZv3HcXfv93g1RZxvLwSG/ceYfsEEbmOScEl63YXeT0vN7u5FpdV4EhpOQ4c9R/nAAC/fXMB/u/r9Vi/+whOfmwaVmw7jMf/twrn/N/32GUbA3HvR8vcCz6KNu87ityxk7FmZ2GsQyEiMCkEYdT/DOlWs+U/rZt7+9xJPR+dVvUpPvVUxceN6qQfCvahsKQcr83aiHmb9gMADh6tmlvpo8XbABjTbaz3SUAA6szkfNNW7QIAfPbj9hhHQkQAk0JAHbIaAADO6pqNjtkNany84xXhjX2wKojqpSUDAErLKj0b1WEijdvf/xHDX5wFwEgkXy7fgbkF+9Dz0WmYsyE+Vosrq6jE4WPBJwtk1RhRfGBSCKBbq0aYd//ZuP6MXKQlV52mThEmiMLiMqzfXeSpPvIV6GKYkWp8pr0hOtR18zdvLMBt//kRCzcfAAC8/O0G7DxcHFG8NVVWUek3+O/WiUvQ+y/THfevSYM+EdU+JoUgTmhSDyKCXjlNPNvuGtY1omNc/cYCDH9xFmZt2Ov4eqDrfFqyUVJYtaMQm/YdNfYN82b6b99sAAAs3HwApz8zI6J4Q5m0bAcWmNVZTs4YN8OregwApq/eXasxWLbsPxo0FiKKHJNCGJ64qKfncZeWDat1jJe/LXDc7ntnbZUcrDtoe4O0U/VRuA4ePY5jx2vWzjBh1kbc8f6PuGLC/ID77C0qDbuqzM5KeKXlFZi0bEdY1UlnPfdd0FhiYfuhYhytI+05RE6YFMKQkZrsedy9VWO0bJwe8TGyGzm/Z8TfZqPrQ1M9z61LYaXDRbG6A5nLKyrR54mv0eORafhg4c9er63fXYRdh8Ob2fXpKWu9nkc6xiKckdjPTl2HO97/ET8U1K0SQFFJGUrLKzBw3Axc/tq8WIdDvyBfLt+Bhz9fGbXPY1KohvvO6x7xe6ylO31Zazt/tdLohWPlgmVbD/nt+/065yqoUDo/WJV0xn7qPRp6+IuzMOCZbyM+5oJN+9Htoa8wb2P4F+9yh6Qg8G5U2HGo2PP9E7OHVV3Q67Hp+LWZDFazey3Votv+8yPem78l9I61hEkhSn4y2wUC+f2/FwOoShLWEp92L36z3uv5b99YUK1Y/j1/Cw4GGCcRLqub7LwAdfoFDnM+/bTvKJ6ZuiZo1ZBVQrrvk+W456NlWL7NPzmGQ1Wx/VB0G9mXbTsc1c+rrkWbD+D0Z76tM92Wa2r2hr2YuMD7orq7sAQrt9eN31e0uZoURGSEiKwTkQIRGevw+vUisldElppfN7sZT02c2i4TNwzMdXztPzefViuf8fSUNRHtP6egqsvpz/uPhf2+hz5fidvf/zGiz4rUze8swtpdhV4J4KZ3FuGf32/Cxr1Ww7l6LkxV1Wbex7GSZKQe+GwFBo6bgR9/PhhwH1XF+BkbsPVA+OfObW/O+Qkz1+7xPJ+7cR/+m1+705o899U67DzsfVF8btpa5I6dXKufEy+ueXMhHvzMu/pl0LMzMPrvc8I+xqLNB/CabendXzLXkoKIJAN4BcD5AHoAuEpEejjs+qGqnmJ+veFWPDX16S0D8eivTvLb3r55faSl1M5pnDBrU7Xfe+ZzMyPaf07BvmrfhQNV1VxHSsox4OlvsXiL98V38/5juP/TFfhgUdUFrdRsULcSxb/nb8FL3xo9paq6znpnhWPHK/DU5NV+d/2rdgS/y7Pmh7KXWP67aCs+tw2S236oGM9PX4+b3lkU9FhlIRrOa3PW2ie+XI0b3q6K5+rXF+C+j5fX2vGBqg4L9oq7V2YaF7x4noF31Y7DnrXRw1Hu83tTVXy7ZjcqKjWihbT2FJXg8tfmYdzUtdh2MH5uINziZkmhP4ACVd2kqscBfADgQhc/L2qsnkEDOjbD/24fVGfXZr5g/A9+296YvQlfLA1/dPFbP/yEXYUleGbKGmzZ711FlpqUhO/WVd31WomkwnwwbVVVV9UpK7zbVCwLNh3A67N/wh8/XOq1fdTL4d3l2Q933yfLcZftONZnHTseuMF8+6FidHlwKj5Y+DO+WLrdcQGmkhrOfBtt1s/tO5oeAMoq43eBqVEvz/GUcFUVu4Msfbt4ywF0fnAq5m6sKk1PW7UbN72Tj9dnR3bz5dUDsJr/7KoaVpVtUUmZp/da8fGKqI8zAtxNCjkA7OXebeY2X5eKyHIR+VhE2roYT62x/pdaNc5A44xUNKmXCiBwD6O6QFVRUlaBJyevwZ0fLPVr2JodYJyFJX/LQZz13Hde21KSxeefyHgS6Lozc+0ebDvo/U9g9XCy7tZV1S9pTV+1y6sH1Q5bqSLckdIHjh5H7tjJ+KHAexT4pr1GSWPspytw5wdL/dp1AO/p0S3rdhUFbIRfteMw5pttMdsOHvOK186tUd7WUUWMDhBrd1U1jBcW1412hvfmb8FpT3/rif3uD5fisUmrPK/P32QM4PzY1llhT5HxNxLp3X6wX8OanYXIHTsZG/cGXzflXz9sRp8nvva7cfLV67Hp6Pvk1wCAq16fX+vjjMLhZlJwGqvqe3r/ByBXVU8G8A2AdxwPJDJGRPJFJH/v3ur1wHFT15aNMPHm0zD7vqF40hzTcEkfp/wXv8bPKED3h7/yPH/485VeF6Vr3lwY8THnbtzvNXDNOlylKlTVq00EAG54e5HfRILHPVVOxvNJy3bgzg+8Sw1j3luMy16b63l+xriqf6Rg/9CemWwVnqq0UPXGO20X8MpKxeItBxyT3Hl/m4WrXnceQzHq5Tm4csJ83DpxCQY9O9MrXrtPl4QusR0tLcfmEJ0YfHnGwgC4+79LMeJvsz2v9XvqG8f3LN92KOzxF4u3HEDPR6cFnPSxNizbalQfLv3Z+L19+uN2vD13s+d168bNfg4jHTy/p7AEs9bvdewebvncvEGZvir4AM1v1hiv+970OCkpM8YuLXXogRgNbiaFbQDsd/5tAHhVCKrqflUtNZ++DqCv04FUdYKq5qlqXnZ2zSaoc8vAzlnISE32/DGmp9atjl0fL/Hv/lnbN6r2MRjfrNkTdF+Lpx3CfL7/iPOFJtA/m/W+whL/Bmv7P7vV1uDbfuDbZXbG2j2eCQj/OWsTLn11nmca9EhNXrEz6OvLArT5vDKzALljJ6P4eAVu+NciDHn+u4g+115SCKdLcVFJGS4Y/0PYnRNe+34TjpSWY+FPkY01ufy1ufjzx8v92gKcNMpIAQAc9an6W73DKDkk1cL8KRf/Yy6ufWthjf8P9hSVYK55nsMNK5ZVkm5euRYB6CIiHUQkDcCVACbZdxCRE2xPLwAQWfebGOmc3QgA0L9Dc7/XzuxiJK3L8+pETZjHFofeS1truVHNunMsq6jEviOlIfY2+A6Qi/R/3fqHvsih/eTvM4xR5tsPFePzpcb9SpHP+ha+n1dYUu6ZgNBKDqFKFzPW7kbu2MmeGWHD5VjUVsUbZp34oeLjnnmuLBeMn4ObQzScW+fk2a/W4aDDRIW+jc1bDxgJN9w7V2uwZ0mZ98V95fbDQatPFm0+iA/zt+LJyVWXgdLyCjz55eqAvdB8z9HIl2djb1EpkoL8nYR7kbc6N9j3D1ZqCOTxSas9j62bjK0HjvmNRQrWRhJNKW4dWFXLReQ2ANMAJAN4S1VXichfAOSr6iQAd4jIBQDKARwAcL1b8dSmXm2aYN79Z6NV4wy/19o2q4/N40bFIKra98gXq/DOjf0dX6vJzdOlr4Y/4tfTS0QVL32zARNmBb4Ar9h22K9h7oHPVmDdrqr5oyxb9h/FJw6lI+t6OHHBFpSWVaJbq0aOn3XmX2fi5zC6stq7eb4wfT3OO6mV437Lth5C77aZXtucGoL/8d1Gz4Xc62JVqdh3tBTLg4yVeH7aOgztnu353S386YDjfrM27MWQbi08zy/+xw/m5wX/rRfsOYK05CRkmL3xVu8sRHpKEs47qRWOlVV4dQF9cOSJ+N2ZHR2P882a3XjsgpOgqrjs1XlYsf2w1w3Kz/uPeVUV+er31De4sp/DTZl5Pu0/hao6nmc7e+P7+t1H0L65/6SYwdoU7G1OVrIa/Fejt6D9WnGWrQeh7+Jb0eRqHYeqTlHVrqraSVWfMrc9YiYEqOr9qnqSqvZW1aGqujb4EeOHNVleMPPvPwfjr+6DwV2yPNvy2jd1O7RaY93NO/Vfj1bXRWvxnWXbDuPFb9b7VRfY/Xv+Fox5b7Hf9nfm+Y8GfSjAtAFrdhZiyc8H8eBnK/GXL1cHHOcQTkLwVRHkonrhK/4lGbuP8rdi3sb9XonMftdaXqno/1TgkemqivEzC4yEHOLifv2/vEsapQ49rpwMe+F7nPncTKSYswpPmLUJf5i4BHd9uNRvksTXZ2/Cc9PWYvTfZzsdCoCxSuEKcyyFvafa6p1ViS/QTxJsIKG9bSTQn7G9C7S9x9nv3s13rHL7ePE2T0O2L/ucZTPX7cXZtuq+gbb2JHvJamiEVYK1qW5VfNcxrZpkYPTJrXFV/3aebf/53QBMvXMwurSo3sR60bRqRyF6P+485fX+o+FV/9TUzjDnZQLCX6jnSGk58jcHHtR2yT+qGq2fn+7f26i6QiVS3xHG9nuOez9e7tdwbW/gnrXeuwPGw5+vxN6iqt+RvV9+OCOvrd5c9iqvcO8DfKtuJjmMLRAxxkas3O4/JYiVs44GmMAxJanqsvXEl6sdb1qCreT3xdKqeMoDdIVbYht349vO9OxXtntX2zmZsWaPV5XQ0q2H/Ko/X/t+o1epNdCo+3ATsRtcqz6iKuf2aOl5nJaShBNPaIx+HZphg8NUEPEmUF2uNTgsnoQ7O6vvXWu0OM39ZNfvSe+eP//6YbPfPpv2Vl1Q7D/vze/me+333vwt2Hm4GG9c1w9vzN4U8cjwd+dtxu7CUq+SyeHiMmw9cAxtmtZDpRpdf4vLKtC1pXcVWziNvL4N+L4qKjXgNBzJydVrRHZ6V0WlYk9RCS59dS7evqE/OmUbN2uTl1d1AvBNCoHaVqx5xTaPG4Ut+4/iohClv3jFpBAFqcn+BbLfn9kJnyzeFtM7AoquStWgM9IW2yZNHPlS4GoVS6iR1t+s2YOtB455NdyGK0nEsc1l8F9n4vazO3sa6QH4taGFM3mbPW9UVqrX/8H2Q8V4bNKqgMeJZEqXQJ9pKa9UTF6+E1sPFOOduZtx4Smt0aVlI3xlKyEdL4+sqvTej5YhOVhLt4NwZyqOBiaFGGnXvD7WPXk+Oj8wJeQdJP0ybDtYHPaMtOHMtBpoNT+7QPXcoQS72bcnBMC4qCdFeBG07/2oQwL41CEh2fevLRUV6hnR/vXq3Xh33hbcfnZnr30iXR/ko2rM7nvLRP+2sFhhm0KMLX9seNDXX/h1byx84Jyg++Rk1qu1eFrU4VHZiSaci1V6SnLIfZxEMlVzxwemRDxZ3A7bnfFHi/2rIoN1KKgu30nxAGOSRmvsi9V+tXaX9wDKOwKMzzhSWo5/1mC+MktpeQWW/BybgWpOmBRiLD0lGVkN0zxTZfi65NQ2aOHQ9dXu/J7O3Ryro1+HZrV2LHLX/DCWIvWdqDBchxzGLwQzbmr1Ow76jmeIpiU/H/KrJgunSvdIaXmttU3ZOwQE49a0J76YFGIsOUmQ/9C5WPZoVYnhjWvz/PZ7eLTTBLOGK/vX3kC5ZvXTau1Y5K7npq0LuU9tVrX8Uvk2wgdaEMtuZ5TX6gCi1yOJSSFKZt83FPkPDQu6T1ZD44I8uGuW32s3DeqA+mlVVQEN0pKRZjZgpyVXr4rASU7T2quKIqqLSsNICtVd58OJ04SKTgKNraltTApR0rZZfWQ1DF5f/8GY0zHhmr4B64Hn/PlkTiRhAAATJ0lEQVRsTL5jEACze7TZWpeaUtVsN/7qPgGPf3KbJiHjvHFgh5D7VNfAzv7TghDFm3DGccxYG97cXeFYEeYKcN9VczneSDEpxJHOLRpieIBpEACgWYM05JpD7FWN0gJQ1S88u1E6Rp/cGlPvHIwnLvRfEOit6/vhd4P9L/qvX5uHu4Z1weZxo0IuGDTzT0M8j28d2gkz7jkLHbIa4MvbB4X8+V6+MnDC+qW5rG+bWIdALvrHd7W3Cttt/wlvosFjAQbz1TYmhTrGSgCVqmhozhRZUlaB16/Nwxe3DgQAnHhCY1xzei42PHU+1j4xAgDQMasBshqm4w9DOuP8nq3w7T1neY55bo+WuGtY17A+v3FGCgZ1Nqq3rjsjFx2zG2Lmn4agZ07wUkh6ShIapAfuAX3iCY3D+vxAnjCnLI8XT10cX/FQ3RdOF+TawKRQx1j9xxXA85f1Ru+2mTihST2c26MlWvt0TU1NTkJGajI2jxuFGeYdfrMGaXj1t309IzcjlZwk+PfNp2HzuFFo0Sh4ryhfToP4LFPuCF3SCORPw7vi7O4tQu43tFvgaddfufrUan++k+p2BSUKJNLxEtXFpFDHWBfWawe0x2kdm+OLWwfW2hrRvq7q3w6XnupdDRJqEsBAROA4yvP1a/Mw5Y7Bjsft2tI5cZ1im030klNzcMuQzkgNY/BUXm5Vd9uVj5+H16/NQ05mPbRolI5RJ58Q5J3ui3DsV53Qvnn9WIdQI9efkRvrEPzUZgN3IBzRHKcuOqW14xS9yUmCjU+PjMpFZHCXLKza4d0IFklO6N2miafR7oxO/j2qAO95oXIy62H7oWKMu6QXzuiUhbSUJAx6dgbKKxX3ntcNAzo2w0mtmyAlSdD5wakAjGqxpCTxzMxp+cOQTnjVp97X3sujYXoKzu3R0uvzC546HyKCN+dswtNTojthb+N6qY5jA4Z2y8bMKDUw1rZ6qbEtLU2+Y1DYa3k7+eOwrkGn6I6FY8fLA45pqi1MCnHqb0EaZSOdV6W6BED9NO8/kXAmO3v8gpPw6CRjLYaiknKUlFWgTdOqu8ZzurfAM5f2QrHPqNVpfzwTxccrvNa6Xvn4efjn95sw5syOjtVP1vQKvlH9eUR3v6RgTSfStplzt1srsQSarK1XTpOAPUU6ZDXAT+bsl9VZT8OpW2LHrAZ+yc4NbZrWQ4esBqhUxfHySiwKMoOsk4dGneg3v1Lvtplhde2M1NMX98IDn60Ia9+TWjfBAyO7VzvBN6nv7sW3OqIx0I9JIYF9efugkCNXbxrUAaXllbiiX1us2HYIDYM0Fs/581AcL69Ex+yGuM4semf6DIZb/ZfzkJac5Hixa5ie4nf8jNRk3DmsS8DPHNzZaCdo2iAN943ohrz2zZAaYBbNispKLH3k3Gon1cvz2vglhVPbZeLtG/sjLTnJa41rwDi/9oVlAmmUnuI1rfZ7N/XH9FW7cd0Z7XHgaBm+Xh18/d+amvPnswEYI2Z3F5aGNT/TB2MG4MoJ85FZPxU3D+6Ibq0aea3jfcMZuejdNrPW1gVIS0nC8fJK9G5rlBTDnS/sxoEdMGPtHszfZCwoNOzElp71kuuiaCy+wzaFBNYzpwkGdfGv1hlh6xabkZqMu8/tipzMehjRM3i9e5um9dExRAN2/bSUWrn73fDU+djw1PnoZRt7ccuQzujfoRn6tPNeyMga9Ne1ZSNk1k9Do4zgd4CBCkPDTmzpt+2k1k3QOCMVGanJePRXPfDf/3e657WeOU2w9okRuHlQBwzv4f9ey9d3n+W5yL3w694Y3CUbT1zUE51bNEL/Wph25Ju7z8T4q/tghq3HmcU+kl5E/GZevcsnIVsDLK3kbZVwkm0n7bK+bTCiZyt0yGqA3ADtCoG22z1zSS/89MxIbB43Ch2zjKrUUFNuW6yfNSU5CReekuPZPuEax2XgPaKxzkmbGgwQ3RxkOdPawqRAfm4e3AEi3g2z8SY1OSlobyagqpvrfed1w/Q/nokLereO6DPO7t4C39nGZbTOrIfN40bh81sH4plLegGAV1XXDQM7+F3EM1KT8dDoHpjgMHXJ2d1bYPO4UWjVJANPXdwL2Y3ScZHtAuakkU9J6poB7T1J7F/X98MVtrXBG6Wn4K+XnYzOLRph9Mmt0TG7IRbbRtXfOrSTX/207wDLu4Z1xfz7qyZktKoPrdKWVcKxz5L6/OW9Pes0O93Qt2qcgSv6tfPa5tt7rHurRriqf7tqd2yw35zY/05Czeb6Woik4STUoFS7f13fD7PuHeq3vU+7TMcxRL581xB3A6uPyE9ebjP89EzdX2d68u2DsHpnIU5q3Tiii0sn825xRM9WyM3yb+w/pW0murVshONmtVq4Fj80DHd88CN+KNiPW4d2wq1Dq6Zovqxvm7AGvDVtkIaZ9w5BkgiaNTDu2udt2o+CPUfQOrMenr3sZHyYb8w6uuLx8/ze39y8gJ3RqTnuPa+73+v10pLRvVUjrN1VhL9dcQoAYwXBd2/sj5KyChSVlOOej5Z5ZtO1lhhtkOZ8KblneFfc+cFS/Gl4Vzw/fT0yUpMw409noV5qMn43uAN+Nf4H5GRm4LELTkLLxum477zu6PPE17jNZ/pqi9Ov8aZBHTBz3R6vBYjsrOrEcJbCDdZVe3CXLMzduN+r/eeruwYju2E6+poLJJ3RqTnmbtyPAR2b4ZYhnXHtW1VVak5tTU3qpeJwcRnuPrcrvg/SocA6f+HMy1RTTAr0i5WUJCEH1TkZ2q0Fpt45GN1bGSuKTbimLxr73FHXS0v2tJuEq3nDdM80571yMv0a8YPpmNUAV5/WDued1MrvzrTHCY1RsOcIUsJckWzT0yOD9iKzkoK91HNm16oxHpf2beO5OFkXyJ45jfHUxT0xupd3aezCU3I81Tc3DOyAjNRkTykjJVkw9c7Bnn2fueRkAM4XT2uC0CQRnN6pOWZv2Ic+7TLx48+HcHGfHE/ycWp/sRJWoFaImwZ1wJtzfgp4PgDjnFmlDPvyn91bGaXRCdf0RcOMFPz9W2OtidvP7oKBnauqZn/lU0qd8+ehKC2vxN++2YD/LdvhtcSorwdHnohrTm9vJIVyJgWimLCPsA429UikHhrdAzmZ9b26woay9okRSBIJOB5l3KW9cEHv1mEPSAxVhfLMJSfjmtPb+w2GtEtPScLJbZrgD2d1AmC0R/zmtPZBjxtsRHsop3dqjnW7i9C0fipe+21fbD9U7LcMqFW11inbu3Q3uGsWrhnQHlef5l1lNbBzc7RsnIGHR/fAoC5ZOFZqXHA/u+UMv8GH9nP275tOw/+W7cDin6t6aVl/Iy9/uwGAf2+4v1/l3ZvQ6o335EU9cXJOEwzo2AytMzMwdeUuPHvpyVi+/RAu69sGSSLIapjumTabvY+IfmEaZ6QG7U3lJCNEf//6aSkYZksyjTNSUFiDuud6acno2z54e5KIYNJt1R+FHqkHR52I68/I9awt4psQAOA3A9rh0x+34983n+a1PT0l2WsalAXmolUtbeuUDO1W1abh21HB16AuWY4dNICq35U9iaQEScJN6qXid2d2BAC0b94AP4w92/MZdiKCdLMHltskWgs31Ja8vDzNz88PvSNRgiosKUNpWaVXIzhVz5qdhfjx50N+pYxA9hSV4N25W3D3uV2RlCT4Yul29MppErJXXjiqs+ypnYgsVlX/Hg+++7mZFERkBICXACQDeENVx/m8ng7gXQB9AewHcIWqbg52TCYFIqLIhZsUXOuSKiLJAF4BcD6AHgCuEhHf5cNuAnBQVTsDeBHAs27FQ0REobk5TqE/gAJV3aSqxwF8AOBCn30uBPCO+fhjAOdIdTsmExFRjbmZFHIAbLU932Zuc9xHVcsBHAbA5bmIiGLEzaTgdMfv24ARzj4QkTEiki8i+Xv31s0ZI4mI6gI3k8I2APbhnm0A7Ai0j4ikAGgC4IDvgVR1gqrmqWpednbghVKIiKhm3EwKiwB0EZEOIpIG4EoAk3z2mQTgOvPxZQBmaF3rI0tE9Avi2uA1VS0XkdsATIPRJfUtVV0lIn8BkK+qkwC8CeA9ESmAUUK40q14iIgoNFdHNKvqFABTfLY9YntcAuByN2MgIqLw1bkRzSKyF8CWar49C8C+WgzHDYyx5uI9PiD+Y4z3+ID4jzHe4muvqiEbZetcUqgJEckPZ0RfLDHGmov3+ID4jzHe4wPiP8Z4jy8QLrJDREQeTApEROSRaElhQqwDCANjrLl4jw+I/xjjPT4g/mOM9/gcJVSbAhERBZdoJQUiIgoiYZKCiIwQkXUiUiAiY2MUQ1sRmSkia0RklYjcaW5/TES2i8hS82uk7T33mzGvExH/ldjdiXOziKwwY8k3tzUTka9FZIP5vam5XUTkZTPG5SJyqsuxdbOdp6UiUigid8X6HIrIWyKyR0RW2rZFfM5E5Dpz/w0icp3TZ9VyjM+JyFozjs9EJNPcnisixbbz+ZrtPX3Nv48C8+eolZmNA8QX8e/Vzf/1ADF+aItvs4gsNbdH/RzWClX9xX/BGFG9EUBHAGkAlgHoEYM4TgBwqvm4EYD1MNaaeAzAnxz272HGmg6gg/kzJEchzs0Asny2/RXAWPPxWADPmo9HApgKY3LDAQAWRPn3ugtA+1ifQwBnAjgVwMrqnjMAzQBsMr83NR83dTnG4QBSzMfP2mLMte/nc5yFAE43458K4HwX44vo9+r2/7pTjD6v/x+AR2J1DmvjK1FKCuGs7eA6Vd2pqkvMx0UA1sB/OnG7CwF8oKqlqvoTgAIYP0ss2Ne+eAfARbbt76phPoBMETkhSjGdA2CjqgYbzBiVc6iqs+A/mWOk5+w8AF+r6gFVPQjgawAj3IxRVaerMW09AMyHMXFlQGacjVV1nhpXt3dtP1etxxdEoN+rq//rwWI07/Z/DeD9YMdw8xzWhkRJCuGs7RBVIpILoA+ABeam28wi/FtWNQNiF7cCmC4ii0VkjLmtparuBIzkBsBa6TyW5/ZKeP8DxtM5BCI/Z7H+O70Rxl2rpYOI/Cgi34vIYHNbjhmXJRoxRvJ7jeU5HAxgt6pusG2Ll3MYtkRJCmGt2xAtItIQwCcA7lLVQgCvAugE4BQAO2EUQYHYxT1QVU+FsZTqrSJyZpB9YxKjGDPvXgDgI3NTvJ3DYALFFLNYReRBAOUAJpqbdgJop6p9ANwN4D8i0jgGMUb6e43l7/sqeN+kxMs5jEiiJIVw1naIChFJhZEQJqrqpwCgqrtVtUJVKwG8jqrqjZjErao7zO97AHxmxrPbqhYyv++JZYwwEtYSVd1txhpX59AU6TmLSaxmg/ZoAL8xqzNgVsvsNx8vhlFP39WM0V7F5GqM1fi9xuocpgC4BMCH1rZ4OYeRSpSkEM7aDq4z6xzfBLBGVV+wbbfXwV8MwOrZMAnAlSKSLiIdAHSB0UDlZowNRKSR9RhGQ+RKeK99cR2AL2wxXmv2qBkA4LBVZeIyr7uyeDqHNpGes2kAhotIU7OaZLi5zTUiMgLAnwFcoKrHbNuzRSTZfNwRxnnbZMZZJCIDzL/na20/lxvxRfp7jdX/+jAAa1XVUy0UL+cwYrFu6Y7WF4weH+thZOsHYxTDIBjFxOUAlppfIwG8B2CFuX0SgBNs73nQjHkdotBDAUavjWXm1yrrXMFYO/tbABvM783M7QLgFTPGFQDyohBjfQD7ATSxbYvpOYSRoHYCKINxJ3hTdc4ZjHr9AvPrhijEWACjDt76e3zN3PdS8/e/DMASAL+yHScPxsV5I4DxMAfBuhRfxL9XN//XnWI0t78N4Pc++0b9HNbGF0c0ExGRR6JUHxERURiYFIiIyINJgYiIPJgUiIjIg0mBiIg8mBQoYYnIXPN7rohcXcvHfsDps4jiHbukUsITkSEwZuIcHcF7klW1IsjrR1S1YW3ERxRNLClQwhKRI+bDcQAGm3Pe/1FEksVYZ2CRORHb/zP3HyLGehj/gTGgCiLyuTlx4Cpr8kARGQegnnm8ifbPMkcxPyciK8359K+wHfs7EflYjPUNJpqjXYmiKiXWARDFgbGwlRTMi/thVe0nIukAfhCR6ea+/QH0VGO6ZgC4UVUPiEg9AItE5BNVHSsit6nqKQ6fdQmMyd16A8gy3zPLfK0PgJNgzIPzA4CBAObU/o9LFBhLCkT+hsOYm2gpjKnNm8OYtwYAFtoSAgDcISLLYKxF0Na2XyCDALyvxiRvuwF8D6Cf7djb1Jj8bSmMRVqIooolBSJ/AuB2VfWajM5sezjq83wYgNNV9ZiIfAcgI4xjB1Jqe1wB/n9SDLCkQAQUwVge1TINwB/Mac4hIl3NGWN9NQFw0EwI3WEsrWkps97vYxaAK8x2i2wYyztGa9ZWopB4J0JkzMBZblYDvQ3gJRhVN0vMxt69cF4u8SsAvxeR5TBm6pxve20CgOUiskRVf2Pb/hmMtXmXwZgx9z5V3WUmFaKYY5dUIiLyYPURERF5MCkQEZEHkwIREXkwKRARkQeTAhEReTApEBGRB5MCERF5MCkQEZHH/wdyQDKHKlHKiwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.recorder!.plotLosses()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Progress bar" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's nice to keep track of where we're at in the training with a progress bar." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "// export\n", "import Foundation" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "// export\n", "func formatTime(_ t: Float) -> String {\n", " let t = Int(t)\n", " let (h,m,s) = (t/3600, (t/60)%60, t%60)\n", " return h != 0 ? String(format: \"%02d:%02d:%02d\", h, m, s) : String(format: \"%02d:%02d\", m, s)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"01:18\"\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "formatTime(78.23)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "// export\n", "public struct ProgressBar{\n", " let total: Int\n", " let length: Int = 50\n", " let showEvery: Float = 0.02\n", " let fillChar: Character = \"X\"\n", " public var comment: String = \"\"\n", " private var lastVal: Int = 0\n", " private var waitFor: Int = 0\n", " private var startTime: UInt64 = 0\n", " private var lastShow: UInt64 = 0\n", " private var estimatedTotal: Float = 0.0\n", " private var bar: String = \"\"\n", " \n", " public init(_ c: Int) { total = c }\n", " \n", " public mutating func update(_ val: Int){\n", " if val == 0 {\n", " startTime = DispatchTime.now().uptimeNanoseconds\n", " lastShow = startTime\n", " waitFor = 1\n", " update_bar(0)\n", " } else if val >= lastVal + waitFor || val == total {\n", " lastShow = DispatchTime.now().uptimeNanoseconds\n", " let averageTime = Float(lastShow - startTime) / (1e9 * Float(val))\n", " waitFor = max(Int(averageTime / (showEvery + 1e-8)), 1)\n", " estimatedTotal = Float(total) * averageTime\n", " update_bar(val)\n", " }\n", " }\n", " \n", " public mutating func update_bar(_ val: Int){\n", " lastVal = val\n", " let prevLength = bar.count\n", " bar = String(repeating: fillChar, count: (val * length) / total)\n", " bar += String(repeating: \"-\", count: length - (val * length) / total)\n", " let pct = String(format: \"%.2f\", 100.0 * Float(val)/Float(total))\n", " let elapsedTime = Float(lastShow - startTime) / 1e9\n", " let remaingTime = estimatedTotal - elapsedTime\n", " bar += \" \\(pct)% [\\(val)/\\(total) \\(formatTime(elapsedTime))<\\(formatTime(remaingTime))\"\n", " bar += comment.isEmpty ? \"]\" : \" \\(comment)]\"\n", " if bar.count < prevLength { bar += String(repeating: \" \", count: prevLength-bar.count) }\n", " print(bar, terminator:\"\\r\")\n", " fflush(stdout)\n", " }\n", " \n", " public func remove(){\n", " print(String(repeating: \" \", count: bar.count), terminator:\"\\r\")\n", " fflush(stdout)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "var tst = ProgressBar(100)\n", "for i in 0...100{\n", " tst.update(i)\n", " usleep(50000)\n", "}\n", "tst.remove()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "// export\n", "extension Learner {\n", " public class ShowProgress: Delegate {\n", " var pbar: ProgressBar? = nil\n", " var iter: Int = 0\n", " \n", " public override func epochWillStart(learner: Learner) {\n", " pbar = ProgressBar(learner.data.train.count)\n", " }\n", " \n", " public override func validationWillStart(learner: Learner) {\n", " if pbar != nil { pbar!.remove() }\n", " pbar = ProgressBar(learner.data.valid.count)\n", " }\n", " \n", " public override func epochDidFinish(learner: Learner) {\n", " if pbar != nil { pbar!.remove() }\n", " }\n", " \n", " public override func batchWillStart(learner: Learner) {\n", " if learner.currentIter == 0 {pbar!.update(0)}\n", " }\n", " \n", " public override func batchDidFinish(learner: Learner) {\n", " pbar!.update(learner.currentIter)\n", " }\n", " \n", " public override func trainingDidFinish(learner: Learner) {\n", " if pbar != nil { pbar!.remove() }\n", " }\n", " }\n", " \n", " public func makeShowProgress() -> ShowProgress { return ShowProgress() }\n", "}" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunction: softmaxCrossEntropy, optimizer: opt, initializingWith: modelInit)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeShowProgress(), \n", " learner.makeAvgMetric(metrics: [accuracy]), learner.makeRecorder(),\n", " learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std)]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.29753134, 0.9164] \n", "Epoch 1: [0.2402218, 0.9305] \n", " \r" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Annealing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "/// A non-generalized learning rate scheduler\n", "extension Learner where Opt.Scalar: BinaryFloatingPoint {\n", " public class ParamScheduler: Delegate {\n", " public override var order: Int { return 1 }\n", " public typealias ScheduleFunc = (Float) -> Float\n", "\n", " // A learning rate schedule from step to float.\n", " public var scheduler: ScheduleFunc\n", " \n", " public init(scheduler: @escaping (Float) -> Float) {\n", " self.scheduler = scheduler\n", " }\n", " \n", " override public func batchWillStart(learner: Learner) {\n", " learner.optimizer.learningRate = Opt.Scalar(scheduler(learner.pctEpochs/Float(learner.epochCount)))\n", " }\n", " }\n", " \n", " public func makeParamScheduler(scheduler: @escaping (Float) -> Float) -> ParamScheduler {\n", " return ParamScheduler(scheduler: scheduler)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "public func linearSchedule(start: Float, end: Float, pct: Float) -> Float {\n", " return start + pct * (end - start)\n", "}\n", "\n", "public func makeAnnealer(start: Float, end: Float, schedule: @escaping (Float, Float, Float) -> Float) -> (Float) -> Float { \n", " return { pct in return schedule(start, end, pct) }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.037\n" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "let annealer = makeAnnealer(start: 1e-2, end: 0.1, schedule: linearSchedule)\n", "annealer(0.3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunction: softmaxCrossEntropy, optimizer: opt, initializingWith: modelInit)\n", "let recorder = learner.makeRecorder()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeShowProgress(), \n", " learner.makeAvgMetric(metrics: [accuracy]), recorder,\n", " learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std),\n", " learner.makeParamScheduler(scheduler: annealer)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.20679843, 0.9399] \n", "Epoch 1: [0.15625015, 0.9509] \n", " \r" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd4VAXa/vHvQ+i9d0LoHVEDiL2LroiIu6KuYlnR3XXrb4VgWbGDZV3fVdcX26pr2yWgWLH3RlBIo4XQQodACCWkPb8/Zth3ZIEQyORMkvtzXbkyc+bM5M6ZJHfOmZlnzN0RERE5mFpBBxARkdinshARkTKpLEREpEwqCxERKZPKQkREyqSyEBGRMqksRESkTCoLEREpk8pCRETKVDvoABWldevWnpCQEHQMEZEqZd68eZvdvU1Z61WbskhISCAlJSXoGCIiVYqZrTyU9XQYSkREyqSyEBGRMqksRESkTCoLEREpU1TLwsxGmtliM8sys6T9XH6ymX1vZsVmdvE+l403s6Xhj/HRzCkiIgcXtbIwszjgMeBcoD9wqZn132e1VcBVwEv7XLclcDswHBgG3G5mLaKVVUREDi6aexbDgCx3z3b3QuAVYHTkCu6+wt1TgdJ9rnsO8L6757r7VuB9YGQUs4qIyEFEsyw6AasjzueEl0X7uiIiNYK78+rcVXyQuSHqXyuaZWH7WXaob/h9SNc1swlmlmJmKZs2bSpXOBGRqmzVll1c/tS3TEpO47X5a6L+9aL5Cu4coEvE+c7A2nJc99R9rvvJviu5+3RgOkBiYuKhFpGISJVVUur846sVPDhnMXG1jHvGDOTSofFR/7rRLIu5QC8z6wasAcYBlx3idecA90Y8qH02MLniI4qIVB1LNuQzcUYq81dv4/S+bblnzEA6NGtQKV87amXh7sVmdiOhP/xxwDPunmFmdwIp7j7bzIYCs4AWwCgzu8PdB7h7rpndRahwAO5099xoZRURiWWFxaX8/ZNlPPrxUprUr8Mj44ZwwVEdMdvfEfvoMPfqcfQmMTHRNUhQRKqbBau3MSk5lUXr87ngqI7cPqo/rRrXq7DbN7N57p5Y1nrVZuqsiEh1sruwhIc/WMJTn2fTtkl9nroykTP7twssj8pCRCTGfL1sC5NnprJiyy4uHRbP5PP60rR+nUAzqSxERGLE9oIipr6ziJe+XUXXVg156brhHN+jddCxAJWFiEhM+HDhBm6Zlc7G/AImnNydP5zZmwZ144KO9R8qCxGRAG3ZsYc73shk9oK19GnXhCeuOJYhXZoHHeu/qCxERALg7sxesJY73sgkv6CIP5zZm1+e2oO6tWPznSNUFiIilWxd3m5unZXOh4s2clSX5tw/djB92jcJOtZBqSxERCpJaanzytzV3Pf2QopKS7n1J/24+oRuxNWqvBfXHS6VhYhIJVixeSdJM1P5JjuX43u0YupFg4lv1TDoWIdMZSEiEkXFJaU88+VyHnpvCXXjajH1okFcMrRLpY7qqAgqCxGRKFm0fjuTZqSyICePM/u14+4LB9K+Wf2gYx0WlYWISAXbU1zCYx8v4/GPs2jWoA5/u/Rozh/cocrtTURSWYiIVKAfVm1lUnIqSzbsYMzRnbjt/P60bFQ36FhHTGUhIlIBdhUW89B7S3jmy+W0b1qfZ68ayml92wYdq8KoLEREjtBXWZtJmpnGqtxd/Py4eCaN7EuTgAf/VTSVhYjIYcrbXcR9by/klbmr6da6Ea9OOI7h3VsFHSsqVBYiIofhvYz13PpaOpt37OH6U0KD/+rXiZ3BfxVNZSEiUg6bd+xhyuwM3kxdR9/2TXhqfCKDO8fe4L+KprIQETkE7s5r89dwxxuZ7NpTwv87qzc3nNqDOnGxOfivoqksRETKsHbbbm6ZlcbHizdxTHxzpo0dTK92sT34r6KpLEREDqC01Hnxu1VMfXshpQ63j+rPlSMSqsTgv4qmshAR2Y/sTTtISk7juxW5nNizNfddNIguLavO4L+KprIQEYlQXFLKU18s5+H3l1Cvdi3uv3gwPz22c5Ue1VERVBYiImGZa7czMXkB6Wu2c86Adtw1eiBtm1bNwX8VTWUhIjXenuISHv0oi79/sozmDevw+OXHcO7A9jV+byKSykJEarR5K3OZlJxG1sYdjD2mM7ed34/mDav+4L+KprIQkRpp555iHpizmOe+XkHHZg147pphnNK7TdCxYpbKQkRqnM+XbmLyzDRytu5m/Iiu3DSyL43r6c/hwWjriEiNkberiLvfyuTf83Lo3qYR/75hBEMTWgYdq0pQWYhIjfBu+npuez2d3J2F/OrUHvz2jF7VevBfRVNZiEi1tjG/gCmzM3g7bT39OzTl2auGMrBTs6BjVTkqCxGpltyd5O/XcNebmewuKuGmc/ow4eTuNWbwX0VTWYhItZOzdRc3z0rnsyWbSOzagqljB9OzbeOgY1VpKgsRqTZKS50XvlnJtHcXAXDHBQO44riu1KqBg/8qmspCRKqFZZt2MGlGKikrt3Jy7zbcO2YgnVvU3MF/FS2qB+/MbKSZLTazLDNL2s/l9czs1fDl35pZQnh5HTN7zszSzGyhmU2OZk4RqbqKSkp57OMszn3kc5Zu3MGDPz2K564eqqKoYFHbszCzOOAx4CwgB5hrZrPdPTNitWuBre7e08zGAdOAS4CfAvXcfZCZNQQyzexld18RrbwiUvWkr8lj4oxUMtdt57xB7ZlywQDaNtHgv2iI5mGoYUCWu2cDmNkrwGggsixGA1PCp2cAj1pocpcDjcysNtAAKAS2RzGriFQhBUUlPPLhUqZ/lk3LRnV54ufHMHJgh6BjVWvRLItOwOqI8znA8AOt4+7FZpYHtCJUHKOBdUBD4A/unhvFrCJSRcxdkcukGalkb97JT4/tzK0/6U+zhnWCjlXtRbMs9vf0Az/EdYYBJUBHoAXwuZl9sHcv5T9XNpsATACIj48/4sAiErt27Cnm/ncX8fzXK+ncogEvXDuMk3pp8F9liWZZ5ABdIs53BtYeYJ2c8CGnZkAucBnwrrsXARvN7EsgEfhRWbj7dGA6QGJi4r5FJCLVxKdLNnHzzDTW5u3mquMTuOmcPjTS4L9KFc1nQ80FeplZNzOrC4wDZu+zzmxgfPj0xcBH7u7AKuB0C2kEHAcsimJWEYlBW3cW8sd/zWf8M99Rv04tZtwwgikXDFBRBCBqWzz8GMSNwBwgDnjG3TPM7E4gxd1nA08DL5hZFqE9inHhqz8GPAukEzpU9ay7p0Yrq4jEFnfnnfT1/Pn1dLbtKuI3p/fk16f11OC/AFnoH/mqLzEx0VNSUoKOISJHaOP2Am57PZ05GRsY1KkZ08YOpn/HpkHHqrbMbJ67J5a1nvblRCQmuDv/npfD3W9msqe4lKRz+/KLE7tRW4P/YoLKQkQCtzp3F5NnpvFF1maGJbRk6thBdG+jwX+xRGUhIoEpKXWe/3oF97+7mFoGd104kMuHxWvwXwxSWYhIIJZuyGdScirfr9rGqX3acM+YQXRq3iDoWHIAKgsRqVRFJaU88cky/vZRFo3qxfHXS4YwekhHQpN+JFapLESk0qTl5HHTjAUsWp/P+YM7MOWCAbRuXC/oWHIIVBYiEnUFRSU8/MESnvwsm9aN6zH9imM5e0D7oGNJOagsRCSqvs3eQtLMNJZv3sm4oV2YfF4/mjXQ4L+qRmUhIlGRX1DEtHcX8c9vVtGlZQNe/MVwTujZOuhYcphUFiJS4T5etJGbZ6WxYXsBvzixG388uzcN6+rPTVWme09EKkzuzkLufCOD1+avpVfbxjz+y+M5Or5F0LGkAqgsROSIuTtvpq5jyuwM8nYX8bszevGr03pQr7YG/1UXKgsROSIbthdwy6x0Pli4gcGdm/HidcPp216D/6oblYWIHBZ359W5q7nn7YUUFpdyy3n9uPqEBA3+q6ZUFiJSbiu37GTyzDS+WraF4d1aMm3sYBJaNwo6lkSRykJEDllJqfPsl8t58L3F1KlVi3vHDGLc0C4a/FcDqCxE5JAsXp/PxORUFqzexhl923L3mIF0aKbBfzWFykJEDqqwuJTHP8nisY+zaFK/Do+MG8IFR2nwX02jshCRA1qwehsTZ6SyeEM+o4d05M/n96eVBv/VSCoLEfkvuwtL+Mv7i3n6i+W0bVKfp65M5Mz+7YKOJQFSWYjIj3y1bDOTZ6axcssuLhseT9K5fWlaX4P/ajqVhYgAsL2giPveXsTL362ia6uGvHzdcYzo0SroWBIjVBYiwgeZG7jltTQ25e9hwsnd+cOZvWlQV6M65P+oLERqsC079nDHG5nMXrCWvu2bMP2KRI7q0jzoWBKDVBYiNZC7M3vBWqbMzmDHnmL+cGZvfnlqD+rW1qgO2T+VhUgNsy5vN7fOSufDRRsZ0qU59188mN7tmgQdS2KcykKkhigtdV6eu4r73l5ESalz2/n9uer4BOI0qkMOgcpCpAZYvnknScmpfLs8lxN6tuK+MYOJb9Uw6FhShagsRKqx4pJSnvlyOQ+9t4S6tWsxbewgfpbYRaM6pNxUFiLV1MJ125mUnEpqTh5n9W/H3RcOpF3T+kHHkipKZSFSzewpLuGxj5fx+MdZNGtQh0cvO5qfDOqgvQk5IioLkWrk+1VbmTQjlaUbdzDm6E78+fz+tGhUN+hYUg2oLESqgV2FxTw4ZwnPfrWcDk3r8+xVQzmtb9ugY0k1orIQqeK+zNpM0sxUVufu5orjujJxZB+aaPCfVDCVhUgVlbe7iHvfWsirKavp1roRr044juHdNfhPoiOqZWFmI4FHgDjgKXefus/l9YDngWOBLcAl7r4ifNlg4H+BpkApMNTdC6KZV6SqeC9jPbe+ls6WnYXccEoPfn9mL+rX0eA/iZ6olYWZxQGPAWcBOcBcM5vt7pkRq10LbHX3nmY2DpgGXGJmtYF/Ale4+wIzawUURSurSFWxKX8PU97I4K3UdfTr0JSnxw9lUOdmQceSGiCaexbDgCx3zwYws1eA0UBkWYwGpoRPzwAetdDz+84GUt19AYC7b4liTpGY5+7M+mENd76Zya49Jfzp7N5cf0oP6sRp8J9UjmiWRSdgdcT5HGD4gdZx92IzywNaAb0BN7M5QBvgFXe/P4pZRWLWmm27uWVWGp8s3sQx8aHBfz3bavCfVK5olsX+XgHkh7hObeBEYCiwC/jQzOa5+4c/urLZBGACQHx8/BEHFoklpaXOi9+uZOo7i3Bgyqj+XDFCg/8kGNEsixygS8T5zsDaA6yTE36cohmQG17+qbtvBjCzt4FjgB+VhbtPB6YDJCYm7ltEIlVW9qYdJCWn8d2KXE7q1Zp7xwyiS0sN/pPgRLMs5gK9zKwbsAYYB1y2zzqzgfHA18DFwEfuvvfw00QzawgUAqcAD0cxq0hMKC4p5cnPl/PwB0uoX7sWD1w8mIuP7axRHRK4qJVF+DGIG4E5hJ46+4y7Z5jZnUCKu88GngZeMLMsQnsU48LX3WpmfyFUOA687e5vRSurSCzIWJvHpORU0tds55wB7bhr9EDaavCfxAhzrx5HbxITEz0lJSXoGCLlVlBUwt8+WsoTn2bTomFd7ho9gHMHdQg6ltQQ4ceDE8taT6/gFgnQvJW5TJyRyrJNOxl7TGduO78fzRtq8J/EHpWFSAB27inmgTmLee7rFXRs1oDnrhnGKb3bBB1L5IBUFiKV7LMlm5g8M421ebu58riu3DSyL43r6VdRYpt+QkUqSd6uIu56K5MZ83Lo3qYR/7p+BEMTWgYdS+SQqCxEKsG76eu47fUMcncW8qtTe/DbMzT4T6qWMsvCzGoRmtM0sBLyiFQrG/MLuP31DN5JX8+Ajk159qqhDOykwX9S9ZRZFu5eamYLzCze3VdVRiiRqs7dmTEvh7vfWsjuohImjuzDdSd11+A/qbIO9TBUByDDzL4Ddu5d6O4XRCWVSBW2OncXN89K4/Olmxma0IKpYwfTo03joGOJHJFDLYs7oppCpBooLXWe/3oF989ZjAF3jh7Az4d3pZYG/0k1cEhl4e6fRjuISFWWtXEHScmppKzcysm923DvmIF0bqHBf1J9HLQszCyf/x4rDqHR4u7uTaOSSqSKKCopZfpn2TzywVIa1I3joZ8exUXHdNLgP6l2DloW7q53WBE5gPQ1eUyckUrmuu38ZFAHplwwgDZN6gUdSyQq9DoLkXIqKCrhkQ+XMv2zbFo2qssTPz+WkQPbBx1LJKpUFiLlMHdFLpNmpJK9eSc/S+zMLef1p1nDOkHHEok6lYXIIdixp5j7313E81+vpHOLBvzz2uGc2Kt10LFEKo3KQqQMHy/eyC0z01i3vYCrT0jgT2f3oZEG/0kNo594kQPYurOQu97MZOYPa+jZtjEzbjieY7u2CDqWSCBUFiL7cHfeTlvP7bPT2bariN+e3pNfn96TerU1+E9qLpWFSISN2wu49bV03svcwKBOzXj+muH076iXE4moLEQI7U38OyWHu97KpLC4lMnn9uXaE7tRW4P/RACVhQirtoQG/32RtZlh3Voy9aJBdNfgP5EfUVlIjVVS6vzjqxU8OGcxcbWMuy8cyGXD4jX4T2Q/VBZSIy3dkM/E5FR+WLWN0/q04Z4xg+jYvEHQsURilspCapTC4lKe+HQZj36URaN6cfz1kiGMHtJRg/9EyqCykBojNWcbE2eksmh9PqOO6sjto/rTurEG/4kcCpWFVHsFRSU8/P4Snvw8mzZN6vHklYmc1b9d0LFEqhSVhVRr32RvISk5lRVbdnHpsC4knduPZg00+E+kvFQWUi3lFxQx9Z1FvPjtKuJbNuSlXwzn+J4a/CdyuFQWUu18tGgDt8xKZ8P2An5xYjf+eHZvGtbVj7rIkdBvkFQbuTsLufONDF6bv5be7Rrz+OXHc3S8Bv+JVASVhVR57s4bqeuYMjuD/IIifndGL359Wk/q1taoDpGKorKQKm19Xmjw3wcLN3BU52ZMu3g4fdtr8J9IRVNZSJXk7rwydzX3vrWQotJSbjmvH9ec2I04jeoQiQqVhVQ5K7fsJCk5ja+zt3Bc95ZMvWgwCa0bBR1LpFpTWUiVUVLqPPvlch58bzF1atXivosGcUliFw3+E6kEUX0E0MxGmtliM8sys6T9XF7PzF4NX/6tmSXsc3m8me0wsz9FM6fEvsXr87no719x91sLObFna97/4ylcqgmxIpUmansWZhYHPAacBeQAc81strtnRqx2LbDV3Xua2ThgGnBJxOUPA+9EK6PEvsLiUh7/JIvHPs6iSf06/M+lRzNqcAcN/hOpZNE8DDUMyHL3bAAzewUYDUSWxWhgSvj0DOBRMzN3dzO7EMgGdkYxo8Sw+au3MWlGKos35DN6SEduHzWAlo3qBh1LpEaKZll0AlZHnM8Bhh9oHXcvNrM8oJWZ7QYmEdor0SGoGmZ3YQkPvbeYZ75cTtsm9Xl6fCJn9NPgP5EgRbMs9necwA9xnTuAh919x8EON5jZBGACQHx8/GHGlFjy1bLNJCWnsSp3F5cNjyfp3L40ra/BfyJBi2ZZ5ABdIs53BtYeYJ0cM6sNNANyCe2BXGxm9wPNgVIzK3D3RyOv7O7TgekAiYmJ+xaRVCHbC4q47+2FvPzdahJaNeTl645jRI9WQccSkbBolsVcoJeZdQPWAOOAy/ZZZzYwHvgauBj4yN0dOGnvCmY2Bdixb1FI9fFB5gZueS2NTfl7uP7k7vz+zN40qBsXdCwRiRC1sgg/BnEjMAeIA55x9wwzuxNIcffZwNPAC2aWRWiPYly08kjs2bJjD1PeyOSNBWvp274JT16ZyODOzYOOJSL7YaF/5Ku+xMRET0lJCTqGHAJ35/X5a7njjQx27CnmN6f34oZTemjwn0gAzGyeuyeWtZ5ewS2Vau223dz6WjofLdrIkC7Nuf/iwfRu1yToWCJSBpWFVIrSUuel71Yx9Z1FlJQ6t53fn6uOT9DgP5EqQmUhUbd8806SklP5dnkuJ/RsxX1jBhPfqmHQsUSkHFQWEjXFJaU8/cVy/vL+EurWrsX9Ywfz08TOGtUhUgWpLCQqMtduZ1JyKmlr8jirfzvuvnAg7ZrWDzqWiBwmlYVUqD3FJTz6URZ//2QZzRvW4bHLjuG8Qe21NyFSxakspMLMW7mVScmpZG3cwUVHd+K28/vTQoP/RKoFlYUcsV2FxTwwZzH/+GoFHZrW59mrh3Jan7ZBxxKRCqSykCPyxdLNJM1MJWfrbq4c0ZWJI/vSuJ5+rESqG/1Wy2HJ213EPW9l8q+UHLq1bsS/rh/BsG4tg44lIlGispBym5OxntteS2fLzkJ+eWoPfndGL+rX0eA/kepMZSGHbFP+HqbMzuCttHX069CUp8cPZVDnZkHHEpFKoLKQMrk7M79fw51vZrK7sISbzunDhJO7UydOg/9EagqVhRzUmm27uXlmGp8u2cSxXVswbewgerbV4D+RmkZlIftVWur889uVTHtnEQ5MGdWfK0ckUEuD/0RqJJWF/Jdlm3aQlJzK3BVbOalXa+4dM4guLTX4T6QmU1nIfxSVlPLk59n89YOl1K9diwcuHszFx2rwn4ioLCQsfU0ek5JTyVi7nZED2nPnhQNo20SD/0QkRGVRwxUUlfC3j5byxKfZtGhYl79ffgznDuoQdCwRiTEqixosZUUuE5NTyd60k4uP7cytP+lH84Ya/Cci/01lUQPt3BMa/Pfc1yvo2KwBz18zjJN7twk6lojEMJVFDfPpkk3cPDONtXm7GT8igZvO6UMjDf4TkTLor0QNsW1XIXe9uZDk73Po3qYR/75+BIkJGvwnIodGZVEDvJO2jttez2DrrkJ+fVoPfnO6Bv+JSPmoLKqxjdsL+PPrGbybsZ4BHZvy3DVDGdBRg/9EpPxUFtWQuzNjXg53vZlJQXEpk0b25bqTulFbg/9E5DCpLKqZ1bm7uHlWGp8v3czQhBZMHTuYHm0aBx1LRKo4lUU1UVLqPP/1Ch6YsxgD7ho9gMuHd9XgPxGpECqLaiBrYz6TktOYt3Irp/Ruwz1jBtK5hQb/iUjFUVlUYUUlpfzvp8v4nw+zaFgvjr/87CjGHN1Jg/9EpMKpLKqo9DV53DQjlYXrtvOTwR2YMmoAbZrUCzqWiFRTKosqpqCohL9+sJQnP8+mZaO6/O8Vx3LOgPZBxxKRak5lUYV8tzyXpORUsjfv5JLELtx8Xj+aNawTdCwRqQFUFlVAfkER97+7mBe+WUnnFg3457XDObFX66BjiUgNorKIcR8v3sgtM9NYt72Aa07oxp/O6U3DurrbRKRyRfUlvWY20swWm1mWmSXt5/J6ZvZq+PJvzSwhvPwsM5tnZmnhz6dHM2cs2rqzkD++Op+rn51Lw3q1mXHD8fx5VH8VhYgEImp/ecwsDngMOAvIAeaa2Wx3z4xY7Vpgq7v3NLNxwDTgEmAzMMrd15rZQGAO0ClaWWOJu/NW2jpufz2DvN1F/Pb0nvz69J7Uq63BfyISnGj+mzoMyHL3bAAzewUYDUSWxWhgSvj0DOBRMzN3/yFinQygvpnVc/c9UcwbuA3bC7jttXTey9zAoE7N+OcvhtOvQ9OgY4mIRLUsOgGrI87nAMMPtI67F5tZHtCK0J7FXmOBH6pzUbg7/0pZzd1vLaSwuJTJ5/bl2hM1+E9EYkc0y2J/LyP28qxjZgMIHZo6e79fwGwCMAEgPj7+8FIGbNWWXSTNTOWrZVsY1q0l08YOplvrRkHHEhH5kWiWRQ7QJeJ8Z2DtAdbJMbPaQDMgF8DMOgOzgCvdfdn+voC7TwemAyQmJu5bRDGtpNT5x1creHDOYuJqGXdfOJDLhsVr8J+IxKRolsVcoJeZdQPWAOOAy/ZZZzYwHvgauBj4yN3dzJoDbwGT3f3LKGYMxJIN+Uyckcr81ds4rU8b7hkziI7NGwQdS0TkgKJWFuHHIG4k9EymOOAZd88wszuBFHefDTwNvGBmWYT2KMaFr34j0BO4zcxuCy872903RitvZSgsLuWJT5fxt4+W0rhebR4ZN4QLjuqowX8iEvPMvUodvTmgxMRET0lJCTrGAS1YvY1JyaksWp/PqKM6MmVUf1o11uA/EQmWmc1z98Sy1tMrvKJsd2EJD3+whKc+z6ZNk3o8eWUiZ/VvF3QsEZFyUVlE0dfLtjB5Ziortuzi0mFdmHxeP5rW1+A/Eal6VBZRsL2giKnvLOKlb1cR37IhL/1iOMf31OA/Eam6VBYV7KNFG7h5Zjob8wu47qRu/PGsPjSoq1EdIlK1qSwqyJYde7jzzUxen7+WPu2a8MQVxzKkS/OgY4mIVAiVxRFyd2YvWMsdb2SSX1DE78/sxa9O7Und2hrVISLVh8riCKzL282ts9L5cNFGjurSnPvHDqZP+yZBxxIRqXAqi8NQWuq8Mnc19729kKLSUm79ST+uPqEbcRrVISLVlMqinFZs3knSzFS+yc5lRPdWTB07iK6tNPhPRKo3lcUhKil1nvliOQ+9v5g6tWpx30WDGDe0i0Z1iEiNoLI4BIvWb2fSjFQW5ORxZr+23H3hINo3qx90LBGRSqOyOIg9xSU89vEyHv84i2YN6vC3S4/m/MEdtDchIjWOyuIAfli1lUnJqSzZsIMLh3Tkz6MG0LJR3aBjiYgEQmWxj12FxTz03hKe+XI57ZvW55mrEjm9rwb/iUjNprKI8FXWZpJmprEqdxeXD48n6dy+NNHgPxERlQVA3u4i7nt7Ia/MXU1Cq4a8MuE4juveKuhYIiIxo8aXRWrONq57PoVN+Xu4/pTu/OHM3tSvo8F/IiKRanxZxLdsSO92TXjyykQGd9bgPxGR/anxZdG8YV1euHZ40DFERGKaRqOKiEiZVBYiIlImlYWIiJRJZSEiImVSWYiISJlUFiIiUiaVhYiIlEllISIiZTJ3DzpDhTCzTcDKI7iJ1sDmCooTDbGeD2I/Y6zng9jPGOv5QBnLq6u7tylrpWpTFkfKzFLcPTHoHAcS6/kg9jPGej6I/Yyxng+UMVp0GEpERMqkshARkTKpLP7P9KADlCHW80HsZ4z1fBD7GWM9HyhjVOgxCxERKZP2LEREpEw1vizMbKQJPf1FAAAGsklEQVSZLTazLDNLCihDFzP72MwWmlmGmf0uvHyKma0xs/nhj/MirjM5nHmxmZ1TSTlXmFlaOEtKeFlLM3vfzJaGP7cILzcz+59wxlQzO6YS8vWJ2FbzzWy7mf0+yO1oZs+Y2UYzS49YVu5tZmbjw+svNbPxlZDxATNbFM4xy8yah5cnmNnuiG35RMR1jg3/fGSFvw+LcsZy36/R+n0/QL5XI7KtMLP54eWBbMMj5u419gOIA5YB3YG6wAKgfwA5OgDHhE83AZYA/YEpwJ/2s37/cNZ6QLfw9xBXCTlXAK33WXY/kBQ+nQRMC58+D3gHMOA44NsA7tv1QNcgtyNwMnAMkH642wxoCWSHP7cIn24R5YxnA7XDp6dFZEyIXG+f2/kOGBHO/w5wbpQzlut+jebv+/7y7XP5Q8Cfg9yGR/pR0/cshgFZ7p7t7oXAK8Doyg7h7uvc/fvw6XxgIdDpIFcZDbzi7nvcfTmQReh7CcJo4Lnw6eeACyOWP+8h3wDNzaxDJeY6A1jm7gd7oWbUt6O7fwbk7ufrlmebnQO87+657r4VeB8YGc2M7v6euxeHz34DdD7YbYRzNnX3rz30V+/5iO8rKhkP4kD3a9R+3w+WL7x38DPg5YPdRrS34ZGq6WXRCVgdcT6Hg/+RjjozSwCOBr4NL7oxfCjgmb2HKwgutwPvmdk8M5sQXtbO3ddBqPSAtgFn3GscP/7ljKXtWN5tFvS2vIbQf7l7dTOzH8zsUzM7KbysUzjXXpWVsTz3a1Db8SRgg7svjVgWS9vwkNT0stjf8cDAnh5mZo2BZOD37r4d+DvQAxgCrCO0KwvB5T7B3Y8BzgV+bWYnH2TdwLatmdUFLgD+HV4Ua9vxQA6UJ8hteQtQDLwYXrQOiHf3o4E/Ai+ZWdOAMpb3fg1qO17Kj/9xiaVteMhqelnkAF0izncG1gYRxMzqECqKF919JoC7b3D3EncvBZ7k/w6RBJLb3deGP28EZoXzbNh7eCn8eWOQGcPOBb539w3hvDG1HSn/NgskZ/iB9POBy8OHRQgf2tkSPj2P0GMAvcMZIw9VRT3jYdyvlb4dzaw2cBHwakTumNmG5VHTy2Iu0MvMuoX/Gx0HzK7sEOFjmk8DC939LxHLI4/xjwH2PtNiNjDOzOqZWTegF6EHxqKZsZGZNdl7mtADoOnhLHufnTMeeD0i45XhZ/gcB+TtPfRSCX70n1wsbceIr1uebTYHONvMWoQPtZwdXhY1ZjYSmARc4O67Ipa3MbO48OnuhLZZdjhnvpkdF/55vjLi+4pWxvLer0H8vp8JLHL3/xxeiqVtWC5BP8Ie9AehZ6AsIdTutwSU4URCu5upwPzwx3nAC0BaePlsoEPEdW4JZ15MJTxjgtAzSBaEPzL2biugFfAhsDT8uWV4uQGPhTOmAYmVtC0bAluAZhHLAtuOhEprHVBE6D/Haw9nmxF63CAr/HF1JWTMInR8f+/P4xPhdceG7/8FwPfAqIjbSST0B3sZ8CjhF/1GMWO579do/b7vL194+T+AG/ZZN5BteKQfegW3iIiUqaYfhhIRkUOgshARkTKpLEREpEwqCxERKZPKQkREyqSyENkPM/sq/DnBzC6r4Nu+eX9fSySW6amzIgdhZqcSmmx6fjmuE+fuJQe5fIe7N66IfCKVRXsWIvthZjvCJ6cCJ4Xfd+APZhZnofd6mBseYHd9eP1TLfSeJC8ReqEYZvZaeOhixt7Bi2Y2FWgQvr0XI79W+JXbD5hZevg9DS6JuO1PzGyGhd5j4sXwK3xFKk3toAOIxLgkIvYswn/089x9qJnVA740s/fC6w4DBnpoLDbANe6ea2YNgLlmluzuSWZ2o7sP2c/XuojQULyjgNbh63wWvuxoYAChWUFfAicAX1T8tyuyf9qzECmfswnNb5pPaIx8K0KzfQC+iygKgN+a2QJC7wfRJWK9AzkReNlDw/E2AJ8CQyNuO8dDQ/PmE3oDHZFKoz0LkfIx4Dfu/qNBfuHHNnbuc/5MYIS77zKzT4D6h3DbB7In4nQJ+t2VSqY9C5GDyyf0Vrd7zQF+GR4pj5n1Dk/h3VczYGu4KPoSepvUvYr2Xn8fnwGXhB8XaUPorTorYwquSJn034nIwaUCxeHDSf8AHiF0COj78IPMm9j/W1++C9xgZqmEJp9+E3HZdCDVzL5398sjls8i9P7LCwhNIZ7o7uvDZSMSKD11VkREyqTDUCIiUiaVhYiIlEllISIiZVJZiIhImVQWIiJSJpWFiIiUSWUhIiJlUlmIiEiZ/j8+Z9asAcq3CAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "recorder.plotLRs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More annealing functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "public func constantSchedule(start: Float, end: Float, pct: Float) -> Float {\n", " return start\n", "}\n", "\n", "public func cosineSchedule(start: Float, end: Float, pct: Float) -> Float {\n", " return start + (1 + cos(Float.pi*(1-pct))) * (end-start) / 2\n", "}\n", "\n", "public func expSchedule(start: Float, end: Float, pct: Float) -> Float {\n", " return start * pow(end / start, pct)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// export\n", "public func combineSchedules(pcts: [Float], schedules: [(Float) -> Float]) -> ((Float) -> Float){\n", " var cumPcts: [Float] = [0]\n", " for pct in pcts {cumPcts.append(cumPcts.last! + pct)}\n", " func inner(pct: Float) -> Float{\n", " if (pct == 0.0) { return schedules[0](0.0) }\n", " let i = cumPcts.firstIndex(where: {$0 >= pct})! - 1\n", " let actualPos = (pct-cumPcts[i]) / (cumPcts[i+1]-cumPcts[i])\n", " return schedules[i](actualPos)\n", " }\n", " return inner\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let mySchedule = combineSchedules(pcts: [0.3, 0.7], \n", " schedules: [makeAnnealer(start: 0.3, end: 0.6, schedule: cosineSchedule),\n", " makeAnnealer(start: 0.6, end: 0.2, schedule: cosineSchedule)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunction: softmaxCrossEntropy, optimizer: opt, initializingWith: modelInit)\n", "let recorder = learner.makeRecorder()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeShowProgress(), \n", " learner.makeAvgMetric(metrics: [accuracy]), recorder,\n", " learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std),\n", " learner.makeParamScheduler(scheduler: mySchedule)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [0.23531592, 0.9375] \n", "Epoch 1: [0.12807088, 0.9641] \n", " \r" ] } ], "source": [ "learner.fit(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8VfX9x/HXJ5uRQCBhJoGw9wygIq4qwwEoVHC0Wgdqpe5arKMtrVat1VbFgXXVioCjSiuKVEEQRRIg7JWwEvZeIfvz++Me+rumgSSQk3Nz7+f5eNxH7j33e27eOTe5n5zzPef7FVXFGGOMOZUwrwMYY4wJfFYsjDHGVMiKhTHGmApZsTDGGFMhKxbGGGMqZMXCGGNMhaxYGGOMqZAVC2OMMRWyYmGMMaZCEV4HqC4JCQnaunVrr2MYY0ytsnjx4r2qmlhRu6ApFq1btyYjI8PrGMYYU6uIyJbKtLPDUMYYYypkxcIYY0yFrFgYY4ypkBULY4wxFbJiYYwxpkKuFgsRGSoi60QkS0QmnKTN1SKyWkRWicgUv+U3iMgG53aDmzmNMcacmmunzopIODAJuATIBdJFZIaqrvZr0x54CBioqgdEpImzvBHwGyANUGCxs+4Bt/IaY4w5OTevs+gPZKnqRgARmQqMAFb7tbkVmHSiCKjqbmf5EGC2qu531p0NDAXeczGvcUlhcSkrth1iw64j7D1aQEkpREYI0RHhNK4XRaN6USTUjyalcV3qRwfNpT/GBBU3/zJbAjl+j3OBAWXadAAQkQVAOPBbVf38JOu2LPsNRGQcMA4gJSWl2oKb6pG+eT/vLtzCrFW7OF5UUql1msXF0CaxHl1bxNEnJZ4+reJpGhfjclJjTEXcLBZSzjIt5/u3By4AkoD5ItKtkuuiqpOByQBpaWn/87zxxpodh3li5hrmb9hLbEwEI3u35PwOCXRt0YAmcdGEi1BcqhwvLGF/XiH7jxWy+3ABm/cdI3vPUbJ3H+Xtb7fw2vxNALRqXJcLOzbhgo6JnNWmMTGR4R7/hMaEHjeLRS6Q7Pc4CdheTpuFqloEbBKRdfiKRy6+AuK/7lzXkppqUVRSystzs3n+yw3ExkTwyGWduW5AK+pE/e+He0Q4xESGE18virbljEpTUFzC6u2HWbL1IAuy9jI1fStvfbuZulHhDO3ajCv7tOSctgmEh5X3f4UxprqJqjv/kItIBLAe+BGwDUgHrlXVVX5thgLXqOoNIpIALAV64XRqA32cpkuAvif6MMqTlpamNjaUdw7mFfLzd5fwbfY+RvRqwe+Gd6Vh3ahqe/38ohK+27iPz1fsZObKHRzJL6ZJbDRj+iVz/Vmt7FCVMadJRBaralqF7dwqFk6IS4G/4OuPeENVHxeRiUCGqs4QEQH+jK/zugR4XFWnOuveBPzaeanHVfXNU30vKxbe2bz3GDe+uYjtB/N54qrujO6b5Or3yy8q4au1u/lgcS5z1u0mXIRLuzfnpnNT6ZXc0NXvbUywCYhiUZOsWHhj095jjJ38HYXFpbz20zTSWjeq0e+/Zd8x3v52C+9n5HCkoJjzOyRyz8Xt6Z0SX6M5jKmtrFgY123Zd4yrX/2OohJlyq0D6NQszrMsRwuKeee7LUyel82BvCIu6JjIg0M60aWFd5mMqQ2sWBhXHThWyFUvf8vBvEKmjjubjs1ivY4EwLGCYv7+3RZenZfN4eNFjOmXwv2DO5BQP9rraMYEpMoWCxsbylRZflEJ497JYNvB47z207SAKRQA9aIjuOOCtnz9wIXceE4q72fkcOGf5vLavI0Ul5R6Hc+YWsuKhamy33yyivTNB/jzj3vWeB9FZTWoG8ljV3Rh1r3n0S+1EY/PXMPIlxawctshr6MZUytZsTBV8uHiXKZl5HDnhW25omcLr+NUqG1ifV6/IY2Xr+vDrsMFjJi0gD9+tobjhZW7otwY42PFwlTahl1HeOTjlQxIbcS9F3fwOk6liQjDujfnP/eez+g+Sbz69UYuf2G+7WUYUwVWLEylFBaX8ov3llIvOpwXrulNRHjt+9VpUDeSp0b34B83D+BoQTFXvrSAyfOyKS0NjpM8jHFT7fuLN5544asNrN15hKdH96BJLb9a+tz2CXx+93lc1KkJT8xcy0/e+J7dR/K9jmVMQLNiYSq0PPcgL83NZlSfJC7q1NTrONUivl4Ur1zflyev6s7iLQe4/PlvyNh80tFkjAl5VizMKRUUl/DL95eTUD+Kx67o4nWcaiUijO2fwsd3DqRuVDhjJy/kzQWbCJZrj4ypTlYszCn9bf4m1u06wh+v6k6DOpFex3FFp2ZxfDL+XC7omMjv/rWae6Zl2tlSxpRhxcKc1LaDx3nxqyyGdG0aNIefTqZBnUgm/ySNBwZ3YMay7Yx9baH1Yxjjx4qFOak//Hs1ivLo5cF1+OlkwsKE8Re155Xr+7J+5xGunPQt63Ye8TqWMQHBioUp17z1e/hs5U7GX9iOpPi6XsepUUO6NmP6bWdTVFLKqJe/5ev1e7yOZIznrFiY/1FcUsrEf6+mdeO63HpeG6/jeKJ7UgM+GT+Q5EZ1uemtdKan51S8kjFBzNViISJDRWSdiGSJyIRynr9RRPaISKZzu8XvuRK/5TPczGl+6P3FuWTtPspDl3YmOiJ057tu3qAOH9x+Nue0bcyDHy5n8rxsryMZ4xnX5uAWkXBgEnAJvjm100VkhqquLtN0mqqOL+cljqtqL7fymfLlFRbz7Oz19G0Vz+Auwd2pXRn1oiN4/YZ+3Ds9kydmrmX/sSJ+NbQjvkkejQkdrhULoD+QpaobAURkKjACKFssTAB5ff4m9hwp4JXr+9gHoiMqIoznx/amQZ1IXvk6m4N5hTx+ZXfCw2z7mNDh5mGoloD/gd5cZ1lZo0RkuYh8ICLJfstjRCRDRBaKyEgXcxrH3qMFvDpvI0O6NqVvq8Acetwr4WHC4yO7Mf7CdkxNz+GeaZk2P4YJKW7uWZT3b1fZS2P/BbynqgUicjvwNnCR81yKqm4XkTbAVyKyQlV/cNBYRMYB4wBSUlKqN30IemlONseLSnhwaCevowQkEeGBIR2pHxPBk5+tpVSVv47pVSsHVTSmqtz8Lc8F/PcUkoDt/g1UdZ+qFjgPXwP6+j233fm6EZgL9C77DVR1sqqmqWpaYmJi9aYPMbsP5/Pu91u4qndL2ibW9zpOQLv9/Lb8+tJOfLp8B3dPzaTI9jBMCHCzWKQD7UUkVUSigLHAD85qEpHmfg+HA2uc5fEiEu3cTwAGYn0drnrl640UlyrjL2rndZRaYdx5bXn40s58umIHd09dagXDBD3XDkOparGIjAdmAeHAG6q6SkQmAhmqOgO4S0SGA8XAfuBGZ/XOwKsiUoqvoD1ZzllUppr471W0alzP6zi1xq3ntUEE/vDpGlSX1tp5PoypDDf7LFDVmcDMMsse87v/EPBQOet9C3R3M5v5f7ZXcfpuGeS7aPEPn67hwQ+X88zonoTZWVImCLlaLEzg233Et1dxpe1VnLZbBrUhr7CEZ2evJzY6gt8O72qnHZugY8UixE0+sVdxoe1VnIlfXNSOowXFTJ63kfoxEfxyiJ1RZoKLFYsQdiiviCmLtnJFj+a0TrC9ijMhIjw0rBNH8ouZNCeb2JhIbj+/rdexjKk2VixC2D++30JeYQm32YdatRAR/jCyG0cLinnys7U0qBPJNf3t+h8THKxYhKj8ohLeXLCJ8zsk0rl5nNdxgkZ4mPDs1T05kl/Ew/9cQZPYaH7U2cbYMrWfnecXoj5aso29Rwu57fzQHILcTZHhYUy6tg9dWzRg/JSlZOYc9DqSMWfMikUIKilVJs/LpkdSA85u09jrOEGpXnQEb9zYj8TYaG56K53Ne495HcmYM2LFIgR9sWonm/flcdt5be0UTxclxkbz1s/6oarc8OYi9h4tqHglYwKUFYsQo6q8Mm8jKY3qMrRbM6/jBL02ifV5/cZ+7Dqcz81vpZNXWOx1JGNOixWLELNk6wGW5Rzk1kGpNh9DDemTEs8L1/RhxbZD3DdtGaWlZQdfNibwWbEIMW8u2ExsTASj+iZ5HSWkXNKlKQ9f1oXPV+3kz7PXeR3HmCqzU2dDyI5Dx/ls5U5uGtiaulH21te0mwa2Jmv3ESbNyaZdk/pc2dsKtqk9bM8ihPxj4RZUlZ+e3drrKCFJRJg4ohtntWnErz5YweIt+72OZEylWbEIEflFJUz5fisXd25KcqO6XscJWZHhYbxyfV9aNIxh3N8Xk7M/z+tIxlSKFYsQMSNzOwfyirhxYGuvo4S8hnWjeP3GfhSVlHLL2xkcLbAzpEzgs2IRAlSVN7/dTMemsXYRXoBom1ifl67rS9aeo9w7LdPOkDIBz9ViISJDRWSdiGSJyIRynr9RRPaISKZzu8XvuRtEZINzu8HNnMFu0ab9rNlxmJ8NbG0X4QWQc9sn8MhlnZm9ehcvzsnyOo4xp+TaKTEiEg5MAi4BcoF0EZlRzvSo01R1fJl1GwG/AdIABRY76x5wK28we3PBZhrWjWREr5ZeRzFl3HhOa1bkHuK5/6yna4s4G3TQBCw39yz6A1mqulFVC4GpwIhKrjsEmK2q+50CMRsY6lLOoLb94HG+WL2Tsf1SqBMV7nUcU4aI8MRV3enSPI57pmaycc9RryMZUy43i0VLIMfvca6zrKxRIrJcRD4QkeQqrmsqMD0jBwWuG2DzKgSqmMhwXv1JXyLChdveWWwd3iYguVksyjs4XrYX719Aa1XtAfwHeLsK6yIi40QkQ0Qy9uzZc0Zhg1FJqTItPYdB7RPtdNkAlxRfl0nX9iF7z1EemL4MVevwNoHFzWKRCyT7PU4Ctvs3UNV9qnpiKM7XgL6VXddZf7KqpqlqWmJiYrUFDxZfr9/NjkP5XNs/ueLGxnPntEvg15d25vNVO3lpbrbXcYz5ATeLRTrQXkRSRSQKGAvM8G8gIs39Hg4H1jj3ZwGDRSReROKBwc4yUwVTvt9KQn2bqa02ufncVK7o2YJnvljH/A22t2wCh2vFQlWLgfH4PuTXANNVdZWITBSR4U6zu0RklYgsA+4CbnTW3Q/8Hl/BSQcmOstMJe04dJyv1u7m6rQkIsPtcpraQkR4alR32jepzz1TM9l5KN/rSMYAIMFybDQtLU0zMjK8jhEw/vqfDTz3n/XM++WFpDS2/oraJmv3EYa/uICuLeKYcutZVvCNa0RksaqmVdTOfgODkK9jeyuD2idYoail2jWJ5Y9XdSd98wGemWVDmhvvWbEIQvPW72H7oXyu6W+ny9ZmI3q15LoBKbw6byOzV+/yOo4JcVYsgtCURVtJqB/FxdaxXes9enkXurWM4/7pmTZCrfGUFYsgs/NQPl+t3c3ovslERdjbW9vFRIbz0rV9UeDOKUsoKC7xOpIJUfZpEmTez8ihpFQZ28+urQgWKY3r8syPe7I89xCPf7qm4hWMcYEViyBSUqpMTc9hYLvGtE6o53UcU42GdG3GLeem8vfvtvDp8h1exzEhyIpFEJm/YQ/bDh63ju0g9athneiV3JAJHy23/gtT46xYBJH3Fm2lcb0oBndp5nUU44LI8DBeuKY3KNw1dSlFJaVeRzIhxIpFkNh9OJ//rNnN6L5J1rEdxJIb1eWJq7qzdOtBnpu93us4JoTYp0qQeH9xLiWlyhjr2A56V/RswZi0ZF7+OpsFWXu9jmNChBWLIFBaqry3aCtnt2lMm8T6XscxNeA3w7vQJqEe90zLZO/RgopXMOYMWbEIAt9k7SX3wHGutQmOQkbdqAhevLYPh44X8cD7yygtDY4x3kzgsmIRBN5btJVG9aIY3NWu2A4lnZvH8chlnZm7bg9vLNjkdRwT5KxY1HK7j+Qze/UuRvdNIjrC5tgONT85qxWDuzTlqc/XsiL3kNdxTBCzYlHLvZ+RS7FdsR2yRISnR/cgoX40v3hvic3fbVxjxaIWKy1VpqZv5aw2jaxjO4Q1rBvFX8f2Zuv+PH47Y5XXcUyQcrVYiMhQEVknIlkiMuEU7UaLiIpImvO4tYgcF5FM5/aKmzlrqwXZe8nZb1dsG+if2og7L2zHB4tzmbnChgMx1S/CrRcWkXBgEnAJkAuki8gMVV1dpl0svilVvy/zEtmq2sutfMHgvUVbia8byZCudsW2gbt+1J556/fw0Ecr6JMST7MGMV5HMkHEzT2L/kCWqm5U1UJgKjCinHa/B54GbLLhKthzpIAvVu1iVJ8kYiKtY9v4hgN5bkwvCotLuf/9TDud1lQrN4tFSyDH73Gus+y/RKQ3kKyq/y5n/VQRWSoiX4vIoPK+gYiME5EMEcnYs2dPtQWvDT5Y7HRs2yEo46dNYn0eu6ILC7L22em0plq5WSyknGX//VdHRMKA54D7y2m3A0hR1d7AfcAUEYn7nxdTnayqaaqalpiYWE2xA9+Jju3+qY1o18Q6ts0Pje2XzCVdmvL05+tYvf2w13FMkHCzWOQC/udzJgHb/R7HAt2AuSKyGTgLmCEiaapaoKr7AFR1MZANdHAxa63y3cZ9bNmXx7W2V2HKISI8NaoHDepGcs+0peQX2ex65sy5WSzSgfYikioiUcBYYMaJJ1X1kKomqGprVW0NLASGq2qGiCQ6HeSISBugPbDRxay1ypRFW2lQJ5Kh3axj25SvUb0onvlxT9bvOspTn6/1Oo4JAq4VC1UtBsYDs4A1wHRVXSUiE0VkeAWrnwcsF5FlwAfA7aq6362stcneowV8sWqndWybCp3fIZEbz2nNmws2M299aPXpmern2qmzAKo6E5hZZtljJ2l7gd/9D4EP3cxWW324OJeiEuWa/nbFtqnYhGGd+DZ7L/e/v4xZ95xHo3pRXkcytZRdwV2LqPqGIu/XOp72TWO9jmNqgZjIcP4ypjeH8oqY8OFyVO10WnN6rFjUIt9t3MfmfXl2xbapki4t4vjlkI58sXoX09JzKl7BmHJYsahF3luUQ1xMBJd2b+51FFPL3HxuKgPbNWbiv1ezdV+e13FMLWTFopbYd7SAWSt3cpV1bJvTEBYm/Gl0T8LDhPumZ1JiV3ebKrJiUUt8sDiXwpJSmw3PnLYWDeswcURXMrYc4LX5dia6qRorFrXAiTm201rF08E6ts0ZGNmrJcO6NePZL9azZodd3W0qz4pFLXCiY9v2KsyZEhEev7I7cXUiuXdaJgXFdnW3qRwrFrXAlO99V2xbx7apDo3qRfHUqO6s3XmE52Zv8DqOqSWsWAS4PUcKmGVXbJtq9qPOTRnbL5lX52WTvtkGRzAVs2IR4N5fnENxqXLtALti21SvRy7vQlJ8He6fvszm7jYVsmIRwEpLlamLchiQ2oh2Taxj21Sv+tER/PnHvcg5kMfjn67xOo4JcFYsAtg3WXvZut86to17+qc2Ytx5bXhv0Va+WrvL6zgmgFmxCGBTvt9Ko3pRNhS5cdV9l3SgU7NYHvxgBfuPFXodxwQoKxYBavfhfGav2cXovklER1jHtnFPdEQ4z17di0PHC3nk4xU22KAplxWLADU9I4eSUrVBA02N6NIijnsv6cDMFTv5JHN7xSuYkONqsRCRoSKyTkSyRGTCKdqNFhEVkTS/ZQ85660TkSFu5gw0JaXKe4tyOKdtY1IT6nkdx4SI285rS99W8Tz6yUq2HzzudRwTYFwrFs60qJOAYUAX4BoR6VJOu1jgLuB7v2Vd8E3D2hUYCrx0YprVUDBvwx62HTxuHdumRoWHCc9e3ZOSUuWXHyyj1AYbNH7c3LPoD2Sp6kZVLQSmAiPKafd74Gkg32/ZCGCqqhao6iYgy3m9kPDOd1tIqB/N4C7WsW1qVqvG9Xjksi4syNrH37/b7HUcE0AqLBYiEiYiK0/jtVsC/jOt5DrL/F+7N5Csqv+u6rrBasu+Y8xZt5trB6QQFWFdSqbmXdM/mQs6JvLHz9aStfuo13FMgKjw00hVS4FlIlLVYyJS3sv990mRMOA54P6qruv3GuNEJENEMvbsCY4J6f+xcAvhIlxnh6CMR0SEp0f1oE5UOPdPz6SopNTrSCYAVPZf1+bAKhH5UkRmnLhVsE4u4D9GRRLgf5pFLNANmCsim4GzgBlOJ3dF6wKgqpNVNU1V0xITEyv5owSu44UlTEvPYUi3ZjSNi/E6jglhTeJieHxkd5blHmLSnCyv45gAEFHJdr87jddOB9qLSCqwDV+H9bUnnlTVQ0DCicciMhd4QFUzROQ4MEVEngVaAO2BRaeRoVb5OHMbh/OLufGc1l5HMYbLejTni9UteOGrLC7s2ISeyQ29jmQ8VKlioapfV/WFVbVYRMYDs4Bw4A1VXSUiE4EMVT3pnonTbjqwGigG7lTVoB54X1V5+9vNdG4eR1qreK/jGAPAxOHdWLRpP/dOz+TTXwyiTlTInJRoyjjlYSgROSIih8u5HRGRCqfZUtWZqtpBVduq6uPOssfKKxSqeoGqZvg9ftxZr6OqfnY6P1xtkr75AGt3HuGGs1shUl6XjTE1r0HdSJ75cU827jnGU5+v9TqO8dApi4WqxqpqXDm3WFWNq6mQoeDt7zYTFxPBiF4hcdKXqUUGtkvgZwNb89a3m5m/IThOJDFVZ+dmBoCdh/KZtXInY/ol226+CUi/GtqJdk3q88D7yziYZ4MNhiIrFgHgHwu3UKLK9We18jqKMeWKiQznuat7se9oIY9+ssrrOMYDViw8lldYzD++38IlnZvSqrGNA2UCV/ekBtz9o/b8a9l2Psnc5nUcU8OsWHjsw8W5HMwr4tbz2ngdxZgK3XFBW3qnNOTRj1ey45ANNhhKrFh4qKRUef2bTfRMbminy5paISI8jOeu7kVRifLgB8ttsMEQYsXCQ/9Zs4vN+/K4dVCqnS5rao3WCfV45PLOzN+w1wYbDCFWLDz0t/kbadmwDkO72uiypna5tn8KF9pggyHFioVHMnMOkr75ADedm0pEuL0NpnYREZ4a1YO6UeHcO80GGwwF9inlkdfmbyQ2JoIx/ZIrbmxMAGoSF8MTV3ZnxbZDvPCVDTYY7KxYeGDz3mN8tmIH1w5IoX50ZcdyNCbwDOvenKv6tGTSnCyWbj3gdRzjIisWHnh5bjaR4WHcfG6q11GMOWO/Hd6VZnEx3Dd9GXmFxV7HMS6xYlHDth08zodLchnbL5kmsTZnhan94mJ8gw1u3neMP860wQaDlRWLGjb562xEYNz5bb2OYky1ObttY24emMo7C7cwd91ur+MYF1ixqEG7j+TzXnoOV/VOomXDOl7HMaZaPTCkIx2a1ufBD5Zz4JgNNhhsrFjUoNfnb6K4pJQ7LrC9ChN8YiLDeW5MLw7kFfLIxytRtau7g4mrxUJEhorIOhHJEpEJ5Tx/u4isEJFMEflGRLo4y1uLyHFneaaIvOJmzpqw92gB7yzcwuU9WtA6wQYMNMGpa4sG3HtJBz5dsYNPMrd7HcdUI9eKhYiEA5OAYUAX4JoTxcDPFFXtrqq9gKeBZ/2ey1bVXs7tdrdy1pSX52aTX1TC3Re39zqKMa667by2pLWK59GPV5KzP8/rOKaauLln0R/IUtWNqloITAVG+DdQVf+pWesBQbnfuuPQcd5ZuIVRfZJom1jf6zjGuCo8THhuTC8A7p2WSbFd3R0U3CwWLYEcv8e5zrIfEJE7RSQb357FXX5PpYrIUhH5WkQGuZjTdc9/mYWq2l6FCRnJjery+5HdyNhygElzsr2OY6qBm8WivGFU/2fPQVUnqWpb4FfAI87iHUCKqvYG7gOmiMj/zPktIuNEJENEMvbsCcy5gTfvPcb7GTlc2z+FpPi6XscxpsaM7N2SK3u35K9frmfxlv1exzFnyM1ikQv4D3yUBJyqx2sqMBJAVQtUdZ9zfzGQDXQou4KqTlbVNFVNS0xMrLbg1enZ2euJCBfuvKid11GMqXETR3SlZXwd7p6ayeH8Iq/jmDPgZrFIB9qLSKqIRAFjgRn+DUTE/7jMZcAGZ3mi00GOiLQB2gMbXczqiiVbDzBj2XZuObeNXa1tQlJsTCR/GdObHYfyefTjlV7HMWfAtWKhqsXAeGAWsAaYrqqrRGSiiAx3mo0XkVUikonvcNMNzvLzgOUisgz4ALhdVWvVfmxpqTLxX6tJjI226ypMSOvbKp67f9SeTzK388+luV7HMafJ1SFPVXUmMLPMssf87t99kvU+BD50M5vbZizbTmbOQf40ugf1bGRZE+LuvLAd32zYy6Mfr6JvSiNSGlv/XW1jV3C74HhhCU9+tpZuLeMY1SfJ6zjGeC48THhubC9E4K6pS22ypFrIioUL/vrlBnYezuexy7sSFmZzaxsD0LJhHZ64sjuZOQd5/ssNXscxVWTFopqt2XGY1+ZvZHTfJPqnNvI6jjEB5YqeLRjdN4lJc7JYtKlWdUOGPCsW1aikVHnooxU0qBPJw5d29jqOMQHpt8O7ktyoLvdMXcrBPBudtrawYlGN3v1+C5k5B3n08s7E14vyOo4xAal+dATPj+3NnqMFPPD+chudtpawYlFNNu31zRI2qH0CI3v9z6gmxhg/PZMbMmFYZ/6zZhdvLNjsdRxTCVYsqkFRSSn3TMskKiKMp0f3QMQ6tY2pyE0DW3Nx56Y8+dkaluUc9DqOqYAVi2rwwldZLMs5yBNXdqd5A5sBz5jKEBGe+XEPmsTGcOeUJRw6bsOBBDIrFmfomw17efGrDVzVuyWX9WjudRxjapWGdaN4/pre7DyUz4QPrf8ikFmxOAM5+/MY/94S2jWpz+9HdvM6jjG1Ut9W8fxySEc+W7mTdxZu8TqOOQkrFqfpWEExt72zmJJSZfJP0mxID2POwK2D2nBhx0T+8O81rNx2yOs4phxWLE5DUUkpd7y7hLU7D/P82N42p7YxZygsTPjz1b1oVC+KO6cs4YgNZx5wrFhUUUmp8sv3lzFv/R6euLI7F3Zq4nUkY4JCo3q+/ovcA8eZ8NEK678IMFYsqqCopJR7p2XyceZ2fjmkI2P7p3gdyZig0j+1EfcP7sCny3fwpl1/EVCsWFTS0YJibn9nMTOWbWfCsE7ceaHNfGeMG24/ry0Xd27KEzPXkL7Zxo8KFFYsKiF7z1GunLSAOesSWAotAAASgklEQVR284eR3bj9fJvMyBi3+PovepIUX4c7313C7iP5XkcyuFwsRGSoiKwTkSwRmVDO87eLyAoRyRSRb0Ski99zDznrrRORIW7mPJniklImz8vmsufns+9YIf+4eQDXn9XKiyjGhJQGdSJ55Sd9OZxfxPgpNv9FIHCtWDhzaE8ChgFdgGv8i4Fjiqp2V9VewNPAs866XfDN2d0VGAq8dGJO7pqQX1TCR0tyGfyXeTwxcy3ntktk5l2DOKddQk1FMCbkdWoWx5NX9WDRpv089dlar+OEPDcvDugPZKnqRgARmQqMAFafaKCqh/3a1wNOnP4wApiqqgXAJhHJcl7vu+oOWVBcQubWgxzIK2L7weMszTnI3LW7OVJQTKdmsbz6k74M7tLUxnsyxgMje7dk6dYD/O2bTfRKacjlPVp4HSlkuVksWgI5fo9zgQFlG4nIncB9QBRwkd+6C8us+z9DuYrIOGAcQErK6Z2ZdCS/mDGT//9bNYuLYXDXZozs3YKBbRNspjtjPPbwZV1Yse0QD36wnE7NYmnXJNbrSCHJzWJR3qfs/5w4raqTgEkici3wCHBDFdadDEwGSEtLO62TsuPrRvHuLQNoUCeSJrHRNImLOZ2XMca4JCoijJeu68vlL8xn3N8X8887B9KgTqTXsUKOmx3cuUCy3+MkYPsp2k8FRp7muqctPEwY2C6Bbi0bWKEwJkA1axDDpGv7sHV/Hne9t5SSUrtgr6a5WSzSgfYikioiUfg6rGf4NxCR9n4PLwNOzOI+AxgrItEikgq0Bxa5mNUYE+AGtGnMxBHd+Hr9Hp7+3Dq8a5prh6FUtVhExgOzgHDgDVVdJSITgQxVnQGMF5GLgSLgAL5DUDjtpuPrDC8G7lTVEreyGmNqh2sHpLB252FenbeRjs1iuapPkteRQoYEy/graWlpmpGR4XUMY4zLikpK+enri1i89QDTxp1F75R4ryPVaiKyWFXTKmpnV3AbY2qVyPAwJl3Xh6Zx0dz2zmJ2HrIrvGuCFQtjTK3TqF4Uf/tpP2demQzyi+wotdusWBhjaqWOzWJ5bkwvlm87xP3Tl1FqZ0i5yoqFMabWGty1Gb8e1plPV+zgqVl2hpSbbC5QY0ytdsugVLbuz+PVrzeSHF/XBvt0iRULY0ytJiL85ooubDt4nMc+WUnLhnVsBksX2GEoY0ytFxEexgvX9KZz8zjunLKEldsOeR0p6FixMMYEhXrREbxxYz8a1onk5rfT2X7wuNeRgooVC2NM0GgaF8MbP+tHXkEJN7yxiAPHCr2OFDSsWBhjgkqnZnFM/mkaW/bn8bO30jlWUOx1pKBgxcIYE3TObtuYF6/pzfLcg9z+j8UUFNtFe2fKioUxJigN7tqMJ0f1YP6Gvdw3bZkNa36G7NRZY0zQujotmUN5RTw+cw0N6kby+MhuNkXyabJiYYwJaree14b9eYW8PDeb+tERPDSskxWM02DFwhgT9B4c0pGj+cVMnreR8DDhwSEdrWBUkRULY0zQExF+N7wrJaq8PDebcBHuH9zBCkYVuNrBLSJDRWSdiGSJyIRynr9PRFaLyHIR+VJEWvk9VyIimc5tRtl1jTGmKsLChD+M6MbYfsm8OCeLv365oeKVzH+5tmchIuHAJOASIBdIF5EZqrrar9lSIE1V80TkDuBpYIzz3HFV7eVWPmNM6AkLE564sjvFpcpf/rOBMBHu+lF7r2PVCm4ehuoPZKnqRgARmQqMwDevNgCqOsev/ULgehfzGGMMYWHCU6N6UKrKs7PXc7yoxPowKsHNYtESyPF7nAsMOEX7m4HP/B7HiEgGUAw8qaofl11BRMYB4wBSUlLOOLAxJjSEhwl/Gt2TmMhwXp6bzdH8Yn43vCthYVYwTsbNYlHeVi/3qhgRuR5IA873W5yiqttFpA3wlYisUNXsH7yY6mRgMkBaWppdcWOMqbTwMOHxkd2IjY7g1XkbOVZQzNOjexARbtcql8fNYpELJPs9TgK2l20kIhcDDwPnq2rBieWqut35ulFE5gK9geyy6xtjzOkSESYM60RsTATPfLGeowXFPH9Nb2Iiw72OFnDcLKHpQHsRSRWRKGAs8IOzmkSkN/AqMFxVd/stjxeRaOd+AjAQv74OY4ypLiLC+Iva89sruvDF6l389PVFHMyz0WrLcq1YqGoxMB6YBawBpqvqKhGZKCLDnWZ/AuoD75c5RbYzkCEiy4A5+PosrFgYY1xz48BUnr+mN5k5B7nq5W/Zui/P60gBRVSD41B/WlqaZmRkeB3DGFPLLdq0n1v/nkFkuPC3G/rRK7mh15FcJSKLVTWtonbWk2OMMX76pzbiwzvOoU5UOGMnf8dnK3Z4HSkgWLEwxpgy2jWpzz9/PpDOzeO4490lPDNrXcgPcW7FwhhjypFQP5qp485iTJpveJCb307n0PEir2N5xoqFMcacRHREOE+O6s7jV3ZjQdZeRrz4DWt3HvY6liesWBhjzCmICNcNaMV7t57FscISRry4gHcWbiFYTg6qLCsWxhhTCWmtGzHzrkEMaNOYRz9eyW3vLObAsdC5HsOKhTHGVFJibDRv3diPRy7rzJx1uxn21/ksyNrrdawaYcXCGGOqICxMuGVQGz66YyB1o8K57m/f89BHyzmcH9yd31YsjDHmNHRPasDMuwdx23ltmJaew+Bn5/Hlml1ex3KNFQtjjDlNMZHhPHRpZ/7584E0qBPJzW9ncOvfM9iy75jX0aqdFQtjjDlDPZMb8q9fnMuDQzuyIGsvlzw7j6c/X8uxgmKvo1UbKxbGGFMNoiLC+PkF7ZjzwAVc3qM5L83N5oJn5vLmgk3kF5V4He+MWbEwxphq1DQuhmfH9OKjn59D28R6/O5fqzn/T3N457vNFBTX3qJho84aY4yLvs3ey7NfrCdjywES6kdx/VmtuG5AKxJjo72OBlR+1FkrFsYY4zJV5dvsffxt/kbmrNtDVHgYl/dszui+SZyV2tjTub8rWyzcnFYVERkK/BUIB/6mqk+Wef4+4BagGNgD3KSqW5znbgAecZr+QVXfdjOrMca4RUQY2C6Bge0SyN5zlLcWbOafS7fx0ZJttGxYhxG9WjCsW3O6tojztHCcimt7FiISDqwHLsE3H3c6cI3/jHciciHwvarmicgdwAWqOkZEGgEZQBqgwGKgr6oeONn3sz0LY0xtcrywhNlrdvHRklzmrd9DqfquEL+oYxMGdUigb6t4mjeo43qOQNiz6A9kqepGJ9BUYAR+c2mr6hy/9guB6537Q4DZqrrfWXc2MBR4z8W8xhhTY+pEhTO8ZwuG92zBvqMFzF23h6/W7Wbmih1My8gBoEWDGHomN6R9k/q0bVKfNgn1aRIXTaN6UUSG1+z5SW4Wi5ZAjt/jXGDAKdrfDHx2inVbVms6Y4wJEI3rRzOqbxKj+iZRVFLKmh2HWbzlAIu3HGDFtkPMWrWTsnMvNawbSb2oCKIiwujWsgEvXNPb1YxuFovyDryVe8xLRK7Hd8jp/KqsKyLjgHEAKSkpp5fSGGMCSGR4GD2SGtIjqSE/G5gKQH5RCVv25bFp7zH2Hi1g79EC9h0tJK+whMKSUpLj3T9c5WaxyAWS/R4nAdvLNhKRi4GHgfNVtcBv3QvKrDu37LqqOhmYDL4+i+oIbYwxgSYmMpyOzWLp2CzWswxuHvRKB9qLSKqIRAFjgRn+DUSkN/AqMFxVd/s9NQsYLCLxIhIPDHaWGWOM8YBrexaqWiwi4/F9yIcDb6jqKhGZCGSo6gzgT0B94H0RAdiqqsNVdb+I/B5fwQGYeKKz2xhjTM2zi/KMMSaEVfbUWRsbyhhjTIWsWBhjjKmQFQtjjDEVsmJhjDGmQlYsjDHGVChozoYSkT3AljN4iQRgbzXFcUOg54PAzxjo+cAyVodAzweBlbGVqiZW1ChoisWZEpGMypw+5pVAzweBnzHQ84FlrA6Bng9qR8ay7DCUMcaYClmxMMYYUyErFv9vstcBKhDo+SDwMwZ6PrCM1SHQ80HtyPgD1mdhjDGmQrZnYYwxpkIhXyxEZKiIrBORLBGZ4GGOZBGZIyJrRGSViNztLP+tiGwTkUzndqnfOg85udeJyJAayLhZRFY4OTKcZY1EZLaIbHC+xjvLRUSed/ItF5E+NZCvo992yhSRwyJyj9fbUETeEJHdIrLSb1mVt5uI3OC03yAiN7ic708istbJ8E8Raegsby0ix/225St+6/R1fj+ynJ+hvEnMqjNjld9Xt/7eT5Jvml+2zSKS6Sz3ZBueMVUN2Ru+odOzgTZAFLAM6OJRluZAH+d+LLAe6AL8FnignPZdnLzRQKrzc4S7nHEzkFBm2dPABOf+BOAp5/6l+KbJFeAs4HsP3tudQCuvtyFwHtAHWHm62w1oBGx0vsY79+NdzDcYiHDuP+WXr7V/uzKvswg428n+GTDM5W1YpffVzb/38vKVef7PwGNebsMzvYX6nkV/IEtVN6pqITAVGOFFEFXdoapLnPtHgDWcet7xEcBUVS1Q1U1AFr6fp6aNAN527r8NjPRb/nf1WQg0FJHmNZjrR0C2qp7qQs0a2YaqOg8oOx9LVbfbEGC2qu5X1QPAbGCoW/lU9QtVLXYeLsQ3W+VJORnjVPU79X3q/d3vZ3Il4ymc7H117e/9VPmcvYOrgfdO9Rpub8MzFerFoiWQ4/c4l1N/QNcIEWkN9Aa+dxaNdw4HvHHicAXeZFfgCxFZLL75zwGaquoO8BU8oImH+fyN5Yd/nIGyDU+o6nbzMutN+P7LPSFVRJaKyNciMshZ1tLJVNP5qvK+erUNBwG7VHWD37JA2oaVEurForzjgZ6eHiYi9YEPgXtU9TDwMtAW6AXswLc7C95kH6iqfYBhwJ0ict4p2nq2bcU3je9w4H1nUSBtw4qcLJMnWUXkYaAYeNdZtANIUdXewH3AFBGJ8yhfVd9Xr97va/jhPy6BtA0rLdSLRS6Q7Pc4CdjuURZEJBJfoXhXVT8CUNVdqlqiqqXAa/z/YZIaz66q252vu4F/Oll2nTi85Hw9MZe6l9t2GLBEVXc5eQNmG/qp6nar8axOJ/rlwHXOYRGcQzv7nPuL8fUBdHDy+R+qqonfx6q+r15swwjgKmCaX+6A2YZVEerFIh1oLyKpzn+jY4EZXgRxjmu+DqxR1Wf9lvsf578SOHG2xQxgrIhEi0gq0B5f55hb+eqJSOyJ+/g6QFc6OU6cmXMD8Ilfvp86Z/ecBRw6cdilBvzgP7lA2YZlVHW7zQIGi0i8c7hlsLPMFSIyFPgVMFxV8/yWJ4pIuHO/Db5tttHJeEREznJ+l3/q9zO5lbGq76sXf+8XA2tV9b+HlwJpG1aJ1z3sXt/wnX2yHl91f9jDHOfi2+VcDmQ6t0uBd4AVzvIZQHO/dR52cq/D5bMm8J1Bssy5rTqxrYDGwJfABudrI2e5AJOcfCuAtBrajnWBfUADv2WebkN8hWsHUITvv8ebT2e74es7yHJuP3M5Xxa+4/snfhdfcdqOct7/ZcAS4Aq/10nD94GdDbyIc9Gvixmr/L669fdeXj5n+VvA7WXaerINz/RmV3AbY4ypUKgfhjLGGFMJViyMMcZUyIqFMcaYClmxMMYYUyErFsYYYypkxcKYcojIt87X1iJybTW/9q/L+17GBDI7ddaYUxCRC/CNbHp5FdYJV9WSUzx/VFXrV0c+Y2qK7VkYUw4ROercfRIY5Mw7cK+IhItvrod0ZwC725z2F4hvPpIp+C4UQ0Q+dgZdXHVi4EUReRKo47zeu/7fy7lq+08istKZ02CM32vPFZEPxDfHxLvOFb7G1JgIrwMYE+Am4Ldn4XzoH1LVfiISDSwQkS+ctv2BbuobFhvgJlXdLyJ1gHQR+VBVJ4jIeFXtVc73ugrfoHg9gQRnnXnOc72BrvjGCloADAS+qf4f15jy2Z6FMVUzGN/YTZn4hpBvjG9sH4BFfoUC4C4RWYZvPohkv3Yncy7wnvoGx9sFfA3083vtXPUNmpeJbwIdY2qM7VkYUzUC/EJVfzCIn9O3cazM44uBs1U1T0TmAjGVeO2TKfC7X4L97ZoaZnsWxpzaEXzT3J4wC7jDGU4eEengjMJbVgPggFMoOuGbIvWEohPrlzEPGOP0iyTim6qzpkbBNeaU7L8TY05tOVDsHE56C/grvkNAS5xO5j2UP/Xl58DtIrIc38inC/2emwwsF5Elqnqd3/J/4pt/eRm+EYgfVNWdTrExxlN26qwxxpgK2WEoY4wxFbJiYYwxpkJWLIwxxlTIioUxxpgKWbEwxhhTISsWxhhjKmTFwhhjTIWsWBhjjKnQ/wGqCY2hbWy5UgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "recorder.plotLRs()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "//Needs fixing \n", "//learner.recorder!.plotLRs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mixup attempt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TODO: move to 10b and adapt loss function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "extension RandomDistribution {\n", " // Returns a batch of samples.\n", " func next(\n", " _ count: Int, using generator: inout G\n", " ) -> [Sample] {\n", " var result: [Sample] = []\n", " for _ in 0.. [Sample] {\n", " return next(count, using: &ThreefryRandomNumberGenerator.global)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "extension Learner{\n", " //TODO: fiw with change of loss function now that labels are Ints\n", " public class MixupDelegate: Delegate {\n", " private var distribution: BetaDistribution\n", " \n", " public init(alpha: Float = 0.4){\n", " distribution = BetaDistribution(alpha: alpha, beta: alpha)\n", " }\n", "\n", " override public func batchWillStart(learner: Learner) {\n", " if let xb = learner.currentInput {\n", " if let yb = learner.currentTarget as? Tensor{\n", " var lambda = Tensor(distribution.next(Int(yb.shape[0])))\n", " lambda = max(lambda, 1-lambda)\n", " let shuffle = Raw.randomShuffle(value: Tensor(0.. MixupDelegate {\n", " return MixupDelegate(alpha: alpha)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "let learner = Learner(data: data, lossFunction: softmaxCrossEntropy1, optimizer: opt, initializingWith: modelInit)\n", "let recorder = learner.makeRecorder()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeShowProgress(), \n", " learner.makeAvgMetric(metrics: [accuracy]), recorder,\n", " learner.makeParamScheduler(scheduler: mySchedule),\n", " learner.makeMixupDelegate(alpha: 0.2)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.fit(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "notebookToScript(fname: (Path.cwd / \"05_anneal.ipynb\").string)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Swift", "language": "swift", "name": "swift" }, "language_info": { "file_extension": ".swift", "mimetype": "text/x-swift", "name": "swift", "version": "" } }, "nbformat": 4, "nbformat_minor": 1 }