{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Theano 实例:Logistic 回归" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using gpu device 0: GeForce GTX 850M\n" ] } ], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import theano\n", "import theano.tensor as T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## sigmoid 函数" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一个 `logistic` 曲线由 `sigmoid` 函数给出:\n", "$$s(x) = \\frac{1}{1+e^{-x}}$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们来定义一个 `elementwise` 的 sigmoid 函数:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [], "source": [ "x = T.matrix('x')\n", "s = 1 / (1 + T.exp(-x))\n", "sigmoid = theano.function([x], s, allow_input_downcast=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里 `allow_input_downcast=True` 的作用是允许输入 `downcast` 成定义的输入类型:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.5 , 0.7310586 ],\n", " [ 0.26894143, 0.11920293]], dtype=float32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sigmoid([[ 0, 1],\n", " [-1,-2]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其图像如下所示:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAArUAAAE4CAYAAABFZgvaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XeYVOXZ+PHvWTqCCKgYAQXFrgEswRhUNFYsRBNBjQVs\nxK6vsaFG9IcdlURf0RiixIKoMYqKIGoWX6KIwRZRUTREigGUIiB19/z+eFiWssC2mTNn5vu5ruea\nObMzww33PMu9Z+/zPFEcx0iSJElpVpR0AJIkSVJNWdRKkiQp9SxqJUmSlHoWtZIkSUo9i1pJkiSl\nnkWtJEmSUs+iVpI2IYqiP0dRNCuKon9t5Dl/iKLoiyiKPoyiqHM245MkWdRKUmU8Ahy1oS9GUdQd\n6BDH8U7AecDgbAUmSQosaiVpE+I4/j9g3kaecjwwdNVz3wG2iKKoVTZikyQFFrWSVHOtgWlrHE8H\n2iQUiyQVJItaSaod0TrH7kEuSVlUt7ovHDRoUNypUye6desGQHFxMYDHKTguu58r8Xhs/tJwPGzY\nMG655RaAeN2vH3fccXTq1OlYVmnbti2DBg2aWHa85vOjKOLMM88EoF27dqvfI+m/n8ebPi57LFfi\n8bhqx2WP5Uo8lT1+/fVili+Hffbpxg8/hK8vXw577NGNJUtgwoRwvMMO3Vi6FP71r3DcunU4njIl\nHLds2Y1ly2DatGJWrICmTcPxnDnFrFwJDRp0Y/lymD8/HEdROF66NByXlnZb9a9Y9u+ZzeNDLo/j\neBCbEMVx9U4m9O/fP+7fv3+1Xqtk9e/fH3OXXuYvGVOnTuW4447jX/9afwGEkSNHcv/99zNy5EjG\njx/PZZddxvjx4yt8nyiKqO73XSXLuZdu2cjfihXw/ffrj0WLYOHCMMruL1oUxuLFFY8ffghj2bKM\nhlwl9epB/frhds1Rt27Fx3Xrrn2/Tp21Hy8bdeqEsebxmo/fdFN0UxzH/TcVX7XP1E6dOrW6L1XC\nzF26mb/sO+WUUxg7dizffvstbdu25aabbmLFihUA9O3bl+7duzNy5Eg6dOjAZpttxiOPPJJwxMoE\n5166VSZ/cRwKzblzKx7z5689Fiwov//997BkSWZib9y4fDRqtPZo2HDt+xWNBg3Kb+vXD7frjvr1\n1x/16oWvlRWi0bpNVlly002Ve161i1pJKhTDhg3b5HPuv//+LEQiqSpKS2HePJg1C/77Xxg+HGbP\nhjlzwvj22/XHypXV//OKiqBZM9h88/LRtGkYTZqsf9ukCWy2WfkoO27cuPy2YcPkism0qXZR27t3\n71oMQ9lk7tLN/EnJcO7ljtLSUJxOnw4zZ4bxzTdhrHl/zpw1i9TejB696fdu0gSaN4cWLcpH8+bl\nj22xxfpj881DMdu4sQVohhRX5knV7qnFK3slqcrsqZU2bdEimDq1fEybVj6mT4cZM0L/amU0awat\nWoWx9dbhdqutwthyyzDK7rdsGX7drpxTqR8Vqn2mtri4ePWVeUoXc5du5k9KhnOv9pSWhuJ0ypTy\n8dVX5UXsd99t+j1atoQ2bWDbbcP40Y/Wvt1mm1DAlhWp5i//2VMrSZIy4rvv4LPPysfkyeUF7Mau\n6m/QANq1C2P77cNo0wbatg23bdqEC6OkNdl+IElZZPuB8tGcOfCvf4Xx8cflRey33274NdtsAx06\nhLHjjrDDDtC+fShkW7UKF11Jq2S2/UCSJBWWlSvh00/hvffgo49CEfvRR2F1gYo0aQK77lo+dtkF\ndtopFLFNmmQ3duW/jPfURl4GqJTK1bNp9oVJySi0ubd8OUyaBBMnhiL2vffgww9h6dL1n9ukCey5\nZxh77QW77x6K2Natc2c1gELLXyHKypnaXC0OpA3xhzFJhWb6dBg/Ht5+O9xOnFhx3+uOO0LnztCp\nUyhg99or9LzaLqCkZbyn1v4xpZGfW2WKny3lgpKS0DYwdiz84x+hkJ0xY/3n7bwz7LMP7L13uO3U\nKazXKmWZPbWSJCms6fr++6GIHTsWxo0LW7yuqVkz2H//8tGliwWs0sV1aqWUce5JyUjT3IvjsHzW\nmDHw6qtQXBw2NFhTu3Zw8MFw4IFwwAHhIq58biFIU/5UPZ6plSQpD3z3Hbz+eihiX3017L61pp12\nCkXswQfDQQfBdtslE6eUKfbUShXwc6tM8bOl2jR5MowYEcZbb4WduspstRUcfnj5aN06uTilGrKn\nVpKkfLJyZSheR4yAF1+Ezz8v/1q9etCtGxxxRBgdO+Z3O4G0rmp/3IuLi2sxDEmV5dyTkpHU3Fu5\nEt54A/r2hR/9KLQP3H13KGhbtIDTT4ennw67d73+Olx9dVhyy4J2bX7vzH+eqVXOWbhwIX369OHe\ne++lbdu2SYcjSVlXUhJWKBg+HP76V5g9u/xrHTpAjx5w/PHhAq+6/k8uAfbU5rySkhLq1KlT5det\nXLmSuin8TjdkyBCmT5/OTTfdxNSpU9kuoSsZ/NwqU/xsaWM++ACGDoWnnoL//rf88Q4doFevMPbc\nM3d26ZKypFKfeIvaHPbcc8/x/fff07t37yq/9uabb+awww7jgAMOqP3AsqCoqMiiVnnJz5bWNXs2\nPPFEKGY//LD88R12gJ49QyHbsaOFrApapT799tTmqLFjx/Lmm29Wq6AF6NevH7fccguffvpp7Qam\nxDn3pGTU5txbvhz+9rfQRtC6NfzP/4SCtkULuOgimDABpkyB224Lu3hZ0Nac3zvzX/p+P10Avv/+\ne6666irGjh1b7feoW7cugwcP5pe//CVvv/12KlsRJCnfTJsGf/wjPPwwzJoVHqtTB447Ds48E449\nFho0SDZGKa1sP8hB11xzDdtuuy2XXHJJjd+rT58+/OxnP+Occ86phciyx/YD5Ss/W4WntDSsSvDA\nA2EprrK1ZHffHc4+G379a2jVKtkYpRxnT20aLV68mO22244pU6bQvBY23X7vvffo1asXX3zxRS1E\nlz0WtcpXfrYKx4IF8Oc/w+DBUPYtuG5d+OUv4YILwva0thVIlWJPbRq9/PLLtG/fvlYKWoBOnTrx\n3Xff8f7779fK+2WT//FXzLknJaOyc2/6dLjySmjbNvTKfvEFtGkD/+//hfaDp54K29Ra0GaX3zvz\nn42WWfb1118zYMAAGjduTKNGjWjYsCFXXnkljRs3BmDMmDEbXbFg4sSJPPbYY9SpU4epU6fypz/9\niYceeoj58+czY8YMbrrpJnbYYYfVzy8qKqJr166MHj2azp07Z/zvV1NPPvkk48aNI4oirrnmGrp2\n7cqFF16YdFiStEkffQQDB8KwYWHDBAg7fF16aeiV9dIGKbNsP8iiJUuW0LFjR1566SV23nlnvvrq\nK7p06cKQIUM4/vjjAdhvv/0477zzOPfcc9d7/VdffcU999zD/fffD0Dv3r15++23GTp0KKWlpRx4\n4IEMHDiQyy+/fK3X/fa3v2XGjBkMGzaswrjOPvts3nvvvSr9XX7/+99z0EEHVek1aeLnVpniZyu/\nxHHY7euuu2D06PBYURGcdBL89rew777JxifliUr9XiMnf25M6lcymf5/ZsyYMcyZM2f1LllbbLEF\nl19+OYcffvjq50ydOpUtttiiwtfffffd3HnnnauPFy9eTIsWLdh///2ZPn06V1xxRYVLgDVv3pxx\n48ZtMK4hQ4ZU828kSYUpjsPFX/37wz/+ER5r3Dhc+HX55dC+faLhSQXJntosat68OQsWLGCvvfbi\nkksu4ZNPPqFfv340atRo9XMWLFiwwaL2yiuvZLPNNlt9/Pbbb3PYYYcB0KZNG+68884Ke3FbtmzJ\nggULavlvo6Q496RkFBcXry5mDzoIDj88FLQtWpT3y/7hDxa0ucrvnfmv2kVtJsVxMiPTDjzwQG6/\n/XaWLVvG/fffz0EHHcQDDzyw1nOiKKK0bL2XdbRr1271/cmTJzNz5kwOOeSQTf65paWl/rpTkmog\njmHixFDMHnYYjBsXitlbb4WpU+H668OxpOTYU5uQTz75hLPOOouFCxcyadKk1Y+3atWKQYMGccop\np2z09YMHD+byyy9n/vz5NGzYEAg9t2teJFZmwIABvPTSS4wfP77C9zrvvPOqvDrC3Xffvbqntqgo\nJ3822qgoiigpKdno1/3cKhP8bKXPu++G1QzK9sNp0SL0y150ETRtmmxsUoFIb09tPjrxxBP58ssv\n+XDVxt677747p512Gq+++upaz2vfvj3ffffdeq9fsmQJN954I2eccQZ77rknY8aMoWPHjqsL2tLS\nUu666y4GDx683mvnzp1bYbFb5o9//GNN/mobPLMsSWn21Vdw3XVhCS6A5s3Li9nNN082Nknrs6c2\nSyZOnLjWBWGzZ89m2LBh/O53v1vreV27duWTTz5Z7/UjR45k4MCBTJo0ic8++4wpU6bQYI29FG+5\n5ZYKLxKDcFZ4n332qZ2/iBLn3JMya+7csL7srruGgrZBA7j6avjLX4rp18+CNq383pn/PFObJY89\n9hhjx46lX79+LFy4kMWLF/P73/+efddZ7+Woo47isssuW+/13bp1o3fv3kycOJH333+f8ePHc8EF\nF/Cb3/yG+vXr06NHD7p06bLe61auXMlbb7211qoJ+W7BggWMGzeOTz/9lK222orXXnuNBx54gKb+\nnlDSRixdCvfdF/pk588PK/GccUa4CGy77cCaSMpt9tTmmGXLltG6dWs++ugjtt122xq/31tvvcV5\n553Hxx9/XAvRbdzChQvp06cP99577+plyzJh0aJFDB8+fL3Ht912W44++mimTp1Ku3btuPXWWzni\niCPYZZddqlzQ+rlVpvjZyk2vvAIXXwxffhmODzsM7rwTUrBnjVQIKtVTa1Gbg/r378+SJUu44447\navxeJ510Et27d6dPnz61ENmGDRkyhOnTp3PTTTcxdepUtttuu0q9bvDgwZx//vm1GsvChQtp1KgR\nJ5xwAiNGjGDatGmVjqeMn1tlip+t3DJ1alhX9vnnw/Huu8M998ARR7iNrZRDKjUb7anNQVdddRWv\nvPIK8+bNq9H7TJ48ma+//nqDvba16eyzz+bGG2+s8utmz55d67HccccdPP744+y44468/vrrzJgx\no9b/jCQ596SaW7oUBgyA3XYLBW2TJnD33fDBB3DkkRUXtM69dDN/+c+e2hzUuHFjhgwZwrnnnssz\nzzxDVI3TBUuXLuXiiy/mySefrNbrc0VpaSn33nsvdevWZYsttmDu3LnrbQO8rgEDBmQpOklpNGpU\naDWYMiUcn3IKDBwItdDxJSlB1T5T261bt1oMQ+vab7/96Nu3L/fdd1+1Xn/rrbdy6623suOOO9Zy\nZNnVt29fSkpKuPTSS/nVr37FnDlzkg4pcc49qXrmzAkF7NFHh4J2t93gjTfgyScrV9A699LN/OU/\ne2pVq4qKijbYUzt79mwGDRq01udh3LhxdO3adfVxkyZNuO666wD47LPP2HvvvXnooYcoKipiyZIl\n9OzZk82zsJ6On1tlip+t7ItjGD48nJ399lto3Bj694fLLoN69ZKOTlIlZHbzheLiYn/qUZVsvfXW\n3HrrrWs9dtNNN22wF/fDDz9kjz324PTTT89GeKnh3JMqb+ZMuOACeOGFcHzoofDww7CR/Wg2yLmX\nbuYv/6Vvf1PlvNo6C7XLLrus3jGt7H2HDBlSK+8tKb/FMTzySFjN4IUXwna2f/wjvPZa9QpaSbnP\n9gPViieffJJx48bx0EMP0bNnT7p27cqFF164yddt7EwtwKBBg4iiiC233JIffviB448/nlatWtVm\n6BXyc6tM8bOVeTNnwllnwejR4fiYY+DBB6FNm2TjklRtrlOr3HfnnXdy1VVXJR3GevzcKlP8bGXW\n3/4G55wTtrpt0QL+8Ac49VTXnJVSznVqlftysaDNdc49aX2LFsG558KJJ4aC9qijYNIk+PWva6+g\nde6lm/nLf65TK0lKtXffDcXrF19AgwZw111w0UWenZUKje0HUgX83CpT/GzVnpISuOMOuPFGWLkS\n9torrDm7555JRyaplmV2SS9JkpLyzTdhI4WxY8Px5ZfDrbfCGgumSCow9tRKKePcU6ErLobOnUNB\nu802Ydvbe+7JfEHr3Es385f/XKdWkpQKpaWh3eDnP4dZs+CQQ+D99+HII5OOTFIusKdWqoCfW2WK\nn63qmTcPzjwTXnwxHF97Ldx8M9S1iU4qBPbUSpLS77334Fe/gn//G7bYAh57DI49NumoJOUae2ql\nlHHuqZAMGQIHHBAK2n32CQVuUgWtcy/dzF/+s6dWkpRzVq6Eyy4Lu4MtWwa/+Q2MGwft2ycdmaRc\nZU+tVAE/t8oUP1ubNm8e9OoFY8ZAvXoweDCcfXbSUUlKUO701EZu6yJJqoTJk+H44+Hzz2GrreC5\n56Br16SjkpQGGe+pjePYkWPj73//e+IxpGHkKvvClK9Gj4YuXUJB27Fj2P42lwpa5166mb/8Z0+t\nJClRcQyDBkH37rBgAZxwQuif3X77pCOTlCYZ76mVJJWzp3ZtK1fCxRfDgw+G4xtugP79ochTLpLK\n5U5PrSRJ61q8GE45JWyo0KABDB0aLhCTpOpwndoCZO7SzfwpH8yZA4ceGgra5s3h9ddzv6B17qWb\n+ct/nqmVJGXVlClw9NHhdvvtYdQo2HXXpKOSlHb21EpSFhV6T+2ECWFHsDlzYO+94eWXYZttko5K\nUo6rVE+trfiSpKx48UXo1i0UtEceCcXFFrSSao89tQXI3KWb+VMaDR0Kv/gFLFkCffqEArdp06Sj\nqhrnXrqZv/znmVpJUkbddx/07g2lpWHJriFDwva3klSb7KmVpCwqpJ7aOIZbbgmFLMC998JllyUb\nk6RUcp1aSVIy4hiuvBLuvjtspPDww3DWWUlHJSmf2VNbgMxdupk/5bqSEujbNxS09erBU0/lR0Hr\n3Es385f/PFMrSao1K1bA6afD8OHQsCE891xYk1aSMs2eWknKonzuqV26FE46CV56Kaxs8NJLcNBB\nSUclKQ/YUytJyo6lS+HEE+GVV6Bly7BL2L77Jh2VpEJiT20BMnfpZv6Ua5YuhRNOCAXtllvCG2/k\nZ0Hr3Es385f/PFMrSaq2soJ21KjygnavvZKOSlIhsqdWkrIon3pqly4Nu4SNHm1BKymjKtVT645i\nkqQqW7Og3Wor+PvfLWglJcue2gJk7tLN/ClpS5ZAjx7lBe0bb8CeeyYdVeY599LN/OU/z9RKkipt\n2bLQQ/vqq7D11uEMbSEUtJJynz21kpRFae6pXbECevaE558vbznYY4+ko5JUAOyplSTVjpIS6N07\nFLRbbAFjxljQSsot9tQWIHOXbuZP2RbHcP758OST0KRJWL6rY8eko8o+5166mb/855laSdIGxTH8\nz//Aww9Dw4Zh69suXZKOSpLWZ0+tJGVR2npqb7gBBgyAevVgxAg46qikI5JUgOyplSRV3+23h4K2\nTh0YPtyCVlJus6e2AJm7dDN/yobBg+HaayGK4NFHwzJehc65l27mL/95plaStJZnnoELLwz3Bw+G\n005LNh5Jqgx7aiUpi3K9p/b11+Hoo8OatLfcAv36JR2RJFWup9aiVpKyKJeL2okToVs3WLQILrkE\nBg0K7QeSlLDMXihmb0p6mbt0M3/KhC++CGdoFy2CU06Be++1oF2Xcy/dzF/+s6dWkgrcN9/AEUfA\nnDnh9tFHocj/HSSljO0HkpRFudZ+MH8+HHwwfPQR/OQnoae2SZOko5KktbhOrSRpw5YuhR49QkG7\nyy7w8ssWtJLSy57aAmTu0s38qTaUlsLpp8Obb0Lr1jB6NGy5ZdJR5TbnXrqZv/znmVpJKkBXXQXP\nPgubbw6vvALbb590RJJUM/bUSlIW5UJP7X33hSW76taFUaPg5z9PNBxJ2hTXqZWkXJN0UfvCC2HL\n2ziGoUPhjDMSC0WSKst1alUxc5du5k/V9c47YQ3aOIabb7agrSrnXrqZv/xnT60kFYAvv4TjjoMl\nS+Dss+H665OOSJJql+0HkpRFSbQffPstHHBA2DXsyCPhxRehXr2shiBJNeE6tZJU6JYuhV/8IhS0\nnTrBM89Y0ErKT/bUFiBzl27mT5UVx3DWWfCPf0CbNmFzhaZNk44qvZx76Wb+8p9naiUpT918Mwwb\nFnYJe/ll2HbbpCOSpMyxp1aSsihbPbXDhsGpp0JREYwYAccck/E/UpIyxZ5aSSpEb70FffqE+/fc\nY0ErqTDYU1uAzF26mT9tzL//HS4MW7YMzj8/7Bym2uHcSzfzl/88UytJeWLBAjj2WJgzB444Av7w\nB4gq9Us7SUo/e2olKYsy1VO7cmVoM3j1Vdh999CC0KxZrf8xkpQEe2olqVBcemkoaLfaCl56yYJW\nUuGxp7YAmbt0M39a1wMPhNGgAbzwArRvn3RE+cm5l27mL/95plaSUuyNN8ovBvvTn+CnP002HklK\nij21kpRFtdlTO2UK/OQnMG8eXHMN3HZbrbytJOWaSvXUWtRKUhbVVlG7YEE4K/vpp3DccfD882Gj\nBUnKQ5m9UMzelPQyd+lm/lRSAqecEgraPfeEJ56woM0G5166mb/857dBSUqZq6+GV16Bli3DFrhN\nmyYdkSQlz/YDScqimrYfPPpo2AK3bl147TU4+ODai02ScpTr1EpSPnnrLejbN9x/4AELWklakz21\nBcjcpZv5K0zTpsGJJ8Ly5XDxxXDuuUlHVHice+lm/vKfZ2olKcctWQInnACzZsGhh8I99yQdkSTl\nHntqJSmLqtpTG8dw+ulhhYP27eHdd8MFYpJUQOyplaS0u/vuUNButlnYAteCVpIqZk9tATJ36Wb+\nCseoUWH5LoC//AX22ivZeAqdcy/dzF/+80ytJOWgzz+Hk0+G0lK48cZwkZgkacPsqZWkLKpMT+33\n30OXLvDZZ/CLX8Bf/+qOYZIKmj21kpQ2paXw61+HgnaPPULbgQWtJG2aPbUFyNylm/nLbzfeCC+9\nBM2bhwvD3AI3dzj30s385T9//pekHPHcczBgQDgzO3w47Lhj0hFJUnrYUytJWbShntpJk0If7eLF\nMHAgXHFFAsFJUm6qVE+tRa0kZVFFRe28efCTn8CUKXDqqfD44xBV6lu4JBWEzF4oZm9Kepm7dDN/\n+aWkJBSyU6ZAp07w8MMWtLnKuZdu5i//2VMrSQm64YawycKWW8Lzz0PjxklHJEnpZPuBJGXRmu0H\nTz8NvXpBnTowZgwcckjCwUlSbnKdWknKVR99BH36hPt3321BK0k1ZU9tATJ36Wb+0m/uXDjhBPjh\nBzjjDLjkkqQjUmU499LN/OU/z9RKUiWMGjWKXXfdlZ122ok77rhjva8XFxfTrFkzOnfuTOfOnRkw\nYMAG3+uUU+Crr2CffeDBB70wTJJqgz21krQJJSUl7LLLLrz22mu0bt2a/fbbj2HDhrHbbrutfk5x\ncTH33HMPI0aM2Oh7RVEExGy1FUycCG3bZjh4SUo/e2olqTZMmDCBDh060K5dO+rVq8fJJ5/MCy+8\nsN7zNnWSYPjwcFu3LjzzjAWtJNUme2oLkLlLN/OXfTNmzKDtGhVomzZtmDFjxlrPiaKIt956i44d\nO9K9e3c++eSTtb7+0Udw1lnh/j33wMEHZzxs1TLnXrqZv/xXt7ov/OCDDwDo1q0bUP5h8dhjjz3O\nt+NJkybxzTffUObTTz9dq6gtLi7mhx9+YNq0aTRu3Jg77riDI488kmnTpgEwYkQxffvCDz+E9/vn\nP3vTuze0a9du9Z+RS39fjys+LpMr8XhcteMyuRKPx5U/PuSQQ7rFcRy+sBH21ErSJowfP57+/fsz\natQoAG677TaKioq4+uqrN/ia9u3bM3HiRJo1a0H37vDqq+HCsIkT198mV5K0UfbUSlJt2Hffffni\niy+YOnUqy5cvZ/jw4Rx//PFrPWfWrFmri9UJEyYQxzEtWrSgX79Q0G61Ffztb0lEL0mFodpF7bqn\n85Ue5i7dzF/21a1bl/vvv58jjzyS3XffnV69erHbbrvx0EMP8dBDDwHw7LPPstdee9GpUycuu+wy\nnnrqKZ5+Gu680wvD8oVzL93MX/6rdk+tJBWSo48+mqOPPnqtx/r27bv6/oUXXsiFF164+vjDD8t3\nDPPCMEnKPHtqJamWffcd7Lcf/PvfYcewRx8t32AhiuyplaQqqlRPrUWtJNWilSvh6KPhtddg333h\n//4PGjYs/7pFrSRVWWYvFLM3Jb3MXbqZv9x27bWhoN16a3juubULWqWbcy/dzF/+c/UDSaolTz4J\nAweGC8OefdYLwyQpm2w/kKRa8P77cMABsHQp/O//wgUXVPw82w8kqcrsqZWkbJgzJ/TPfv01nH02\nPPxw+YVh67KolaQqs6dWFTN36Wb+csuKFdCrVyho998/nKXdUEGrdHPupZv5y3/21EpSDVx5Jfz9\n77DNNvDXv0KDBklHJEmFyfYDSaqmRx6Bs86CevWguDj01G6K7QeSVGWZbT+QpEI2fjz85jfh/uDB\nlStoJUmZY09tATJ36Wb+kjdzJpx4IixfDhddFC4OU/5z7qWb+ct/nqmVpCpYuhROOAG++Qa6dYN7\n7kk6IkkS2FMrSZUWx9CnDwwdCttvD//8J2y5ZdXew55aSaoye2olqTb9/vehoG3cGF54oeoFrSQp\nc+ypLUDmLt3MXzJeew2uuCLcf/RR6Ngx0XCUAOdeupm//OeZWknahC+/hJ49obQUrrsOTjop6Ygk\nSeuyp1aSNmLBAvjpT+HTT+HYY0PbQVENTgfYUytJVWZPrSTVREkJnHpqKGj32AOeeKJmBa0kKXPs\nqS1A5i7dzF/2XHMNjBwJLVvCiBGw+eZJR6QkOffSzfzlP885SFIFhg6FgQOhbl149lnYYYekI5Ik\nbYw9tZK0jrfegkMOCTuGPfgg9O1be+9tT60kVVmlemotaiVpDV9/DfvtB7Nnhy1w77uvdt/folaS\nqiyzF4rZm5Je5i7dzF/mLF4MPXqEgvaww+Dee5OOSLnEuZdu5i//2VMrSYQ1aM88Ez74ADp0gOHD\nQz+tJClTdr1sAAAUfklEQVQdbD+QJOD66+GWW6BZMxg/HnbdNTN/ju0HklRl9tRKUmX85S/hLG2d\nOvDyy3DkkZn7syxqJanK7KlVxcxdupm/2vXmm3DOOeH+ffdltqBVujn30s385T97aiUVrClT4IQT\nYMUKuPRSOP/8pCOSJFWX7QeSCtK8ebD//vD553DMMfDCC6H9INNsP5CkKsts+4EkpdWKFfCrX4WC\n9sc/hmHDslPQSpIyx57aAmTu0s381UwchzaDN96AVq3gxRehadOko1IaOPfSzfzlP8/USiooAwfC\nkCHQqFEoaLfbLumIJEm1wZ5aSQXjmWegZ8/y+7/6VfZjsKdWkqrMnlpJKjNuHJx+erh/++3JFLSS\npMyxp7YAmbt0M39VN3ky9OgBy5aFftqrrko6IqWRcy/dzF/+80ytpLw2axYcfTTMnQvHHgt/+ANE\nlfpFliQpTeyplZS3Fi+GQw6Bd9+FffeF4mLYbLNkY7KnVpKqzJ5aSYWrpAROPTUUtO3awUsvJV/Q\nSpIyx57aAmTu0s38bVocwyWXwIgR0Lw5vPJKWJNWqgnnXrqZv/znmVpJeeeuu+CBB6B+/bD97a67\nJh2RJCnT7KmVlFeGDoXevcP9p56CXr0SDWc99tRKUpXZUyupsLz0Epx9drg/aFDuFbSSpMyxp7YA\nmbt0M38Ve+utsFtYSQn06weXXpp0RMo3zr10M3/5zzO1klLv44/hmGNgyRI45xwYMCDpiCRJ2WZP\nraRU+89/4IADYOZM+MUv4JlnoG7dpKPaMHtqJanKKtVTa1ErKbW+/Ra6dg3b4B54IIweDY0aJR3V\nxlnUSlKVZfZCMXtT0svcpZv5CxYtCi0HkyfDj38c1qTN9YJW6ebcSzfzl//sqZWUOkuXQo8eMGFC\n2C1s1CjYYouko5IkJcn2A0mpsnw5nHgivPwybLMNvPkm7LRT0lFVnu0HklRlrlMrKb+UlMBpp4WC\ntkULGDMmXQWtJClz7KktQOYu3Qo1f6WlYbmuZ56BzTeHV1+FPfdMOioVkkKde/nC/OU/z9RKynlx\nDJdcAo8+Co0bhzO1++yTdFSSpFxiT62knBbHcO21cMcdUL9+KGgPOyzpqKrPnlpJqjJ7aiWl3623\nhoK2bl149tl0F7SSpMyxp7YAmbt0K6T83XEHXH89RBE8/jgcd1zSEamQFdLcy0fmL/95plZSTrr9\ndrjmmlDQ/vnP0KtX0hFJknKZPbWScs5tt0G/fuUFbe/eSUdUe+yplaQqs6dWUvrcemt5QfvII/lV\n0EqSMsee2gJk7tItn/N3yy1w3XXlBe2ZZyYdkVQun+deITB/+c8ztZJywoAB5ReFPfqoBa0kqWrs\nqZWUuAED4IYbQkE7dCicfnrSEWWOPbWSVGWV6qmtm+koJGlD1txYoagoFLSnnZZ0VJKkNLKntgCZ\nu3TLl/yVlsIFF5RvrPDEExa0ym35MvcKlfnLf56plZR1K1aEVQ2efBIaNgw7hR1zTNJRSZLSzJ5a\nSVm1dCn07AkvvghNmoTbbt2Sjip77KmVpCqzp1ZSblm4EHr0gL//HVq0gFdegZ/8JOmoJEn5wJ7a\nAmTu0i2t+Zs7Fw4/PBS022wDY8da0Cpd0jr3FJi//OeZWkkZ9/XX0L07TJoE7drBa6/BjjsmHZUk\nKZ/YUyspoz78MBS0M2fC7rvD6NHQpk3SUSXHnlpJqrJK9dS6o5ikjHntNTjwwFDQHnwwjBtX2AWt\nJClz7KktQOYu3dKSv8ceg6OPDheH9eoVztA2b550VFL1pWXuqWLmL/95plZSrYpjuO02OOMMWLkS\nrrgirEfboEHSkUmS8pk9tZJqzcqVcPHF8OCDEEVw771w6aVJR5Vb7KmVpCpznVpJ2TN/Ppx8cmgz\naNAgbHv7y18mHZUkqVDYU1uAzF265WL+vvgC9t8/FLQtW4YLxCxolW9yce6p8sxf/rOnVlKNjBkT\nNlGYPBn23BPefRe6dk06KklSobGnVlK1xDHcfz9cfjmUlITtbx97DJo2TTqy3GZPrSRVmevUSsqM\n5cuhb1+45JJQ0PbrB889Z0ErSUqOPbUFyNylW9L5mzULDj8cHn4YGjYMF4TdcgsU+SOy8lzSc081\nY/7yn6sfSKq0ceOgZ0/45hvYdlt4/nnYb7+ko5IkyZ5aSZUQx3DPPXD11aHd4MADYfhw+NGPko4s\nfeyplaQqs6dWUs0tWBCW5/rtb0NBe+WV8MYbFrSSpNxiT20BMnfpls38ffgh7LMP/O1v0KxZuL3z\nTqhr45IKkN8708385T/P1EpaTxzDn/8cNlT48kvo1AkmToRf/CLpyCRJqpg9tZLWMm8e/OY38PTT\n4fjss+G++6BRo2Tjyhf21EpSlVWqp9ZfIkparbgYTj8dpk+HJk3C5gpnnpl0VJIkbZo9tQXI3KVb\nJvK3fDlccw0cemgoaLt0gQ8+sKCV1uT3znQzf/nPM7VSgZs8GU49Fd57L2ygcMMNcP31UK9e0pFJ\nklR59tRKBaq0FB56CK64ApYsgXbt4PHH4Wc/Szqy/GZPrSRVmT21kio2ZQqccw6MHRuOTz89XAzW\nrFmycUmSVF321BYgc5duNclfSUnYGezHPw4F7dZbh1UO/vIXC1ppU/zemW7mL/95plYqEJ98Amed\nBe+8E45POw0GDYKWLZONS5Kk2mBPrZTnVqwIu4DdfHNY5aB1a3jwQTj22KQjK0z21EpSldlTKxW6\n4mK48MJwlhbg3HPhrrtsNZAk5R97aguQuUu3yuRv5kz49a/hkENCQbvjjjBmDPzxjxa0UnX5vTPd\nzF/+q3ZRKyn3rFgRLgTbZRd48klo2DC0HXz8MRx2WNLRSZKUOfbUSnli7NjQajBpUjju0QPuvRfa\nt082Lq3NnlpJqrJK9dR6plZKuc8/hxNPhG7dQkG7ww7w0kvw/PMWtJKkwmFPbQEyd+lWlr/Zs+Gi\ni2CPPeBvf4NGjaB//1DYHnNMoiFKecnvnelm/vKfqx9IKbN0KdxyC9xxByxcCEVFcPbZoXd2222T\njk6SpGTYUyulxMqVYeev3/0OZswIj3XvHorbPfdMNjZVnj21klRlrlMr5YOVK+Hxx2HAAPjyy/BY\n584wcCAcemiysUmSlCvsqS1A5i4dVq6EoUNh112hT59Q0O60E/TrV8w//2lBK2Wb3zvTzfzlP1c/\nkHJMWZvBbrtB796hmO3QITz2ySdw+OGhj1aSJJWzp1bKEYsXwyOPhLVlv/oqPNahA9xwA5x6KtS1\nWSgv2FMrSVVmT62UBrNmwf33wwMPwNy54bEOHeD668NWtxazkiRtmj21Bcjc5YbJk6FvX9h++3AR\n2Ny50KULPPssfPYZnHlmxQWt+ZOS4dxLN/OX/zwHJGVRSQmMGhXOyr7yCpT9Fvr44+HKK+FnP4Oo\nUr9kkSRJa7KnVsqCOXNgyBB48EH4z3/CYw0awOmnwxVXhBUOVBjsqZWkKrOnVkpSHMNbb8HgwfDM\nM7B8eXi8fXs4//ywTNeWWyYboyRJ+cKe2gJk7jJrxgy4/fawJFfXrvDEE7BiBRx3HIwcCVOmhFaD\n6ha05k9KhnMv3cxf/vNMrVQLliyB55+HRx+FMWPKe2VbtYKzzoLzzoN27ZKMUJKk/GZPrVRNK1dC\ncTEMHw5PPw3ffx8er18fevQIGycccYRLcmlt9tRKUpXZUyvVtpISePPNUMT+9a/hArAy++0XCtmT\nT4YWLRILUZKkgmRPbQEyd1VTdkb2oougdWs49NCwisGcObDTTmGThEmTYMIEuOCCzBe05k9KhnMv\n3cxf/vNMrVSBBQtg9GgYMSJc3DVvXvnXdtwRevaEXr3gxz92XVlJknKBPbXSKl9+GQrYESNg7Niw\nYkGZXXcNfbI9e0Lnzhayqj57aiWpyuyplTZm/nx44w149dUw/v3v8q8VFcFBB4Wdvo47DnbeObk4\nJUnSptlTW4AKNXdLloTe2P794YADoGVL+OUv4aGHQkHbvDmcdBI89hjMnh3O1l5xRe4VtIWaPylp\nzr10M3/5zzO1yluLF4cdvcaODWPChPJdvSAstdW1a1h264gjYO+9oU6d5OKVJEnVZ0+t8kIcw9df\nw/jxYbz9NkycGFYuKBNF0LFjaCs47DDo1g2aNk0sZBUoe2olqcrsqVX+WrAA3n8/nH0tK2L/+9+1\nn1NUBPvuCwcfHEbXrqHFQJIk5Z9qF7XFxcV069atFkNRtqQtd3PnwnvvhTFxYridMmX95zVvDvvv\nDz/9abjt0gU23zz78WZa2vIn5QvnXrqZv/znmVrljOXL4bPP4F//WntMm7b+c+vXD2vE7rtveSG7\n004utSVJUqGyp1ZZt3gxTJ4cCtjPPoNPPw1j8uS1e2DLNGoUemH32SdczLX33rDHHlCvXvZjl2rK\nnlpJqjJ7apWcpUvhq69Cm8CXX4bbL74IhevXX1f8miiCDh1gr73WHh06uCqBJEnaOHtqC1Bt5G7l\nSpgxA6ZOXXv8+9+hmJ0+PaxIUJF69UKrwG67hZ26ysZuu8Fmm9UorILg3JOS4dxLN/OX/zxTq/Ws\nWAGzZsHMmaGfdfr0cLvm/ZkzoaRkw+9Rpw60axfOsq45dtkF2rcPa8RKkiTVFntqC0RpKcybF4rV\n2bPXvv3mmzBmzgy3c+Zs+CzrmrbdNhSu64727WH77e15lSpiT60kVZk9tfmqpATmzw9LXa05vvsO\nvv127TFnTvn9ii7CqkhREbRqBT/6EbRtG0abNuW3bdpA69bQsGFm/56SJEmVZU9tlq1YAYsWwcKF\nYXz/fdhI4Pvvy8eCBWHMn7/2WLAgnG1dsKB6f3azZqFYbdCgmJ137karVuF4m23CWdcf/SjcbrWV\n7QG5zLmXjFGjRnHZZZdRUlLCOeecw9VXX73ecy655BJeeeUVGjduzKOPPkrnzp0TiFSZ4txLN/OX\n/6pdunzwwQd59+EoLQ1X7S9ZEsaa98vGDz9UPBYvLh+LFq19vHBheSG7bFnN44wi2GILaNEijObN\ny+9vtRVsuWXFo+zM6qBBH3DZZd1qHogSkY9zL9eVlJRw0UUX8dprr9G6dWv2228/jj/+eHbbbbfV\nzxk5ciRTpkzhiy++4J133uH8889n/PjxCUat2ubcSzfzl15RFHWL47h4U8+rdlE7f/78jX69tDT8\nuntDY8WK9e+vWLHxsXz5+sfrjmXL1r6/7ih7fOnS9ceKFdX916i8oiJo2jSMzTeveDRtGorWDY1m\nzWq2xNWmcqfcZv6yb8KECXTo0IF27doBcPLJJ/PCCy+sVdSOGDGCM888E4AuXbowf/58Zs2aRatW\nrZIIWRng3Es385dq3YDiTT2p2kXtwIHwv/8b+jtLSkJRuub9tGrYMCz236jR2vcbNYLGjcNY837Z\ncZMmYTmqsrHmcVkR26RJeE93vZLSZcaMGbRt23b1cZs2bXjnnXc2+Zzp06db1EpSllS7qF28eCqL\nF2/460VF4Wxi3brlo06dcEV82fGa98uONzbq11//fv360KBBuF13NGiw4dGwYfkoO65fvzAKzqlT\npyYdgmrA/GVfVMlvDOuualDZ1ykdnHvpZv7yX7WX9IqiyDVpJEmSlHFxHG/yLEFN1qmVpIIQRVFd\nYDLwc2AmMAE4JY7jT9d4TnfgojiOu0dRtD8wKI7j/RMJWJIKkAs3SdImxHG8Moqii4DRQB1gSBzH\nn0ZR1HfV1x+K43hkFEXdoyiaAiwG+iQYsiQVHM/USpIkKfWKavLiKIoujqLo0yiKPo6i6I7aCkrZ\nE0XRFVEUlUZR1CLpWFR5URTdtWrufRhF0XNRFDVLOiZtXBRFR0VR9FkURV9EUbT+zg3KWVEUtY2i\n6O9RFE1a9f/dJUnHpKqJoqhOFEXvR1H0YtKxqGqiKNoiiqJnV/2f98mq9q4KVbuojaLoEOB44Mdx\nHO8JDKzueykZURS1BQ4H/pN0LKqyV4E94jjuCHwOXJtwPNqIKIrqAPcDRwG7A6dEUbTbxl+lHLIC\nuDyO4z2A/YELzV/qXAp8Avjr6fT5PTAyjuPdgB8Dn27oiTU5U3s+cFscxysA4jieU4P3UjLuAa5K\nOghVXRzHY+I4Ll11+A7QJsl4tEk/AabEcTx11ffMp4AeCcekSorj+L9xHH+w6v4iwn+q2yYblSor\niqI2QHfgT4Dr7KXIqt9CHhjH8Z8hXN8Qx/GCDT2/JkXtTsBBURSNj6KoOIqifWvwXsqyKIp6ANPj\nOP4o6VhUY2cBI5MOQhvVGpi2xvH0VY8pZaIoagd0JvwwqXS4F7gSKN3UE5Vz2gNzoih6JIqi96Io\nejiKosYbevJGVz+IomgMsE0FX7pu1Wubx3G8fxRF+wFPAzvUIHDVsk3k71rgiDWfnpWgVGkbyV+/\nOI5fXPWc64DlcRw/mdXgVFX+yjMPRFHUBHgWuHTVGVvluCiKjgVmx3H8fhRF3ZKOR1VWF9ibsFzi\nu1EUDQKuAX63oSdvUBzHh2/oa1EUnQ88t+p576662KhlHMffVTt01aoN5S+Koj0JP/18uGrHozbA\nxCiKfhLH8ewshqiN2Nj8A4iiqDfhV2o/z0pAqokZQNs1jtsSztYqJaIoqgf8FXg8juPnk45HlXYA\ncPyqdaQbAptHUfSXOI7PSDguVc50wm+V3111/CyhqK1QTdoPngcOBYiiaGegvgVtOsRx/HEcx63i\nOG4fx3F7wodmbwva9Iii6CjCr9N6xHG8NOl4tEn/BHaKoqhdFEX1gV7AiIRjUiVF4af/IcAncRwP\nSjoeVV4cx/3iOG676v+6k4E3LGjTI47j/wLTVtWZAIcBkzb0/JpsvvBn4M9RFP0LWA74IUkvfzWa\nPvcB9YExq862vx3H8QXJhqQN2dDmDQmHpcr7GXAa8FEURe+veuzaOI5HJRiTqsf/79LnYuCJVScE\nvmQjG9u4+YIkSZJSr0abL0iSJEm5wKJWkiRJqWdRK0mSpNSzqJUkSVLqWdRKkiQp9SxqJUmSlHoW\ntZIkSUo9i1pJkiSl3v8H7tR/QglpVT0AAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X = np.linspace(-6, 6, 100)\n", "X = X[np.newaxis,:]\n", "\n", "plt.figure(figsize=(12,5))\n", "\n", "plt.plot(X.flatten(), sigmoid(X).flatten(), linewidth=2)\n", "\n", "# 美化图像的操作\n", "#=========================\n", "plt.grid('on')\n", "plt.yticks([0,0.5,1])\n", "\n", "ax = plt.gca()\n", "ax.spines['right'].set_color('none')\n", "ax.spines['top'].set_color('none')\n", "\n", "ax.yaxis.set_ticks_position('left')\n", "ax.spines['left'].set_position(('data', 0))\n", "\n", "plt.legend([r'$s(x)=\\frac{1}{1+e^{-x}}$'], loc=0, fontsize=20)\n", "#=========================\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## sigmoid 函数与 tanh 函数的关系" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sigmoid` 函数与 `tanh` 之间有如下的转化关系:\n", "$$s(x)=\\frac{1}{1+e^{-x}}=\\frac{1+\\tanh(x/2)}{2}$$" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.5 , 0.7310586 ],\n", " [ 0.26894143, 0.11920291]], dtype=float32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s2 = (1 + T.tanh(x / 2)) / 2\n", "\n", "sigmoid2 = theano.function([x], s2)\n", "\n", "sigmoid2([[ 0, 1],\n", " [-1,-2]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## logistic 回归" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "简单的二元逻辑回归问题可以这样描述:我们要对数据点 $x = (x_1, ..., x_n)$ 进行 0-1 分类,参数为 $w = (w_1, ..., w_n), b$,我们的假设函数如下:\n", "\n", "$$\n", "\\begin{align}\n", "h_{w,b}(x) & = P(Y=1|X=x) \\\\\n", "& = sigmoid(z) \\\\\n", "& =\\frac{1}{1 + e^{-z}}\\\\\n", "\\end{align}\n", "$$\n", "\n", "其中\n", "\n", "$$\n", "\\begin{align}\n", "z & = x_1w_1 + ... + x_nw_n + b\\\\\n", "& = w^T x + b\\\\\n", "\\end{align}\n", "$$\n", "\n", "对于一个数据点 $(x, y), y\\in \\{0,1\\}$ 来说,我们的目标是希望 $h_{w,b}(x)$ 的值尽量接近于 $y$。\n", "\n", "由于数值在 0-1 之间,我们用交叉熵来衡量 $h_{w,b}(x)$ 和 $y$ 的差异:\n", "\n", "$$- y \\log(h_{w,b}(x)) - (1-y) \\log(1-h_{w,b}(x))$$\n", "\n", "对于一组数据,我们定义损失函数为所有差异的均值,然后通过梯度下降法来优化损失函数,得到最优的参数 $w, b$。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 实例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "生成随机数据:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "rng = np.random\n", "\n", "# 数据大小和规模\n", "N = 400\n", "feats = 784\n", "\n", "# D = (X, Y)\n", "D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义 `theano` 变量:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = T.matrix('x')\n", "y = T.vector('y')\n", "\n", "# 要更新的变量:\n", "w = theano.shared(rng.randn(feats), name='w')\n", "b = theano.shared(0., name='b')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义模型:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "h = 1 / (1 + T.exp(-T.dot(x, w) - b))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当 $h > 0.5$ 时,认为该类的标签为 1:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "prediction = h > 0.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "损失函数和梯度:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "cost = - T.mean(y * T.log(h) + (1 - y) * T.log(1 - h)) + 0.01 * T.sum(w ** 2) # 正则项,防止过拟合\n", "gw, gb = T.grad(cost, [w, b])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "编译训练和预测函数:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "train = theano.function(inputs=[x, y],\n", " outputs=cost,\n", " updates=[[w, w - 0.1 * gw], [b, b - 0.1 * gb]], \n", " allow_input_downcast=True)\n", "\n", "predict = theano.function(inputs=[x],\n", " outputs=prediction,\n", " allow_input_downcast=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter 0, error 19.295896\n", "iter 1000, error 0.210341\n", "iter 2000, error 0.126124\n", "iter 3000, error 0.124872\n", "iter 4000, error 0.124846\n", "iter 5000, error 0.124845\n", "iter 6000, error 0.124845\n", "iter 7000, error 0.124845\n", "iter 8000, error 0.124845\n", "iter 9000, error 0.124845\n", "iter 10000, error 0.124845\n" ] } ], "source": [ "for i in xrange(10001):\n", " err = train(D[0], D[1])\n", " if i % 1000 == 0:\n", " print 'iter %5d, error %f' % (i, err)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看结果:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 0 1 1 0 1 1 1 0 0 1 1 1 1 1 0 1 1 1 0 1 1 0 1 0 1 0 0 1 1 0 0 0 0 1 0\n", " 1 1 0 0 0 0 1 0 0 1 1 0 1 1 1 0 0 0 1 0 0 1 1 0 1 0 1 1 1 0 0 0 0 0 1 0 0\n", " 0 1 0 1 0 0 1 1 0 1 0 0 0 1 0 1 1 1 0 1 1 0 1 0 0 1 0 0 1 0 1 1 1 1 0 1 0\n", " 0 0 0 1 0 1 0 1 1 0 1 0 1 0 1 0 0 0 1 0 0 0 1 0 1 1 0 1 0 1 1 0 0 0 0 0 1\n", " 1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 0 1 0 0 1 0 1 1 0 1 1 0 0 1 1 1 1 1\n", " 1 0 1 0 0 1 0 0 1 1 1 1 0 1 0 1 0 1 1 1 1 0 1 0 0 1 1 1 0 0 0 1 0 0 0 1 0\n", " 1 0 1 0 0 0 0 0 1 1 1 0 0 1 1 0 1 1 0 0 1 0 1 1 1 1 1 1 0 0 0 1 1 1 0 1 1\n", " 0 1 0 0 1 0 1 0 1 0 0 1 0 1 0 0 0 0 0 0 0 1 1 1 1 1 0 1 1 0 0 1 1 1 1 1 1\n", " 0 0 1 1 0 0 0 1 0 1 1 0 1 0 0 1 0 0 0 1 0 1 1 1 0 1 0 1 0 1 0 1 0 1 0 1 0\n", " 0 1 0 0 0 0 0 1 1 0 1 1 0 1 1 0 0 1 0 1 1 0 0 1 1 1 0 1 0 1 1 1 0 1 0 0 1\n", " 0 1 0 1 0 0 1 0 0 1 1 1 0 1 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 0]\n" ] } ], "source": [ "print D[1]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 0 1 1 0 1 1 1 0 0 1 1 1 1 1 0 1 1 1 0 1 1 0 1 0 1 0 0 1 1 0 0 0 0 1 0\n", " 1 1 0 0 0 0 1 0 0 1 1 0 1 1 1 0 0 0 1 0 0 1 1 0 1 0 1 1 1 0 0 0 0 0 1 0 0\n", " 0 1 0 1 0 0 1 1 0 1 0 0 0 1 0 1 1 1 0 1 1 0 1 0 0 1 0 0 1 0 1 1 1 1 0 1 0\n", " 0 0 0 1 0 1 0 1 1 0 1 0 1 0 1 0 0 0 1 0 0 0 1 0 1 1 0 1 0 1 1 0 0 0 0 0 1\n", " 1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 0 1 0 0 1 0 1 1 0 1 1 0 0 1 1 1 1 1\n", " 1 0 1 0 0 1 0 0 1 1 1 1 0 1 0 1 0 1 1 1 1 0 1 0 0 1 1 1 0 0 0 1 0 0 0 1 0\n", " 1 0 1 0 0 0 0 0 1 1 1 0 0 1 1 0 1 1 0 0 1 0 1 1 1 1 1 1 0 0 0 1 1 1 0 1 1\n", " 0 1 0 0 1 0 1 0 1 0 0 1 0 1 0 0 0 0 0 0 0 1 1 1 1 1 0 1 1 0 0 1 1 1 1 1 1\n", " 0 0 1 1 0 0 0 1 0 1 1 0 1 0 0 1 0 0 0 1 0 1 1 1 0 1 0 1 0 1 0 1 0 1 0 1 0\n", " 0 1 0 0 0 0 0 1 1 0 1 1 0 1 1 0 0 1 0 1 1 0 0 1 1 1 0 1 0 1 1 1 0 1 0 0 1\n", " 0 1 0 1 0 0 1 0 0 1 1 1 0 1 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 0]\n" ] } ], "source": [ "print predict(D[0])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }