{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.1\n", "IPython 6.0.0\n", "\n", "tensorflow 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p tensorflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Logistic Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of *classic* logistic regression for binary class labels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAACqCAYAAAD1E6s4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFexJREFUeJzt3X+MVNd1B/DvYbKOVkq0hHjlyMsuUMddCxkqxMom4g8r\nxhWOCzGxGxTcWqG1hCI5ShMiWpAtZCNXpkIijVX3D1QsWtkhosLeONgVsaGKVStQ75oEsAmRHdew\nm0jBtSCRuhLL7ukfb4edmX1v5r1598398b4fCa3nMTtzWL+75829550rqgoiIiJfzLMdABERURZM\nXERE5BUmLiIi8goTFxEReYWJi4iIvMLERUREXmHiIiIirzBxERGRV5i4iIjIK5+w8aY33nijLl68\n2MZbExkxOjr6kar22o6jimOKQpB2XFlJXIsXL8bIyIiNtyYyQkQ+tB1DLY4pCkHaccWpQiIi8goT\nFxEReYWJi4iIvMLEFbLTh4Dv3Q48MT/6evqQ7YiIisVzvhSsFGdQB5w+BPz4W8DkRPT4ysXoMQAs\n32gvLqKi8JwvDX7iCtWxXbMDuGpyIjpOFCKe86XBxBWqK2PZjhP5jud8aTBxhapnYbbjRL7jOV8a\nTFyhWrMT6OquP9bVHR0nChHP+dJg4grV8o3A+meAnn4AEn1d/wwXqckdpisAec6XBqsKQ7Z8Iwct\nuamoCkCe86XAT1xE1HmsAKQcmLiIqPNYAUg5MHERUeexApByYOIios5jBSDlwMRFRJ3HCkDKgVWF\nRI4RkX4A/wbgJgAKYJ+qft9uVAVgBSC1iYmLyD3XAHxXVd8WkU8DGBWR11T1XduBEbmAU4VEjlHV\n36rq2zP//QcA5wD02Y2KyB1MXEQOE5HFAFYAOBnzd1tEZERERi5dutTp0IisYeIicpSIfArAYQDf\nVtXfN/69qu5T1SFVHert7e18gESWMHEROUhEuhAlrRdU9UXb8RC5hImLyDEiIgD2Azinqnttx0Pk\nGiauMjLdlZtMWw3gYQB3i8jPZ/7cZzsoIlfkLocvzT0noSiqKzcZo6r/BUBsx0HkKhOfuKr3nCwF\nsArAoyKy1MDrUhHYlZuIPJc7cfGeE8+wKzcRec5o54xW95wA2AIAAwMDJt+WsuhZGE0Pxh0nMunI\nVmD0AKBTgFSAlZuBdaw1ofyMFWfwnhPDjmwFnlwAPNETfT2y1czrsis3dcKRrcDI/ihpAdHXkf3m\nzmMqNSOJi/ecGFbkoGdXbuqE0QPZjhNlYKKqkPecmNZs0JuYamFXbipa9aIr7XGiDEx84uI9J6Zx\n0JPvpJLtOFEGuT9x8Z6TAkglPklx0JMvVm6OprfjjhPlxM4ZppnoSpE0uLMOenbIIFvW7QWGHpm9\n2JJK9NjEVDfP69LjRpImmepKUR3ceUqJ2SGDbFu313z5O89rAhOXWc26UmQdVHkHvclYiFzB85rA\nqUKzXOpK4VIsRKbwvCYwcZmV1H3CRlcKl2IhMoXnNYGJy6w1O4F5XfXH5nU170pR1EIzO2RQiHhe\nE7jGZZ5I88e1ilxorn7/sV3RNErPwmhwcx2AfMbzmsDEZdaxXcDU1fpjU1eTF46LXmhmhwwKEc/r\n0uNUoUlZF4650ExElBkTl0lZF4650ExElFnYiavIO+zjXjvrwjEXmimBiDwnIr8TkbO2YyFyTbiJ\nq1r4cOUiAJ0tfDCRvJJeG8i2ZQi3GKFkBwDcazsIIheFW5xRZOFDs9f+ztlsr8+FZoqhqm/M7Cju\nr9OHOl/9Z+M9qePCTVxFFj6wqIIcICJbAGwBgIGBAcvRNLDRU5B9DEsj3KnCIgsfWFRBDlDVfao6\npKpDvb29tsOp12xWIqT3JCvCTVxFFj6s2QnMa9gba14lOp5UEMKtGKhMbMxKcCakNMKdKizyDvsL\nJ4Dpho0ep6eAU88DY/89d6riwgngFz/gFAaVR8/CmeKlmOMhvSdZEe4nLiBKCt85CzxxOXvRRDOj\nB+KPf/DT+KmK0QOcwqBMROQggJ8BGBSRMRF5xHZMmdi41YO3l5RGuJ+4iqRTrZ+T5vmcwqAEqrrJ\ndgy5LN8YzTTUbob6Jw+ZuXhMqhxkH8PSYOJqh1SyJa+k53MKg0J1+lA0PV4973UqejywKl8iaVU5\nyNtLSiHsqcKsBRFHtgJPLgCe6Im+Htka/7yVm+OPL7krfqpi5WZ3tjtpw/CpcazefRxLtr+C1buP\nY/jUuLVYyBNFVfixcpAQcuLK2jnjyFZgZH/9FeLI/uTkFeezn4/vhDGwqr3tToro+pHR8Klx7Hjx\nDMYvT0ABjF+ewI4XzzB5UXNFVfixcpAQcuLKemWWVHARd7zZc+MKQpptd2Ii9gLtOXoeE5P105wT\nk1PYc/R8x2MhjxR1ryPvoSSEnLiyXpklrVnFHc/y3HZiceiq8jeXJzIdJwJQXIUfKwcJISeurFdm\nUkl/PMtz24nFoavKm+d3ZzpOBKC4BtJsTA2A687hJq6sV2ZJBRdxx5s9N7DtTratHUR3V31C7u6q\nYNvawY7HQp4p6j7Kol7XE1x3DjlxZb0yW7cXGHpk9lOTVKLH6/bOfe7AqrmfrqqPA9vuZMOKPjz9\nwDL0ze+GAOib342nH1iGDSv6Oh4LEXHdGQBEVTv+pkNDQzoyMtLx9zXme7fHt5ZJvF+rP7oypGCI\nyKiqDtmOo8r7MUWpLdn+CuJ+awuAD3b/WafDMSrtuAr3E1eRshZ4sFSXiAzhujMTV3uyFniwVJeI\nDOG6s28tn7Lubpr0/Ly7pK7ZCfzo0fp7syo3ACseru8CD3hVqjt8ahx7jp7Hby5P4Ob53di2drB0\na1lB/wy4O3AQqudj0nka9Dk8w5/ElXV306Tnm9pipHFtUDUq2hhY5eUvh2qlUnXRt1qpBCC4kz5J\n0D8D7g4clA0r+mLPyaDP4Rr+TBVm7SaR9HwTW4wc2wVMT9Yfm56MjntaqstKpcB/Bg51Y6HiBH0O\n1/AncZnqMmGigMKhzhamsENG4D+DAM9Zmivoc7iGP4nLVJcJEwUUDnW2MIWVSoH/DAI8Z2muoM/h\nGv4kLlPdJ1Zuzt+VwqHOFqawUinwn0GA5yzNFfQ5XMNIcYaI3Avg+wAqAP5FVXebeN06WXc3TdqB\ndd3e5AKKpKqrf/0y8MFPZ197yV1RJwsPizCStKpUcp2JSirffwZNcXfgUijqHM4zvoqocszdOUNE\nKgB+BeBPAYwBeAvAJlV9N+l7OnKXf2MVFRBdYSa1Tkp6fs8i4KNfzn3+kruAr79sPm7KrLGSCoiu\nMotsTVV054ysF4PsnEFFyTO+sn5vJztn3AHgPVX9tapeBfBDAPcbeN18TFUhxiUtoP4TGFkVWiXV\nzMXgswC+BGApgE0istRuVFRWecZXUWPTROLqA1DbuG9s5lgdEdkiIiMiMnLp0iUDb9uCB3tdkRkB\nVlK5eTFIpZRnfBU1NjtWnKGq+1R1SFWHent7i39DD/a6IjMCrKRy82KQSinP+CpqbJoozhgH0F/z\neOHMsfaZaNW0Zmf8mlWzKsSXvlF/n5dUgM/eGj9deONtM13i3Vjofnz4DA6evIgpVVREsOnOfgwt\nWpBpUTTrIqqN1jJx77lt7SC2/fsvMDk9u17bNU+Cq6RqpKr7AOwDojUuy+FQoLatHYxdp0ozvvJ8\nbzMmEtdbAG4VkSWIEtbXADzU9quZatWUtYrqwom5NyfrFPDpm4CPzgONGwl8/P5s9wzL7XMeHz6D\n509cuP54ShXPn7iAH5y4gOmZY61av2RtFWOjtUzSez64si/a06FW42O/mL8YJGpTnkrFoqocjezH\nJSL3AfhHRBVQz6nq3zd7ftMKKFt7XT25ILmrRlqW9t26ZcermEr5/7Fvfjfe3H73nOOrdx/HeMy8\ns6nnm5D0nhWR2H9/kbEUWVUoIp9AVKm7BlHCegvAQ6r6TtL3lKmqsAxNZMsq7bgych+Xqr4K4FUT\nr2Vtr6u8SQuwVuCRNmkB2RdLTR03Iem1k/79vhZnqOo1EfkmgKOYvRhMTFplUpYmstSce50zbO11\nlfT6WVgq8KhI+nmxrIulpo6bkPTaSf9+j4szoKqvquofq+otrWYwyiS0Wx+oPe4lrmatmuZ11R+f\n12WuZc3KzfHHl9w1N57KDXNjsdg+Z9Od/a2fhOYFC1lbxWxbO4iuefUJo92CiOFT41i9+ziWbH8F\nq3cfx/Cp+OWcbWsH0VVpeM9KVIhShjY3FOStD9QG9xLX8o1Rd4uefgASfV3/TNSmqfHKOsMnjZbW\n7QWGHpn95CWV6PHXX54bz/3PAhv+eW6MlqoKhxYtQKUhicyT6E+dJj+uDSv68PQDy9A3vxuCaH2o\n5Z3xBgoiqlM/45cnoJid+klKXo01MtDo3585dvJSgLc+UBuMFGdk1dZCclLRhqWCCJckFS3EMVWw\nYKo4I8vr2CgISVJ0y6esylKcYaO9F3VOR4szOoIdLxJlmSYxNaViasomy+twmoh8a4TcbgUkKyeb\n8ydx9SxM+MTFjhc3z+9O/YnL1JRK0ntmff0sr2PqPclvSdvWu6bdCkhWTrbm3hpXEu4nlCiusKKr\nInOWnCoGu0mYKs5IKgr54m29cwo22tlrKG3hB5Fp7VZAsnKyNX8SV1LRBvcTii2suGPxZ+bUMUxN\nK0Y+/NjcGxsozoiL/cGVfTg8Oj6nYANApiKMzIUfRAa1O7XNKfHW/JkqBKIkxUQVq3H65JYd8feD\nHzx5EU9tWJb7/fYcPY/JqfrUODml2HP0fFsbONZ+z+rdxxOvON/cfnfq12925copFypau1PbnBJv\nzZ9PXJRJUjeJLF02minyqtBG4QeRae1Mbef5vjLx6xMXpZbUvy9Ll41mirwqtFH4QWRaqwrIpMpB\n3yonbWDi6pCiy1sbX3/VH30Gb74/dz1r0539RrYvKWq7AsDcVghFxkiURlIFZKvKQV8qJ23hVGEH\nFF0kEPf6b1+4gtW3LLj+Casigr9cNYChRQsyxZIUO5CtUCKLtrp4FPg6RKaxcjAffuLqgKKLBJJe\n/3/+dwLvP31f3fFmhQ9xsTSLPUuhRFamrjh55Uou4vprPvzE1QFFn6RFdp/gACMyjz0X82Hi6oCi\nT9Isr+/D9iVEoWPlYD5MXB1g8iSN6wRRZPcJDjAi8x1YuP6ajz/d4T1noqqwWWdsoL589ou39eLw\n6Hiq57ZTVVj2Acbu8OXBjvSdk3ZcMXF5xNctQELExFUeHEudk3ZccarQI9wCJHwi8lUReUdEpkXE\nmcRYZhxL7mHi8kiRRRjkjLMAHgDwhu1AKMKx5B4mLsOK3EZj29pBdFUathKpxG8l4mJRBbcYaU1V\nz6kq70J1iItjqeyYuAzqyDYajUuSCUuUrlUtcYsR80Rki4iMiMjIpUuXbIcTrA0r+vDgyr66LjQP\nruSN7Taxc4ZBneiQMTndsJXIdPJWIi51jeAWI7NE5HUAn4v5q8dU9UdpX0dV9wHYB0TFGYbCowbD\np8ZxeHT8etPqKVUcHh3H0KIFpTt3XcHEZZBLHTJc43PspqnqPbZjoPR40eUeThUa5FKHDNf4HDuV\nGy+63MPElULaooKiF3F9XiT2OfZOEpGviMgYgC8AeEVEjtqOqex40eUeJq4WshQVFF0Q4VrBRRY+\nx95JqvqSqi5U1U+q6k2qutZ2TGXHiy73sHNGC7xrnuKwc4Z/8rQuY9uzzkg7rlic0QLnt4n812rH\n4VZcqtAlThW2xPltIv9xx+GwhPGJ6/Qh4Ngu4MoY0LMQWLMTWL7RyEtvWzsY2xna1vy2z1MWPsdO\nfuPMSVj8T1ynDwE//hYwOXMCXrkYPQaMJK/qL1YXfuHmne6wyefYyX83z++OXavmzImf/E9cx3bN\nJq2qyYnouKFPXa7Mb/t8I6TPsZP/XJs5oXz8T1xXxrId95jP0x0+x07+eHz4DA6evIgpVVREsOnO\nfjy1YZlTMyeUn/+Jq2dhND0YdzwwPk93+Bw7+eHx4TN4/sSF64+nVK8/riYvJqow5KoqFJE9IvJL\nETktIi+JyHxTgaW2ZifQ1fDLr6s7Oh4YX26EjOs04kvs5K+DJ2MuYJscJ3/lLYd/DcDtqrocwK8A\n7MgfUkbLNwLrnwF6+gFI9HX9M8bWt1ziQ/eJpE4jAJyPnfw2ldBMIek4+SvXVKGq/qTm4QkAf54v\nnDYt3xhkoorj+nRHsyKMN7ff7XTs5LeKSGySqu6jReEweQPyXwP4j6S/5KZ35cAiDLJl0539mY6T\nv1p+4kqz6Z2IPAbgGoAXkl6Hm96VA4swysWlm8qf2rAMAGKrCiksLRNXq03vRGQzgHUA1qiNjr3k\nFN4vUx4u3lT+1IZlTFQlkLeq8F4Afwvgy6r6f2ZCIp/5UEBCZrD/H9mS9z6ufwLwSQCvSbQAekJV\nv5E7KvKa6wUkZAbXM8mWvFWFnzcVCBH5heuZZAu3NSGitvCmcrLF/5ZPlrhUTUXhEJE9ANYDuArg\nfQB/paqX7UYVj/3/WuPviWIwcbXBxWoqCsZrAHao6jUR+QdE3Wj+znJMibiemYy/J4rDqcI2sJqK\niqKqP1HVazMPTwAIr1t0SfD3RHGYuNrAairqEHaj8Rh/TxSHiasNSVVTrKaiNETkdRE5G/Pn/prn\npOpGo6pDqjrU29vbidApA/6eKA4TVxtYTUV5qOo9qnp7zJ9qC7XNiLrR/AW70fiLvyeKw+KMNrCa\niopS043mLnaj8Rt/TxSHiatNrKaigrAbTUD4e6IYTFxEDmE3GqLWuMZFREReERtrvyJyCcCHHX/j\n5m4E8JHtIBK4GluZ41qkqs6U8uUcU67+f0yDsdtRVOypxpWVxOUiERlR1SHbccRxNTbGFQaff16M\n3Q7bsXOqkIiIvMLERUREXmHimrXPdgBNuBob4wqDzz8vxm6H1di5xkVERF7hJy4iIvIKExcREXmF\niauGiHxVRN4RkWkRsV6mKiL3ish5EXlPRLbbjqdKRJ4Tkd+JyFnbsdQSkX4R+U8ReXfm/+Pf2I7J\nB66d92m4OjbScHX8tOLS+GLiqncWwAMA3rAdiIhUADwL4EsAlgLYJCJL7UZ13QEA99oOIsY1AN9V\n1aUAVgF41KGfmcucOe/TcHxspHEAbo6fVpwZX0xcNVT1nKq6sj3pHQDeU9Vfq+pVAD8EcH+L7+kI\nVX0DwMe242ikqr9V1bdn/vsPAM4BYIfTFhw779Nwdmyk4er4acWl8cXE5a4+ABdrHo+Bv4RTE5HF\nAFYAOGk3EioAx4ZltsdX6brDi8jrAD4X81ePVTfyI7+JyKcAHAbwbVX9ve14XMDznkxxYXyVLnGp\n6j22Y0hpHEB/zeOFM8eoCRHpQjSoXlDVF23H4wqPzvs0ODYscWV8carQXW8BuFVElojIDQC+BuBl\nyzE5TaKdF/cDOKeqe23HQ4Xh2LDApfHFxFVDRL4iImMAvgDgFRE5aisWVb0G4JsAjiJaBD2kqu/Y\niqeWiBwE8DMAgyIyJiKP2I5pxmoADwO4W0R+PvPnPttBuc6l8z4Nl8dGGg6Pn1acGV9s+URERF7h\nJy4iIvIKExcREXmFiYuIiLzCxEVERF5h4iIiIq8wcRERkVeYuIiIyCv/DzTREOtIuGjEAAAAAElF\nTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from io import BytesIO\n", "\n", "##########################\n", "### DATASET\n", "##########################\n", "\n", "ds = np.lib.DataSource()\n", "fp = ds.open('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')\n", "\n", "x = np.genfromtxt(BytesIO(fp.read().encode()), delimiter=',', usecols=range(2), max_rows=100)\n", "y = np.zeros(100)\n", "y[50:] = 1\n", "\n", "np.random.seed(1)\n", "idx = np.arange(y.shape[0])\n", "np.random.shuffle(idx)\n", "x_test, y_test = x[idx[:25]], y[idx[:25]]\n", "x_train, y_train = x[idx[25:]], y[idx[25:]]\n", "mu, std = np.mean(x_train, axis=0), np.std(x_train, axis=0)\n", "x_train, x_test = (x_train - mu) / std, (x_test - mu) / std\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(7, 2.5))\n", "ax[0].scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1])\n", "ax[0].scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1])\n", "ax[1].scatter(x_test[y_test == 1, 0], x_test[y_test == 1, 1])\n", "ax[1].scatter(x_test[y_test == 0, 0], x_test[y_test == 0, 1])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "##########################\n", "### HELPER FUNCTIONS\n", "##########################\n", "\n", "def iterate_minibatches(arrays, batch_size, shuffle=False, seed=None):\n", " rgen = np.random.RandomState(seed)\n", " indices = np.arange(arrays[0].shape[0])\n", "\n", " if shuffle:\n", " rgen.shuffle(indices)\n", "\n", " for start_idx in range(0, indices.shape[0] - batch_size + 1, batch_size):\n", " index_slice = indices[start_idx:start_idx + batch_size]\n", "\n", " yield (ary[index_slice] for ary in arrays)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 000 | AvgCost: nan | Train/Valid ACC: 0.53/0.40\n", "Epoch: 001 | AvgCost: 4.221 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 002 | AvgCost: 1.225 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 003 | AvgCost: 0.610 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 004 | AvgCost: 0.376 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 005 | AvgCost: 0.259 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 006 | AvgCost: 0.191 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 007 | AvgCost: 0.148 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 008 | AvgCost: 0.119 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 009 | AvgCost: 0.098 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 010 | AvgCost: 0.082 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 011 | AvgCost: 0.070 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 012 | AvgCost: 0.061 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 013 | AvgCost: 0.053 | Train/Valid ACC: 1.00/1.00\n", "Epoch: 014 | AvgCost: 0.047 | Train/Valid ACC: 1.00/1.00\n", "\n", "Weights:\n", " [[ 3.31176686]\n", " [-2.40808702]]\n", "\n", "Bias:\n", " [[-0.01001291]]\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "\n", "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "n_features = x.shape[1]\n", "n_samples = x.shape[0]\n", "learning_rate = 0.05\n", "training_epochs = 15\n", "batch_size = 10\n", "\n", "\n", "##########################\n", "### GRAPH DEFINITION\n", "##########################\n", "\n", "g = tf.Graph()\n", "with g.as_default() as g:\n", "\n", " # Input data\n", " tf_x = tf.placeholder(dtype=tf.float32,\n", " shape=[None, n_features], name='inputs')\n", " tf_y = tf.placeholder(dtype=tf.float32,\n", " shape=[None], name='targets')\n", " \n", " # Model parameters\n", " params = {\n", " 'weights': tf.Variable(tf.zeros(shape=[n_features, 1],\n", " dtype=tf.float32), name='weights'),\n", " 'bias': tf.Variable([[0.]], dtype=tf.float32, name='bias')}\n", "\n", " # Logistic Regression\n", " linear = tf.matmul(tf_x, params['weights']) + params['bias']\n", " pred_proba = tf.sigmoid(linear, name='predict_probas')\n", "\n", " # Loss and optimizer\n", " r = tf.reshape(pred_proba, [-1])\n", " cost = tf.reduce_mean(tf.reduce_sum((-tf_y * tf.log(r)) - \n", " ((1. - tf_y) * tf.log(1. - r))), name='cost')\n", " optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n", " train = optimizer.minimize(cost, name='train')\n", " \n", " # Class prediction\n", " pred_labels = tf.round(tf.reshape(pred_proba, [-1]), name='predict_labels')\n", " correct_prediction = tf.equal(tf_y, pred_labels)\n", " accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')\n", "\n", "\n", "##########################\n", "### TRAINING & EVALUATION\n", "##########################\n", " \n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.global_variables_initializer())\n", " \n", " avg_cost = np.nan\n", " count = 1\n", " \n", " for epoch in range(training_epochs):\n", "\n", " train_acc = sess.run('accuracy:0', feed_dict={tf_x: x_train,\n", " tf_y: y_train})\n", " valid_acc = sess.run('accuracy:0', feed_dict={tf_x: x_test,\n", " tf_y: y_test}) \n", "\n", " print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch, avg_cost / count), end=\"\")\n", " print(\" | Train/Valid ACC: %.2f/%.2f\" % (train_acc, valid_acc))\n", " \n", " avg_cost = 0.\n", " for x_batch, y_batch in iterate_minibatches(arrays=[x_train, y_train],\n", " batch_size=batch_size, \n", " shuffle=True, seed=123):\n", " \n", " feed_dict = {'inputs:0': x_batch,\n", " 'targets:0': y_batch}\n", " _, c = sess.run(['train', 'cost:0'], feed_dict=feed_dict)\n", "\n", " avg_cost += c\n", " count += 1\n", "\n", " weights, bias = sess.run(['weights:0', 'bias:0'])\n", " print('\\nWeights:\\n', weights)\n", " print('\\nBias:\\n', bias)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }