{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import os\n", "os.chdir('../')\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 一.最大熵原理\n", "最大熵的思想很朴素，即将已知事实以外的未知部分看做“等可能”的，而熵是描述“等可能”大小很合适的量化指标，熵的公式如下： \n", "\n", "$$\n", "H(p)=-\\sum_{i}p_i log p_i\n", "$$ \n", "\n", "这里分布$p$的取值有$i$种情况，每种情况的概率为$p_i$，下图绘制了二值随机变量的熵：" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "p=np.linspace(0.1,0.9,90)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def entropy(p):\n", " return -np.log(p)*p-np.log(1-p)*(1-p)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl4VPXZ//H3nT2BJIQkQEgCCRiWsEtAEFFxQUAFLFahVrF1bV3ap62/qrXaR32etnaxy0OriFrrhkhbREURFURQliAESCAQwpKQAGHLQsh+//7IaKcYyASSnJnM/bquXM6c+Z7MhxE+OTnL94iqYowxxj8EOB3AGGNM+7HSN8YYP2Klb4wxfsRK3xhj/IiVvjHG+BErfWOM8SNW+sYY40es9I0xxo9Y6RtjjB8JcjrAqeLi4jQlJcXpGMYY41M2bNhwWFXjmxvndaWfkpJCZmam0zGMMcaniMheT8bZ7h1jjPEjHpW+iEwSkVwRyRORB5t4/WkR2eT62iEix91emy0iO11fs1szvDHGmJZpdveOiAQCc4ArgUJgvYgsVtWcL8eo6n+5jb8PGOF63BV4DMgAFNjgWvdYq/4pjDHGeMSTLf3RQJ6q5qtqDTAfmHaG8bOA112PrwKWqepRV9EvAyadS2BjjDFnz5PSTwQK3J4XupZ9jYj0BlKBj1u6rjHGmLbnSelLE8tOd+eVmcBCVa1vyboicqeIZIpIZklJiQeRjDHGnA1PSr8QSHZ7ngQUnWbsTP69a8fjdVV1rqpmqGpGfHyzp5kaY4w5S56cp78eSBORVGA/jcX+rVMHiUh/IAb43G3xUuB/RSTG9Xwi8NA5JTamjdTWN1BSXs3BsioOllVTVlXLieo6TlTXUVPX8O+BIoQHB9I5NJBOoUHEdAqhe2QY3aNCiYkIISCgqV9wjfEOzZa+qtaJyL00Fngg8IKqZovI40Cmqi52DZ0FzFe3m+6q6lEReYLGHxwAj6vq0db9IxjTMtV19ew4UMHm/cfZebCCXSUV5JecoKj0JGe6ZbS4uvxMY0ICA+gdG0Gf+E70je/MgIQohiZG0zs2AhH7YWCcJ952Y/SMjAy1K3JNayqrqmX97qN8vusI6/YcZXtxOTX1jVvunUIC6RPfmT7xnegd24mE6MYt9m6RYUSHB9MpNIhOoYGEBgV+9f1UlZO19VRU13Giup6jJ2o4VFbFwbIqikqryC85Qf7hCvYdqaSuofHfV1RYEMOSuzCmTyxj+8YyJDGa4EC7NtK0HhHZoKoZzY3zumkYjDlXqsqukgqW5Rziw20H2bjvGA0KIUEBnN+rC9+9KJWhSdEMSYwmKSa8xVvgIkJESBARIUEQCalxnZocV1vfwI6D5WwpLGXz/lK+2HuM3yzNBRp/2FyUFseV6T24bEA3unYKOec/tzGesC1902Hkl1SwaFMRb2cVsfvwCQAGJ0YxoX83xvaN5fxeMYQFBzbzXdrWkYpq1u4+yqq8w3y87RAHyqoIEBiV0pXpIxKZMjiB6IhgRzMa3+Tplr6VvvFpFdV1LNq4nzczC8gqLEUExvaJZfKQBK4Y2I2E6HCnI56WqrJ1fxnLcg7wzpZi8ktOEBIYwIQB8cwc3YtL0uLtoLDxmJW+6dC2FZfxypq9LNq4nxM19QzoEcmM85O4dlhPekSHOR2vxVSVLftLWbSxiMVZ+zlcUUNy13Bmje7FjRnJxHYOdTqi8XJW+qbDUVU+23WEZz7Zxac7DxMaFMA1Q3vy7TG9GJ7cpcOcHVNT18DS7AO8smYva3cfJTQogBsykrl9fCq9Y5s+fmCMlb7pMFSVD3IO8n8f57FlfylxnUP5zrgUvjW6FzEd/ADozoPlzPt0N//cWEh9gzJ5SAI/uDyNft0jnY5mvIyVvvF5qsqKHSX8/oMdbNlfSkpsBHdd0pfrRiQ6fkC2vR0sq+LF1Xt4+fM9VNbWM3VYT354Rb/Tnjlk/I+VvvFpWQXHefLdHNbvOUZSTDj3X57GN0YkEuTn57YfO1HDsyvzeemzPdTUN3BDRjI/ntiPONvn7/es9I1PKi49yVPv5/KvjfuJ6xzKD69I44aMZEKC/LvsT1VSXs2c5Xm8smYvYcGB3DPhPL4zLsXvfgMy/2alb3xKTV0D81bl86ePdtKgcPtFqXx/wnl0DrXrB89kV0kFv1yyjQ+3HSK5aziPTxvMhP7dnI5lHGClb3zGut1HeWTRFnYcrOCqQd155Op0krtGOB3Lp6zaeZhHF28lv+QEVw9J4NFr0+ke5XunrpqzZ6VvvF55VS3/u2Qbr68rILFLOI9PG8TlA7s7HctnVdfVM/eTfP68PI+QwAAenjKQWaOTO8yprObMrPSNV1udd5j/t3AzxaUnuX18H354RVrjXDbmnO05fIKH/7WFz3YdYXxaHL+eMZSeXbz3ymTTOjwtfTs6ZtrVyZp6fr5oKzfNW0toUABv3n0hD08ZaIXfilLiOvHKbRfwxPTBbNh7jKueXsnCDYV42waecYb9SzPtZltxGfe/vpG8kgpuuyiVB67qb2ebtJGAAOHmMb25JC2enyzM4idvZrFyRwlPXjeYqDCb0M2fWembNqeqvLxmL0++u43o8GBe/u4FXJQW53Qsv9ArNoLX7xjDM5/s4vfLdrCx4Bh/mjmCEb1iml/ZdEi2e8e0qfKqWr7/6hc8+lY24/rG8v4Pxlvht7PAAOGeCeex4K6xNDTAN5/5nOdX7bbdPX7KSt+0mZ0Hy5k2ZzUf5Bzk4SkDeOHWUTZbpING9o5hyQ/Gc9mAbjzxTg73vb6RE9V1Tscy7cyj0heRSSKSKyJ5IvLgacbcICI5IpItIq+5La8XkU2ur8VNrWs6nnc3FzNtzmrKTtby6u0XcOfFfe3UQS8QHR7MszeP5KeTBrBkSzHT56xmV0mF07FMO2r2lE0RCQR2AFcChTTe5HyWqua4jUkDFgCXqeoxEemmqodcr1WoamdPA9kpm76toUH5w0c7+dNHOxnZO4Y53zrfJ+e39wer8w5z3+sbqa1v4C83nc/4tHinI5lz0JqnbI4G8lQ1X1VrgPnAtFPG3AHMUdVjAF8WvvEvJ2vque/1jfzpo518c2QSr98xxgrfi407L47F944jsUs4t764npc/3+N0JNMOPCn9RKDA7Xmha5m7fkA/EVktImtEZJLba2EikulaPr2pNxCRO11jMktKSlr0BzDe4VBZFTc8+zlLthbz8JQBPHX9UJskzQckxUSw8HsXcmm/eH7+VjaPvrWVuvoGp2OZNuTJKZtN7Yg9dZ9QEJAGXAokAZ+KyGBVPQ70UtUiEekDfCwiW1R11398M9W5wFxo3L3Twj+DcVjeoQpmv7COY5U1PHdzBlek21QKvqRzaBBzb8ngV+9t47lPd1N0vIo/zxpBeIhdQ9ERebIpVggkuz1PAoqaGPOWqtaq6m4gl8YfAqhqkeu/+cAKYMQ5ZjZeZMPeo1z/zGdU19Xzxp1jrfB9VGCA8LOr0/nvqYP4aPtBbpq3hmMnapyOZdqAJ6W/HkgTkVQRCQFmAqeehbMImAAgInE07u7JF5EYEQl1Wz4OyMF0CB9kH+Bbz60lJiKEf35vHEOSop2OZM7R7AtT+OtN57O1qIwZf/2MgqOVTkcyrazZ0lfVOuBeYCmwDVigqtki8riITHUNWwocEZEcYDnwgKoeAQYCmSKS5Vr+K/ezfozv+ucXhXzv1S8YkBDFwrvH0ivWpkLuKCYNTuDV2y/gcEU133zmc/IO2SmdHYnNsmla7OU1e/n5oq1c2DeW527JoJPd6KRD2lZcxs3Pr0UVXr7tAtJ7RjkdyZyBzbJp2sSzn+zi54u2cvmAbrxw6ygr/A5sYEIUb9w1lpCgAGbO/Zwv9h1zOpJpBVb6xmNzlufxy/e2c83QBJ65eaTNkOkH+sZ35s27xxLTKYSb561lw96jTkcy58hK33hkzvI8frM0l+tGJPLHmSMIDrS/Ov4iKSaCBXeNpVtUGLNfWM+GvbbF78vsX65p1l9WNBb+9OE9+e03hxEYYHPo+JvuUWG8fscY4iNDmf3COtvV48Os9M0ZzV25i6fez2Xa8J787obhVvh+rEd0Y/HHdQ5h9vPryCo47nQkcxas9M1pvbp2L/+7ZDtXD03gd7aFb3AV/51j6NIpmNkvriP3QLnTkUwLWembJr21aT+PLNrKZQO68fQNwwmyffjGJSE6nFdvG0NIYADffn4te4+ccDqSaQH7l2y+5sOcg/xoQRYXpHblLzedbxOnma/pFRvBq7dfQF19AzfNW0tx6UmnIxkP2b9m8x/W7znK91/7gsE9o5g3e5SdlmlOK617JH//7gUcr6zllufXUVpZ63Qk4wErffOVnQfLuf2lTJK6hPPid0bT2S68Ms0YkhTNc7dksPdIJXf8PZOq2nqnI5lmWOkbAA6UVjH7hXWEBAXw0ndH07VTiNORjI8Y2zeW390wjHV7jvJfb2yivsG7pnYx/8lK31BWVcutL66j9GQtL946iuSuNnmaaZlrh/XkkasH8t7WAzz+djbeNqeX+Tf7/d3P1dY3cM+rX5B3qIIXvzOKwYk2PbI5O7eP78OB0irmrdpNr9hO3HZRqtORTBOs9P2YqvKLxdl8uvMwT80YajfGNufs4SkDKThWyZPv5pASG8HlA+2mOt7Gdu/4sRdX7+HVtfu465I+3DAqufkVjGlGQIDw9I3DGdQzivte30hOUZnTkcwprPT91MfbD/LkuzlMTO/OT68a4HQc04FEhATx/OxRRIUFc/tL6zlUVuV0JOPGSt8P7TxYzn2vbWRgQhR/mDmcAJtewbSy7lFhzJudwfGTtdz1ygaq6+xUTm9hpe9nSk/WcufLGwgPCeS5WzKICLHDOqZtDE6M5nffHMbGfcd5dJGd0eMtPCp9EZkkIrkikiciD55mzA0ikiMi2SLymtvy2SKy0/U1u7WCm5arb1B+MH8jBUcr+ctNI+nZJdzpSKaDmzwkgXsnnMcbmQW8snaf03EMHpy9IyKBwBzgSqAQWC8ii91vcC4iacBDwDhVPSYi3VzLuwKPARmAAhtc69pk3A74/bJcVuSW8OT0wYxO7ep0HOMn/uvKfuQUl/Hfi7Pp3z3S/u45zJMt/dFAnqrmq2oNMB+YdsqYO4A5X5a5qh5yLb8KWKaqR12vLQMmtU500xLvby1mzvJdzBqdzE0X9HI6jvEjga4zenp1jeD7r27gQKkd2HWSJ6WfCBS4PS90LXPXD+gnIqtFZI2ITGrBuojInSKSKSKZJSUlnqc3Htl9+AQPvLmZYcld+MXUQYjYgVvTvqLDg3n25pFU1tRz3+tfUFvf4HQkv+VJ6TfVEKcekQkC0oBLgVnAPBHp4uG6qOpcVc1Q1Yz4eLtAqDVV1dbzvVc2EBgozPnWCEKDbNZM44y07pH88htDWL/nGL9Zmut0HL/lSekXAu5X7iQBRU2MeUtVa1V1N5BL4w8BT9Y1bejRt7aSe7CcP9w4nKQYm1PHOGva8ERuGdubuSvzeX/rAafj+CVPSn89kCYiqSISAswEFp8yZhEwAUBE4mjc3ZMPLAUmikiMiMQAE13LTDtYkFnAgsxC7ptwHpf27+Z0HGMA+NnVAxmWFM0Db2ax57Dddau9NVv6qloH3EtjWW8DFqhqtog8LiJTXcOWAkdEJAdYDjygqkdU9SjwBI0/ONYDj7uWmTa282A5j761lXHnxfKDK/o5HceYr4QGBTLnpvMJCBDun7+Rmjrbv9+exNsumMjIyNDMzEynY/i0qtp6ps9ZTUl5Ne/9YDzdosKcjmTM1yzNPsBdL2/gjvGp/OzqdKfj+DwR2aCqGc2NsytyO6An381h+4FyfnfDMCt847WuGtSDW8b25rlPd7M891DzK5hWYaXfwby/tZhX1uzjzov72H584/UenjKQAT0i+cmCLJuYrZ1Y6XcgRcdP8v8WbmZYUjQ/mdjf6TjGNCssOJD/+9YIKmvq+a8Fm2iwWy22OSv9DqKhQfnxgizqGpQ/zhxBSJD9rzW+4bxukTx2bTqr847wwurdTsfp8KwZOogXVu/m8/wjPHZtOilxnZyOY0yL3DgqmSsGdueppbnkHih3Ok6HZqXfAeQeKOeppblcmd6dGzLsDljG94gIv5oxhKiwIH74xiabf78NWen7uOq6en4wfyNRYUH88htDbF4d47PiOofy6xlD2VZcxu+X7XA6Todlpe/jnl62k+0Hynnq+qHEdQ51Oo4x5+Tygd2ZNboXc1fms263XcfZFqz0fdgX+44xd+UuZo5K5rIB3Z2OY0yreOTqgSTHRPDAwiwqa+qcjtPhWOn7qKraeh54M4seUWH87OqBTscxptV0Cg3i1zOGsvdIJU+9b7NxtjYrfR/19LId7Co5wa9mDCUyLNjpOMa0qrF9Y5k9tjd/+2wPa/OPOB2nQ7HS90Eb9h7juU/zmTU6mYv72f0HTMf008kD6NU1ggcWbrbdPK3ISt/HVNXW88DCLBKiw3l4iu3WMR1XREgQT10/lH1HbTdPa7LS9zFzlueRX3KCX35jiO3WMR3emD6Nu3le+nwPX+w75nScDsFK34dsP1DGX1fs4hvnJ9puHeM3Hpg0gB5RYTz4j802934rsNL3EfUNyoP/2EJ0eDA/t7nHjR/pHBrEk9MHs+NgBc9+ssvpOD7PSt9H/P3zPWwqOM6j16YT0ynE6TjGtKvLB3bnmqEJ/PnjPPIOVTgdx6d5VPoiMklEckUkT0QebOL1W0WkREQ2ub5ud3ut3m35qffWNR7Yf/wkv1may6X945k6rKfTcYxxxGPXDiI8JJCH/7nFpmA+B82WvogEAnOAyUA6MEtEmtq/8IaqDnd9zXNbftJt+dQm1jPNeOytbFThyemDbW4d47fiI0P52dUDWbfnKG9uKHA6js/yZEt/NJCnqvmqWgPMB6a1bSzzpWU5B/lw20F+eEUaSTERTscxxlHfHJnE6JSu/PK97Rw9UeN0HJ/kSeknAu4/Vgtdy041Q0Q2i8hCEXGf3zdMRDJFZI2ITD+XsP6msqaOXyzOpl/3znz3olSn4xjjOBHhiemDqaiq49fvbXc6jk/ypPSb2p9w6g61t4EUVR0KfAi85PZaL9cd2r8F/EFE+n7tDUTudP1gyCwpKfEwesf354/z2H/8JE9OH0JwoB1zNwagf49IbrsolTcyC8jcYzNxtpQnTVIIuG+5JwFF7gNU9YiqVruePgeMdHutyPXffGAFMOLUN1DVuaqaoaoZ8fF2/jnAzoPlPLcyn+tHJjE6tavTcYzxKvdfnkbP6DAeWbSV2no7d78lPCn99UCaiKSKSAgwE/iPs3BEJMHt6VRgm2t5jIiEuh7HAeOAnNYI3pGpKj9/ayudQoN4aPIAp+MY43U6hQbx2NRBbD9Qzkuf7XE6jk9ptvRVtQ64F1hKY5kvUNVsEXlcRL48G+d+EckWkSzgfuBW1/KBQKZr+XLgV6pqpd+MdzYXsyb/KA9c1Z9YuzGKMU2amN6dCf3j+cOHOzlUXuV0HJ8hqt51vmtGRoZmZmY6HcMxlTV1XP67T4iJCOHt+y4iMMBO0TTmdPJLKrjqDyuZNjyR335zmNNxHCUiG1zHT8/Ijg56mb+u2EVxaRX/PW2QFb4xzegT33hm28INhWy0Cdk8YqXvRfYdqeTZlflMG96TUSl28NYYT9x3WRrdIkP5xeJsu1LXA1b6XuSJd3MIChAemmzz5Bvjqc6hQTw0ZQBZhaUs3FDodByvZ6XvJVbtPMyynIPcM+E8ekSHOR3HGJ8yfXgiI3vH8NTS7ZRX1Todx6tZ6XuB+gblyXdzSO4azm125a0xLSYiPHpNOocravjrCpt++Uys9L3Am5kFbD9QzoOTBhIWHOh0HGN80rDkLlw3IpF5q3ZTeKzS6They0rfYRXVdfz2gx1k9I5hypAeTscxxqc9cFV/AgS7p+4ZWOk77JkVuzhcUc0j16TbtMnGnKOeXcK5c3wfFmcV2T11T8NK30H7j5/kuU/zmT68J8OTuzgdx5gO4a5L+hIfGcqT7+TgbRefegMrfQf95v3GqWEfmGTz6xjTWjqFBvHAxP58se84724pdjqO17HSd8jW/aUs2lTEdy9KJbFLuNNxjOlQZoxMYkCPSH6zNNdm4TyFlb5DnlqaS5eIYO6+5Gu3FzDGnKPAAOGnkwaw90gl89ftczqOV7HSd8DqvMOs3FHCvRPOIzo82Ok4xnRIl/aP54LUrvzxo51UVNc5HcdrWOm3s4YG5VfvbSexSzg3j+3tdBxjOiwR4aEpAzlcUcO8T/OdjuM1rPTb2btbitmyv5QfT+xHaJBdiGVMWxqe3IUpQ3rw3Mp8Ssqrm1/BD1jpt6OaugZ++0EuA3pEMm14U/eWN8a0tp9M7E9VXQN//nin01G8gpV+O1qQWcDeI5X8dNIAmyvfmHbSJ74zN45K5vV1+yg4atMzWOm3k6raev788U5G9o7h0v5283dj2tN9l52HiPCnj2xr36PSF5FJIpIrInki8mATr98qIiUissn1dbvba7NFZKfra3Zrhvclr6zZy8Gyan4ysb9Nt2BMO0uIDufmMb35xxeF7CqpcDqOo5otfREJBOYAk4F0YJaIpDcx9A1VHe76mudatyvwGHABMBp4TERiWi29jzhRXcdfV+ziovPiGNs31uk4xvil713al7DgQJ5etsPpKI7yZEt/NJCnqvmqWgPMB6Z5+P2vApap6lFVPQYsAyadXVTf9eLq3Rw5UcNPrurvdBRj/FZc51C+Oy6VdzYXk1NU5nQcx3hS+olAgdvzQteyU80Qkc0islBEklu4bodVWlnLsyvzuWJgd5tUzRiH3XFxH6LCgvj9Mv+detmT0m9qB/SpU9e9DaSo6lDgQ+ClFqyLiNwpIpkikllSUuJBJN8xb1U+5VV1/OjKfk5HMcbvRYcHc+fFffhw2yE2FRx3Oo4jPCn9QiDZ7XkSUOQ+QFWPqOqXVz48B4z0dF3X+nNVNUNVM+LjO86ZLccra3hx9R6mDOlBes8op+MYY4Bbx6XSJSKYP37on/v2PSn99UCaiKSKSAgwE1jsPkBEEtyeTgW2uR4vBSaKSIzrAO5E1zK/8MKq3VRU13H/5WlORzHGuHQODeKO8X1YnltClh9u7Tdb+qpaB9xLY1lvAxaoaraIPC4iU13D7heRbBHJAu4HbnWtexR4gsYfHOuBx13LOrzSylpeXL2HyYN7MKCHbeUb401mX5jSuLXvh+ftB3kySFWXAEtOWfao2+OHgIdOs+4LwAvnkNEnPb96N+W2lW+MV/pya/83S3PZXHicoUn+c5KFXZHbBkora3lx1W4mDerBwATbyjfGG90ytjfR4cH88UP/2tq30m8DL9hWvjFeLzIsmDvGp/LR9kNsKSx1Ok67sdJvZWVVtbywejdXDepuZ+wY4+VmX5hCdHgwf/KjGTit9FvZK2v2Ul5Vx70TbCvfGG8XGRbMrRemsCznILkHyp2O0y6s9FvRyZp6nv90N5f0i2dIUrTTcYwxHvjOuBQiQgL5y4o8p6O0Cyv9VvTG+n0cOVHDPRPOczqKMcZDXSJC+PaY3rydVcTeIyecjtPmrPRbSU1dA8+uzGd0SldGp3Z1Oo4xpgVuvyiVoMAAnvlkl9NR2pyVfitZtHE/xaVVfH9CX6ejGGNaqFtUGDdkJLFwQyEHSqucjtOmrPRbQX2D8tdPdjE4MYpL+nWcuYOM8Sd3XdyXBoW5K/OdjtKmrPRbwXtbi9l9+AT3XHqe3RXLGB+V3DWCacN78tq6vRw9UeN0nDZjpX+OVJVnP8mnT1wnJg7q4XQcY8w5uPuSvlTVNvDy53udjtJmrPTP0ee7jrBlfym3j+9DYIBt5Rvjy/p1j+SyAd146fM9nKypdzpOm7DSP0fPrswnrnMI3zjfr24IZkyHddfFfTh6ooaFXxQ6HaVNWOmfg23FZXyyo4TvjEslLDjQ6TjGmFYwOrUrw5O7MO/TfOobvnajP59npX8OnluZT0RIIN++oLfTUYwxrUREuOviPuw9UsnS7ANOx2l1Vvpnaf/xkyzOKmLmqF5ERwQ7HccY04omDupBSmwEz36yC9WOtbVvpX+WXli1GwVuG5/qdBRjTCsLDBDuuLgPWYWlrMnvWDf7s9I/C+VVtbyxvoBrhiaQ2CXc6TjGmDYw4/wkYjuF8PyqjnWxlkelLyKTRCRXRPJE5MEzjLteRFREMlzPU0TkpIhscn0901rBnfTG+gIqquu47SLbyjemowoLDuSmC3rx0fZD7D7ccSZia7b0RSQQmANMBtKBWSKS3sS4SBpvir72lJd2qepw19fdrZDZUfUNyt8+28OolBi/uq+mMf7o22N7ExwQwIurdzsdpdV4sqU/GshT1XxVrQHmA9OaGPcE8BTQoWcr+iD7AIXHTtpWvjF+oFtkGFOH9+TNzEJKK2udjtMqPCn9RKDA7Xmha9lXRGQEkKyq7zSxfqqIbBSRT0Rk/NlH9Q7Pr9pNctdwrky3KReM8QffHZfKydp6Xl+/z+korcKT0m9qboGvzmESkQDgaeDHTYwrBnqp6gjgR8BrIvK1G8eKyJ0ikikimSUlJZ4ld0BWwXEy9x7j1gtTbcoFY/xEes8oLuwby0uf7aG2vsHpOOfMk9IvBJLdnicBRW7PI4HBwAoR2QOMARaLSIaqVqvqEQBV3QDsAvqd+gaqOldVM1Q1Iz7ee6cmfn7VbjqHBnFDRpLTUYwx7ei2i1IpLq3iva2+f7GWJ6W/HkgTkVQRCQFmAou/fFFVS1U1TlVTVDUFWANMVdVMEYl3HQhGRPoAaYBPnv9UXHqSJVuKuXFUMpFhdjGWMf5kQv9u9InrxPOrfP+AbrOlr6p1wL3AUmAbsEBVs0XkcRGZ2szqFwObRSQLWAjcrao+eaXDq2v2Ua/K7LEpTkcxxrSzgADh1nEpZBUcZ1PBcafjnJMgTwap6hJgySnLHj3N2EvdHv8D+Mc55PMK1XX1vL5uH5cP6Eav2Ain4xhjHPCN85N46v1c/v7ZHobfONzpOGfNrsj1wLubizlyoobZF6Y4HcUY45DOoUFcPzKJdzYXc7ii2uk4Z81K3wMvfb6XPvFlRRUkAAAPGUlEQVSdGNc3zukoxhgH3Ty2NzX1Dby+1ndP37TSb8amguNkFRxn9tgUAuw0TWP8Wt/4zoxPi+PVtft89vRNK/1m/P2zPXQODWLGSDtN0xgDs8emcKCsig+yDzod5axY6Z/B4Ypq3tlczIzzE+kc6tExb2NMBzdhQDeSu4bz0md7nI5yVqz0z2D+un3U1Ddwix3ANca4BAYIN4/pzbo9R8kpKnM6TotZ6Z9GfYPy2tp9jDsvlr7xnZ2OY4zxIjdkJBMSFMCra/c6HaXFrPRPY0XuIYpKq+z+t8aYr+kSEcI1QxNYtHE/FdV1TsdpESv903hlzV66RYZyRXp3p6MYY7zQt8f05kRNPW9t2u90lBax0m9CwdFKVuwoYeaoZIID7SMyxnzdiOQuDEyI4pU1+3zq5unWaE2Yv34fAtw4upfTUYwxXkpEuOmCXmwrLmOjD83HY6V/ipq6Bt5YX8BlA7rZTc+NMWc0fUQinUICeXWN71yha6V/ig9yDnC4ooabxtgBXGPMmXUODWL6iETe2VzE8coap+N4xEr/FK+u2UdSTDgXp3nvzVyMMd7jpgt6U13XwMINhU5H8YiVvpv8kgo+zz/CrNG97HaIxhiPpPeMYnhyF+avL/CJA7pW+m4WZBYSGCBcb/PsGGNaYNboZPIOVfDFvmNOR2mWlb5LbX3jr2cT+neje1SY03GMMT7kmqE96RQSyPx1BU5HaZaVvsvH2w9xuKKamaOSmx9sjDFuOoUGce2wnryzuZjyqlqn45yRR6UvIpNEJFdE8kTkwTOMu15EVEQy3JY95FovV0Suao3QbeGN9QV0jwrl0v52ANcY03IzR/fiZG09i7OKnI5yRs2WvogEAnOAyUA6MEtE0psYFwncD6x1W5YOzAQGAZOAv7i+n1cpLj3JitxDfHNkMkF2Ba4x5iwMS4pmQI9I3ljv3bt4PGm40UCequarag0wH5jWxLgngKeAKrdl04D5qlqtqruBPNf38yoLMwtp0MaZ84wx5myICDeOSmZzYSnZRaVOxzktT0o/EXD/0VXoWvYVERkBJKvqOy1d12kNDcobmQWMOy+WXrERTscxxviw60YkEhIUwAIv3tr3pPSbOmH9q5NRRSQAeBr4cUvXdfsed4pIpohklpSUeBCp9Xy26wiFx05y4yibZ8cYc266RIQwaVAP/rVxP1W19U7HaZInpV8IuO/3SALcj1REAoOBFSKyBxgDLHYdzG1uXQBUda6qZqhqRnx8+x5IfXNDAdHhwUy0KZSNMa3gxlHJlFXV8UGOd95D15PSXw+kiUiqiITQeGB28ZcvqmqpqsapaoqqpgBrgKmqmukaN1NEQkUkFUgD1rX6n+IslVXV8v7WA0wd1pOwYK87vmyM8UFj+8TSMzrMa6dlaLb0VbUOuBdYCmwDFqhqtog8LiJTm1k3G1gA5ADvA/eoqtf8zvPu5mKq6xrsClxjTKsJCBBmjExi1c4SDpRWNb9CO/Po/ERVXaKq/VS1r6r+j2vZo6q6uImxl7q28r98/j+u9fqr6nutF/3cLdxQSFq3zgxNinY6ijGmA5lxfhINCv/c6H1b+357Unp+SQUb9h7j+pFJiNjkasaY1pMS14lRKTEs3FDodZOw+W3pL9xQSIA0nmJljDGt7fqRSeSXnPC6u2r5ZenXNyj//GI/l/SLp5tNrmaMaQNThiQQFhzgdQd0/bL0V+cd5kBZFdePtCtwjTFtIzIsmMmDE3g7q8irztn3y9JfuKGQ6PBgLh/YzekoxpgO7PqRSZRX1bE0+4DTUb7id6VfXlXL0uwDXDsswc7NN8a0qS/P2V+0cb/TUb7id6W/NPsg1XUNXDfCzs03xrStgABh2ohEVu48zOGKaqfjAH5Y+os27qdX1wjO79XF6SjGGD9w3YhE6huUd7xknn2/Kv2DZVWs3nWY6SMS7dx8Y0y76Nc9kvSEKP61yUq/3S3eVIQqTB/e0+koxhg/ct2IRLIKjpNfUuF0FP8q/X9u3M+w5C70ie/sdBRjjB+ZOrwnInjFAV2/Kf3tB8rYVlzGdbaVb4xpZ92jwhjXN45/bdrv+LQMflP6izYWERggXDPMSt8Y0/6mj0ik4OhJvth3zNEcflH6DQ3KW5sap12I6xzqdBxjjB+aNLgHYcEB/MvhXTx+Ufrr9hyluLSKabZrxxjjkM6hQUxM78E7m4uprW9wLIdflP7bWUWEBwdypd0S0RjjoKnDenK8spZVOw87lqHDl35tfQNLthRzRXp3IkKCnI5jjPFj4/vFERUWxNsOXqjV4Ut/dd5hjlXWMtUO4BpjHBYaFMjkwQl8kHPQsZk3PSp9EZkkIrkikiciDzbx+t0iskVENonIKhFJdy1PEZGTruWbROSZ1v4DNOftrGIiw4K4uF9ce7+1McZ8zbXDelJRXceK3EOOvH+zpS8igcAcYDKQDsz6stTdvKaqQ1R1OPAU8Hu313ap6nDX192tFdwTVbX1fJB9gEmDehAaZDNqGmOcN6ZPV+I6h7DYoV08nmzpjwbyVDVfVWuA+cA09wGqWub2tBPgFTeF/GRHCeXVdUy1s3aMMV4iKDCAq4ck8NG2Q1RU17X7+3tS+olAgdvzQtey/yAi94jILhq39O93eylVRDaKyCciMv6c0rbQ21lFxHYKYWyf2PZ8W2OMOaNrh/Wkuq6BD3MOtvt7e1L6TU1H+bUteVWdo6p9gZ8Cj7gWFwO9VHUE8CPgNRGJ+tobiNwpIpkikllSUuJ5+jOorKnjo22HmDIkgaDADn+82hjjQ87vFUPP6DBHzuLxpA0LAfebySYBZ0o6H5gOoKrVqnrE9XgDsAvod+oKqjpXVTNUNSM+Pt7T7Ge0LOcgJ2vrudbO2jHGeJkA15QwK3eWcLyypn3f24Mx64E0EUkVkRBgJrDYfYCIpLk9vRrY6Voe7zoQjIj0AdKA/NYI3px3NhfTIyqMjN4x7fF2xhjTIlOH9aS2Xnl/a/veP7fZ0lfVOuBeYCmwDVigqtki8riITHUNu1dEskVkE427cWa7ll8MbBaRLGAhcLeqHm31P8Upyqtq+WRHCVOGJBAQYDdLMcZ4n0E9o+jVNYIl7Vz6Hl2iqqpLgCWnLHvU7fEPTrPeP4B/nEvAs/Hx9kPU1DUwZUiP9n5rY4zxiIgweUgPnv90N8cra+gSEdIu79shj3C+t+UA3aNCOb+X7doxxnivq4ckUNegLGvHs3g6XOmfqK5jee4hJg+2XTvGGO82JDGaxC7hLNlS3G7v2eFKf3nuIarrGpg82HbtGGO8m4gwZUgPVuUdpvRkbbu8Z4cr/SVbiomPDCUjpavTUYwxpllThiRQW6/tdqFWhyr9ypo6lm8vYdKgHgTarh1jjA8YntyFntFhvLe1fXbxdKjS/yS3hJO19UwZkuB0FGOM8UjjWTwJrNxxmPKqtt/F06FK/90txcR2CmF0qu3aMcb4jilDelBT38BH29p+uuUOU/pVtfV8vP0QVw22XTvGGN8yIjmGHlFh7XIWT4e5f2DZyVquGNjd7pBljPE5AQHCzWN7U1nT9lMti6pXTH3/lYyMDM3MzHQ6hjHG+BQR2aCqGc2N6zC7d4wxxjTPSt8YY/yIlb4xxvgRK31jjPEjVvrGGONHrPSNMcaPWOkbY4wfsdI3xhg/4nUXZ4lICbD3HL5FHHC4leK0JsvVMparZSxXy3TEXL1VNb65QV5X+udKRDI9uSqtvVmulrFcLWO5Wsafc9nuHWOM8SNW+sYY40c6YunPdTrAaViulrFcLWO5WsZvc3W4ffrGGGNOryNu6RtjjDkNnyx9EZkkIrkikiciDzbx+sUi8oWI1InI9V6U60cikiMim0XkIxHp7UXZ7haRLSKySURWiUi6N+RyG3e9iKiItMsZFx58XreKSInr89okIrd7Qy7XmBtcf8+yReQ1b8glIk+7fVY7ROS4l+TqJSLLRWSj69/lFC/J1dvVEZtFZIWIJLXam6uqT30BgcAuoA8QAmQB6aeMSQGGAn8HrveiXBOACNfj7wFveFG2KLfHU4H3vSGXa1wksBJYA2R4Qy7gVuD/2uP/XwtzpQEbgRjX827ekOuU8fcBL3hDLhr3oX/P9Tgd2OMlud4EZrseXwa83Frv74tb+qOBPFXNV9UaYD4wzX2Aqu5R1c1Ag5flWq6qla6na4DW++l97tnK3J52AtrjYE+zuVyeAJ4CqtohU0tytTdPct0BzFHVYwCq2vZ32m755zULeN1LcikQ5XocDRR5Sa504CPX4+VNvH7WfLH0E4ECt+eFrmVOa2mu24D32jTRv3mUTUTuEZFdNBbs/d6QS0RGAMmq+k475PE4l8sM16/fC0Uk2Uty9QP6ichqEVkjIpO8JBfQuNsCSAU+9pJcvwC+LSKFwBIafwvxhlxZwAzX4+uASBGJbY0398XSlyaWecMpSB7nEpFvAxnAb9o0kdtbNrHsa9lUdY6q9gV+CjzS5qmaySUiAcDTwI/bIYs7Tz6vt4EUVR0KfAi81OapPMsVROMunktp3KKeJyJdvCDXl2YCC1W1vg3zfMmTXLOAv6lqEjAFeNn1987pXD8BLhGRjcAlwH6gVe6a7oulXwi4b1Ul0T6/kjXHo1wicgXwM2CqqlZ7UzY384HpbZqoUXO5IoHBwAoR2QOMARa3w8HcZj8vVT3i9v/vOWBkG2fyKJdrzFuqWququ4FcGn8IOJ3rSzNpn1074Fmu24AFAKr6ORBG4/w3juZS1SJV/YaqjqCxL1DV0lZ597Y+aNEGB0GCgHwaf0X88iDIoNOM/RvtdyC32VzACBoP4KR522fmngm4Fsj0hlynjF9B+xzI9eTzSnB7fB2wxktyTQJecj2Oo3E3QqzTuVzj+gN7cF0f5CWf13vAra7HA2ks3zbN52GuOCDA9fh/gMdb7f3b48Nvgw9tCrDDVaA/cy17nMatZ4BRNP40PQEcAbK9JNeHwEFgk+trsRd9Zn8Esl25lp+pfNsz1ylj26X0Pfy8fun6vLJcn9cAL8klwO+BHGALMNMbcrme/wL4VXvkacHnlQ6sdv1/3ARM9JJc1wM7XWPmAaGt9d52Ra4xxvgRX9ynb4wx5ixZ6RtjjB+x0jfGGD9ipW+MMX7ESt8YY/yIlb4xxvgRK31jjPEjVvrGGONH/j/PIx4FDGggPAAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(p,entropy(p))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当两者概率均为0.5时，熵取得最大值，通过最大化熵，可以使得分布更“等可能”；另外，熵还有优秀的性质，它是一个凹函数，所以最大化熵其实是一个凸问题。 \n", "\n", "对于“已知事实”，可以用约束条件来描述，比如4个值的随机变量分布，其中已知$p_1+p_2=0.4$，它的求解可以表述如下： \n", "\n", "$$\n", "\\max_{p} -\\sum_{i=1}^4 p_ilogp_i \\\\\n", "s.t. p_1+p_2=0.4\\\\\n", "p_i\\geq 0,i=1,2,3,4\\\\\n", "\\sum_i p_i=1\n", "$$ \n", "显然，最优解为：$p_1=0.2,p_2=0.2,p_3=0.3,p_4=0.3$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二.最大熵模型\n", "最大熵模型是最大熵原理在分类问题上的应用，它假设分类模型是一个条件概率分布$P(Y|X)$，即对于给定的输入$X$，以概率$P(Y|X)$输出$Y$，这时最大熵模型的目标函数定义为条件熵： \n", "\n", "$$\n", "H(P)=-\\sum_{x,y}\\tilde{P}(x)P(y|x)logP(y|x)\n", "$$ \n", "\n", "这里，$\\tilde{P}(x)$表示边缘分布$P(X)$的经验分布，$\\tilde{P}(x)=\\frac{v(X=x)}{N}$，$v(X=x)$表示训练样本中输入$x$出现的次数，$N$表示训练样本的总数。 \n", "\n", "而最大熵模型的“已知事实”可以通过如下等式来约束： \n", "\n", "$$\n", "\\sum_{x,y}\\tilde{P}(x)P(y|x)f(x,y)=\\sum_{x,y}\\tilde{P}(x,y)f(x,y)\n", "$$\n", "\n", "为了方便，左边式子记着$E_P(f)$，右边式子记着$E_{\\tilde{P}}(f)$，等式描述的是某函数$f(x,y)$关于模型$P(Y|X)$与经验分布$\\tilde{P}(X)$的期望与函数$f(x,y)$关于经验分布$\\tilde{P}(X,Y)$的期望相同。(这里$\\tilde{P}(x,y)=\\frac{v(X=x,Y=y)}{N}$) \n", "所以重要的约束信息将由$f(x,y)$来表示，它的定义如下： \n", "$$\n", "f(x,y)=\\left\\{\\begin{matrix}\n", "1 & x与y满足某一事实\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "故最大熵模型可以理解为，模型在某些事实发生的期望和训练集相同的条件下，使得条件熵最大化。所以，对于有$n$个约束条件的最大熵模型可以表示为： \n", "\n", "$$\n", "\\max_P -\\sum_{x,y}\\tilde{P}(x)P(y|x)logP(y|x) \\\\\n", "s.t. E_P(f_i)=E_{\\tilde{P}}(f_i),i=1,2,...,n\\\\\n", "\\sum_y P(y|x)=1\n", "$$ \n", "\n", "按照优化问题的习惯，可以改写为如下： \n", "\n", "$$\n", "\\min_P \\sum_{x,y}\\tilde{P}(x)P(y|x)logP(y|x) \\\\\n", "s.t. E_P(f_i)-E_{\\tilde{P}}(f_i)=0,i=1,2,...,n\\\\\n", "\\sum_y P(y|x)-1=0\n", "$$ \n", "\n", "由于目标函数为凸函数，约束条件为仿射，所以我们可以通过求解对偶问题，得到原始问题的最优解，首先引入拉格朗日乘子$w_0,w_1,...,w_n$，定义拉格朗日函数$L(P,w)$： \n", "\n", "$$\n", "L(P,w)=-H(P)+w_0(1-\\sum_yP(y|x)+\\sum_{i=1}^nw_i(E_{\\tilde{P}}(f_i))-E_P(f_i))\n", "$$ \n", "\n", "所以原问题等价于： \n", "$$\n", "\\min_P\\max_w L(P,w)\n", "$$ \n", "它的对偶问题： \n", "$$\n", "\\max_w\\min_P L(P,w)\n", "$$ \n", "\n", "首先，解里面的 $\\min_P L(P,w)$，其实对于$\\forall w$，$L(P,w)$都是关于$P$的凸函数，因为$-H(P)$是关于$P$的凸函数，而后面的$w_0(1-\\sum_yP(y|x)+\\sum_{i=1}^nw_i(E_{\\tilde{P}}(f_i))-E_P(f_i))$是关于$P(y|x)$的仿射函数，所以求$L(P,w)$对$P$的偏导数，并令其等于0，即可解得最优的$P(y|x)$,记为$P_w(y|x)$，即： \n", "$$\n", "\\frac{\\partial L(P,w)}{\\partial P(y|x)}=\\sum_{x,y}\\tilde{P}(x)(logP(y|x)+1)-\\sum_yw_0+\\sum_{i=1}^n\\sum_{x,y}\\tilde{P}(x)f_i(x,y)w_i\\\\\n", "=\\sum_{x,y}\\tilde{P}(x)(logP(y|x)+1-w_0-\\sum_{i=1}^nw_if_i(x,y))\\\\\n", "=0\n", "$$ \n", "\n", "在训练集中对任意样本$\\forall x,y$，都有$\\tilde{P}(x)(logP(y|x)+1-w_0-\\sum_{i=1}^nw_if_i(x,y))=0$，显然$\\tilde{P}(x)>0$($x$本来就是训练集中的一个样本，自然概率大于0)，所以$logP(y|x)+1-w_0-\\sum_{i=1}^nw_if_i(x,y)=0$，所以： \n", "$$\n", "P_w(y|x)=exp(\\sum_{i=1}^nw_if_i(x,y)+w_0-1)\\\\\n", "=\\frac{exp(\\sum_{i=1}^nw_if_i(x,y))}{exp(1-w_0)}\\\\\n", " =\\frac{exp(\\sum_{i=1}^nw_if_i(x,y))}{\\sum_y exp(\\sum_{i=1}^nw_if_i(x,y))}\n", "$$ \n", "\n", "这就是最大熵模型的表达式（最后一步变换用到了$\\sum_y P(y|x)=1$），这里$w$即是模型的参数，聪明的童鞋其实已经发现，最大熵模型其实就是一个线性函数外面套了一个**softmax**函数，它大概就是如下图所示的这么回事： \n", "![avatar](./source/05_最大熵模型.svg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来，将$L(P_w,w)$带入外层的$max$函数，即可求解最优的参数$w^*$： \n", "\n", "$$\n", "w^*=arg\\max_w L(P_w,w)\n", "$$ \n", "\n", "推导一下模型的梯度更新公式： \n", "$$\n", "L(P_w,w)=\\sum_{x,y}\\tilde{P}(x)P_w(y|x)logP_w(y|x)+\\sum_{i=1}^nw_i\\sum_{x,y}(\\tilde{P}(x,y)f_i(x,y)-\\tilde{P}(x)P_w(y|x)f_i(x,y))\\\\\n", "=\\sum_{x,y}\\tilde{P}(x,y)\\sum_{i=1}^nw_if_i(x,y)+\\sum_{x,y}\\tilde{P}(x)P_w(y|x)(logP_w(y|x)-\\sum_{i=1}^nw_if_i(x,y))\\\\\n", "=\\sum_{x,y}\\tilde{P}(x,y)\\sum_{i=1}^nw_if_i(x,y)-\\sum_{x,y}\\tilde{P}(x)P_w(y|x)log(\\sum_{y^{'}}exp(\\sum_{i=1}^nw_if_i(x,y^{'})))\\\\\n", "=\\sum_{x,y}\\tilde{P}(x,y)\\sum_{i=1}^nw_if_i(x,y)-\\sum_{x}\\tilde{P}(x)log(\\sum_{y^{'}}exp(\\sum_{i=1}^nw_if_i(x,y^{'})))\\\\\n", "=\\sum_{x,y}\\tilde{P}(x,y)w^Tf(x,y)-\\sum_{x}\\tilde{P}(x)log(\\sum_{y^{'}}exp(w^Tf(x,y^{'})))\n", "$$ \n", "这里，倒数第三步到倒数第二步用到了$\\sum_yP(y|x)=1$，最后一步中$w=[w_1,w_2,...,w_n]^T,f(x,y)=[f_1(x,y),f_2(x,y),...,f_n(x,y)]^T$，所以： \n", "$$\n", "\\frac{\\partial L(P_w,w)}{\\partial w}=\\sum_{x,y}\\tilde{P}(x,y)f(x,y)-\\sum_x\\tilde{P}(x)\\frac{exp(w^Tf(x,y))f(x,y)}{\\sum_{y^{'}}exp(w^Tf(x,y^{'}))} \n", "$$ \n", "\n", "所以，自然$w$的更新公式： \n", "$$\n", "w=w+\\eta\\frac{\\partial L(P_w,w)}{\\partial w}\n", "$$ \n", "这里，$\\eta$是学习率" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三.对特征函数的进一步理解\n", "上面推导出了最大熵模型的梯度更新公式，想必大家对$f(x,y)$还是有点疑惑，**“满足某一事实”**这句话该如何理解？这其实与我们的学习目的相关，学习目的决定了我们的**“事实”**，比如有这样一个任务，判断“打”这个词是量词还是动词，我们收集了如下的语料： \n", "\n", "| 句子/$x$ | 目标/$y$ |\n", "|-|-|\n", "| $x_1:$一打火柴 | $y_1:$量词 |\n", "| $x_2:$三打啤酒 | $y_2:$量词 |\n", "| $x_3:$打电话 | $y_3:$ 动词 |\n", "| $x_4:$打篮球 | $y_4:$ 动词 | \n", "\n", "通过观察，我们可以设计如下的两个特征函数来分别识别\"量词\"和\"动词\"任务： \n", "$$\n", "f_1(x,y)=\\left\\{\\begin{matrix}\n", "1 & \"打\"前是数字\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "$$\n", "f_2(x,y)=\\left\\{\\begin{matrix}\n", "1 & \"打\"后是名词，且前面无数字\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "当然，你也可以设计这样的特征函数来做识别“量词”的任务： \n", "$$\n", "f(x,y)=\\left\\{\\begin{matrix}\n", "1 & \"打\"前是\"一\",\"打\"后是\"火柴\"\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "$$\n", "f(x,y)=\\left\\{\\begin{matrix}\n", "1 & \"打\"前是\"三\",\"打\"后是\"啤酒\"\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "只是，这样的特征函数设计会使得模型学习能力变弱，比如遇到“三打火柴”，采用后面的特征函数设计就识别不出“打”是量词，而采用第一种特征函数设计就能很好的识别出来，所以要使模型具有更好的泛化能力，就需要设计更好的特征函数，而这往往依赖于人工经验，对于自然语言处理这类任务（比如上面的例子），我们可以较容易的归纳总结出一些有用的经验知识，但是对于其他情况，人工往往难以总结出一般性的规律，所以对于这些问题，我们需要设计更**“一般”**的特征函数。 \n", "#### 一种简单的特征函数设计\n", "我们可以简单考虑$x$的某个特征取某个值和$y$取某个类的组合做特征函数（对于连续型特征，可以采用分箱操作），所以我们可以设计这样两类特征函数： \n", "\n", "（1）离散型： \n", "$$\n", "f(x,y)=\\left\\{\\begin{matrix}\n", "1 & x_i=某值,y=某类\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "（2）连续型： \n", "$$\n", "f(x,y)=\\left\\{\\begin{matrix}\n", "1 & 某值1\\leq x_i< 某值2,y=某类\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ### 四.代码实现\n", "为了方便演示，首先构建训练数据和测试数据" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# 测试\n", "from sklearn import datasets\n", "from sklearn import model_selection\n", "from sklearn.metrics import f1_score\n", "\n", "iris = datasets.load_iris()\n", "data = iris['data']\n", "target = iris['target']\n", "X_train, X_test, y_train, y_test = model_selection.train_test_split(data, target, test_size=0.2,random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "为了方便对数据进行分箱操作，封装一个DataBinWrapper类，并对X_train和X_test进行转换（该类放到ml_models.wrapper_models中）" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class DataBinWrapper(object):\n", " def __init__(self, max_bins=10):\n", " # 分段数\n", " self.max_bins = max_bins\n", " # 记录x各个特征的分段区间\n", " self.XrangeMap = None\n", "\n", " def fit(self, x):\n", " n_sample, n_feature = x.shape\n", " # 构建分段数据\n", " self.XrangeMap = [[] for _ in range(0, n_feature)]\n", " for index in range(0, n_feature):\n", " tmp = x[:, index]\n", " for percent in range(1, self.max_bins):\n", " percent_value = np.percentile(tmp, (1.0 * percent / self.max_bins) * 100.0 // 1)\n", " self.XrangeMap[index].append(percent_value)\n", "\n", " def transform(self, x):\n", " \"\"\"\n", " 抽取x_bin_index\n", " :param x:\n", " :return:\n", " \"\"\"\n", " if x.ndim == 1:\n", " return np.asarray([np.digitize(x[i], self.XrangeMap[i]) for i in range(0, x.size)])\n", " else:\n", " return np.asarray([np.digitize(x[:, i], self.XrangeMap[i]) for i in range(0, x.shape[1])]).T" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "data_bin_wrapper=DataBinWrapper(max_bins=10)\n", "data_bin_wrapper.fit(X_train)\n", "X_train=data_bin_wrapper.transform(X_train)\n", "X_test=data_bin_wrapper.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[7, 6, 8, 7],\n", " [3, 5, 5, 6],\n", " [2, 8, 2, 2],\n", " [6, 5, 6, 7],\n", " [7, 2, 8, 8]], dtype=int64)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train[:5,:]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5, 2, 7, 9],\n", " [5, 0, 4, 3],\n", " [3, 9, 1, 2],\n", " [9, 3, 9, 7],\n", " [1, 8, 2, 2]], dtype=int64)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test[:5,:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由于特征函数可以有不同的形式，这里我们将特征函数解耦出来，构造一个SimpleFeatureFunction类（后续构造其他复杂的特征函数，需要定义和该类相同的函数名，该类放置到ml_models.linear_model中）" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class SimpleFeatureFunction(object):\n", " def __init__(self):\n", " \"\"\"\n", " 记录特征函数\n", " {\n", " (x_index,x_value,y_index)\n", " }\n", " \"\"\"\n", " self.feature_funcs = set()\n", "\n", " # 构建特征函数\n", " def build_feature_funcs(self, X, y):\n", " n_sample, _ = X.shape\n", " for index in range(0, n_sample):\n", " x = X[index, :].tolist()\n", " for feature_index in range(0, len(x)):\n", " self.feature_funcs.add(tuple([feature_index, x[feature_index], y[index]]))\n", "\n", " # 获取特征函数总数\n", " def get_feature_funcs_num(self):\n", " return len(self.feature_funcs)\n", "\n", " # 分别命中了那几个特征函数\n", " def match_feature_funcs_indices(self, x, y):\n", " match_indices = []\n", " index = 0\n", " for feature_index, feature_value, feature_y in self.feature_funcs:\n", " if feature_y == y and x[feature_index] == feature_value:\n", " match_indices.append(index)\n", " index += 1\n", " return match_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来对MaxEnt类进行实现，首先实现一个softmax函数的功能(ml_models.utils)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def softmax(x):\n", " if x.ndim == 1:\n", " return np.exp(x) / np.exp(x).sum()\n", " else:\n", " return np.exp(x) / np.exp(x).sum(axis=1, keepdims=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "进行MaxEnt类的具体实现（ml_models.linear_model）" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from ml_models import utils\n", "class MaxEnt(object):\n", " def __init__(self, feature_func, epochs=5, eta=0.01):\n", " self.feature_func = feature_func\n", " self.epochs = epochs\n", " self.eta = eta\n", "\n", " self.class_num = None\n", " \"\"\"\n", " 记录联合概率分布:\n", " {\n", " (x_0,x_1,...,x_p,y_index):p\n", " }\n", " \"\"\"\n", " self.Pxy = {}\n", " \"\"\"\n", " 记录边缘概率分布:\n", " {\n", " (x_0,x_1,...,x_p):p\n", " }\n", " \"\"\"\n", " self.Px = {}\n", "\n", " \"\"\"\n", " w[i]-->feature_func[i]\n", " \"\"\"\n", " self.w = None\n", "\n", " def init_params(self, X, y):\n", " \"\"\"\n", " 初始化相应的数据\n", " :return:\n", " \"\"\"\n", " n_sample, n_feature = X.shape\n", " self.class_num = np.max(y) + 1\n", "\n", " # 初始化联合概率分布、边缘概率分布、特征函数\n", " for index in range(0, n_sample):\n", " range_indices = X[index, :].tolist()\n", "\n", " if self.Px.get(tuple(range_indices)) is None:\n", " self.Px[tuple(range_indices)] = 1\n", " else:\n", " self.Px[tuple(range_indices)] += 1\n", "\n", " if self.Pxy.get(tuple(range_indices + [y[index]])) is None:\n", " self.Pxy[tuple(range_indices + [y[index]])] = 1\n", " else:\n", " self.Pxy[tuple(range_indices + [y[index]])] = 1\n", "\n", " for key, value in self.Pxy.items():\n", " self.Pxy[key] = 1.0 * self.Pxy[key] / n_sample\n", " for key, value in self.Px.items():\n", " self.Px[key] = 1.0 * self.Px[key] / n_sample\n", "\n", " # 初始化参数权重\n", " self.w = np.zeros(self.feature_func.get_feature_funcs_num())\n", "\n", " def _sum_exp_w_on_all_y(self, x):\n", " \"\"\"\n", " sum_y exp(self._sum_w_on_feature_funcs(x))\n", " :param x:\n", " :return:\n", " \"\"\"\n", " sum_w = 0\n", " for y in range(0, self.class_num):\n", " tmp_w = self._sum_exp_w_on_y(x, y)\n", " sum_w += np.exp(tmp_w)\n", " return sum_w\n", "\n", " def _sum_exp_w_on_y(self, x, y):\n", " tmp_w = 0\n", " match_feature_func_indices = self.feature_func.match_feature_funcs_indices(x, y)\n", " for match_feature_func_index in match_feature_func_indices:\n", " tmp_w += self.w[match_feature_func_index]\n", " return tmp_w\n", "\n", " def fit(self, X, y):\n", " self.eta = max(1.0 / np.sqrt(X.shape[0]), self.eta)\n", " self.init_params(X, y)\n", " x_y = np.c_[X, y]\n", " for epoch in range(self.epochs):\n", " count = 0\n", " np.random.shuffle(x_y)\n", " for index in range(x_y.shape[0]):\n", " count += 1\n", " x_point = x_y[index, :-1]\n", " y_point = x_y[index, -1:][0]\n", " # 获取联合概率分布\n", " p_xy = self.Pxy.get(tuple(x_point.tolist() + [y_point]))\n", " # 获取边缘概率分布\n", " p_x = self.Px.get(tuple(x_point))\n", " # 更新w\n", " dw = np.zeros(shape=self.w.shape)\n", " match_feature_func_indices = self.feature_func.match_feature_funcs_indices(x_point, y_point)\n", " if len(match_feature_func_indices) == 0:\n", " continue\n", " if p_xy is not None:\n", " for match_feature_func_index in match_feature_func_indices:\n", " dw[match_feature_func_index] = p_xy\n", " if p_x is not None:\n", " sum_w = self._sum_exp_w_on_all_y(x_point)\n", " for match_feature_func_index in match_feature_func_indices:\n", " dw[match_feature_func_index] -= p_x * np.exp(self._sum_exp_w_on_y(x_point, y_point)) / (\n", " 1e-7 + sum_w)\n", " # 更新\n", " self.w += self.eta * dw\n", " # 打印训练进度\n", " if count % (X.shape[0] // 4) == 0:\n", " print(\"processing:\\tepoch:\" + str(epoch + 1) + \"/\" + str(self.epochs) + \",percent:\" + str(\n", " count) + \"/\" + str(X.shape[0]))\n", "\n", " def predict_proba(self, x):\n", " \"\"\"\n", " 预测为y的概率分布\n", " :param x:\n", " :return:\n", " \"\"\"\n", " y = []\n", " for x_point in x:\n", " y_tmp = []\n", " for y_index in range(0, self.class_num):\n", " match_feature_func_indices = self.feature_func.match_feature_funcs_indices(x_point, y_index)\n", " tmp = 0\n", " for match_feature_func_index in match_feature_func_indices:\n", " tmp += self.w[match_feature_func_index]\n", " y_tmp.append(tmp)\n", " y.append(y_tmp)\n", " return utils.softmax(np.asarray(y))\n", "\n", " def predict(self, x):\n", " return np.argmax(self.predict_proba(x), axis=1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "processing:\tepoch:1/5,percent:30/120\n", "processing:\tepoch:1/5,percent:60/120\n", "processing:\tepoch:1/5,percent:90/120\n", "processing:\tepoch:1/5,percent:120/120\n", "processing:\tepoch:2/5,percent:30/120\n", "processing:\tepoch:2/5,percent:60/120\n", "processing:\tepoch:2/5,percent:90/120\n", "processing:\tepoch:2/5,percent:120/120\n", "processing:\tepoch:3/5,percent:30/120\n", "processing:\tepoch:3/5,percent:60/120\n", "processing:\tepoch:3/5,percent:90/120\n", "processing:\tepoch:3/5,percent:120/120\n", "processing:\tepoch:4/5,percent:30/120\n", "processing:\tepoch:4/5,percent:60/120\n", "processing:\tepoch:4/5,percent:90/120\n", "processing:\tepoch:4/5,percent:120/120\n", "processing:\tepoch:5/5,percent:30/120\n", "processing:\tepoch:5/5,percent:60/120\n", "processing:\tepoch:5/5,percent:90/120\n", "processing:\tepoch:5/5,percent:120/120\n", "f1: 0.9295631904327557\n" ] } ], "source": [ "# 构建特征函数类\n", "feature_func=SimpleFeatureFunction()\n", "feature_func.build_feature_funcs(X_train,y_train)\n", "\n", "maxEnt = MaxEnt(feature_func=feature_func)\n", "maxEnt.fit(X_train, y_train)\n", "y = maxEnt.predict(X_test)\n", "\n", "print('f1:', f1_score(y_test, y, average='macro'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过前面的分析，我们知道特征函数的复杂程度决定了模型的复杂度，下面我们添加更复杂的特征函数来增强MaxEnt的效果，上面的特征函数仅考虑了单个特征与目标的关系，我们进一步考虑二个特征与目标的关系，即： \n", "\n", "$$\n", "f(x,y)=\\left\\{\\begin{matrix}\n", "1 & x_i=某值,x_j=某值,y=某类\\\\ \n", "0 & 否则\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "如此，我们可以定义一个新的UserDefineFeatureFunction类（**注意:类中的方法名称要和SimpleFeatureFunction一样**）" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class UserDefineFeatureFunction(object):\n", " def __init__(self):\n", " \"\"\"\n", " 记录特征函数\n", " {\n", " (x_index1,x_value1,x_index2,x_value2,y_index)\n", " }\n", " \"\"\"\n", " self.feature_funcs = set()\n", "\n", " # 构建特征函数\n", " def build_feature_funcs(self, X, y):\n", " n_sample, _ = X.shape\n", " for index in range(0, n_sample):\n", " x = X[index, :].tolist()\n", " for feature_index in range(0, len(x)):\n", " self.feature_funcs.add(tuple([feature_index, x[feature_index], y[index]]))\n", " for new_feature_index in range(0,len(x)):\n", " if feature_index!=new_feature_index:\n", " self.feature_funcs.add(tuple([feature_index, x[feature_index],new_feature_index,x[new_feature_index],y[index]]))\n", "\n", " # 获取特征函数总数\n", " def get_feature_funcs_num(self):\n", " return len(self.feature_funcs)\n", "\n", " # 分别命中了那几个特征函数\n", " def match_feature_funcs_indices(self, x, y):\n", " match_indices = []\n", " index = 0\n", " for item in self.feature_funcs:\n", " if len(item)==5:\n", " feature_index1, feature_value1,feature_index2,feature_value2, feature_y=item\n", " if feature_y == y and x[feature_index1] == feature_value1 and x[feature_index2]==feature_value2:\n", " match_indices.append(index)\n", " else:\n", " feature_index1, feature_value1, feature_y=item\n", " if feature_y == y and x[feature_index1] == feature_value1:\n", " match_indices.append(index)\n", " index += 1\n", " return match_indices" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "processing:\tepoch:1/5,percent:30/120\n", "processing:\tepoch:1/5,percent:60/120\n", "processing:\tepoch:1/5,percent:90/120\n", "processing:\tepoch:1/5,percent:120/120\n", "processing:\tepoch:2/5,percent:30/120\n", "processing:\tepoch:2/5,percent:60/120\n", "processing:\tepoch:2/5,percent:90/120\n", "processing:\tepoch:2/5,percent:120/120\n", "processing:\tepoch:3/5,percent:30/120\n", "processing:\tepoch:3/5,percent:60/120\n", "processing:\tepoch:3/5,percent:90/120\n", "processing:\tepoch:3/5,percent:120/120\n", "processing:\tepoch:4/5,percent:30/120\n", "processing:\tepoch:4/5,percent:60/120\n", "processing:\tepoch:4/5,percent:90/120\n", "processing:\tepoch:4/5,percent:120/120\n", "processing:\tepoch:5/5,percent:30/120\n", "processing:\tepoch:5/5,percent:60/120\n", "processing:\tepoch:5/5,percent:90/120\n", "processing:\tepoch:5/5,percent:120/120\n", "f1: 0.957351290684624\n" ] } ], "source": [ "# 检验\n", "feature_func=UserDefineFeatureFunction()\n", "feature_func.build_feature_funcs(X_train,y_train)\n", "\n", "maxEnt = MaxEnt(feature_func=feature_func)\n", "maxEnt.fit(X_train, y_train)\n", "y = maxEnt.predict(X_test)\n", "\n", "print('f1:', f1_score(y_test, y, average='macro'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以根据自己对数据的认识，不断为模型添加一些新特征函数去增强模型的效果，只需要修改build_feature_funcs和match_feature_funcs_indices这两个函数即可（**但注意控制函数的数量规模**） \n", "简单总结一下MaxEnt的优缺点，优点很明显：我们可以diy任意复杂的特征函数进去，缺点也很明显：训练很耗时，而且特征函数的设计好坏需要先验知识，对于某些任务很难直观获取" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }