{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 概念\n", "- 梯度是向量,和参数维度一样。简单地来说,多元函数的导数(derivative)就是梯度(gradient),分别对每个变量进行微分,然后用逗号分割开,梯度是用括号包括起来,说明梯度其实是一个向量。\n", "\n", "- 计算过程:\n", " - ①、对各参数求偏导,得出 $\\triangledown f$;\n", " - ②、设置初始参数向量、学习率 η及阈值;\n", " - ③、下个参数向量=上一个参数向量-η*$\\triangledown f$;\n", " - ④、直到$\\triangledown f$≤阈值停止,此时寻找到参数向量是局部最优解。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 手工代码实现" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一元函数\n", " $f(x)=3x^2+5x$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def loss_function(x):\n", " return 3*(x**2)+5*x\n", "\n", "def det_function(x):\n", " return 6*x+5" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_GD(od_f=None,f=None,x_0=None,eta=0.001,threshold=0):\n", " x_all=[]\n", " od_f_all=[]\n", " det_f_all=[]\n", " count_n=0\n", " while True:\n", " count_n+=1\n", " y=od_f(x_0)\n", " #计算导数在x处的值\n", " det_f=f(x_0) \n", " od_f_all.append(y)\n", " x_all.append(x_0)\n", " det_f_all.append(det_f)\n", " #计算下一个点的值\n", " x_0=x_0-eta*det_f\n", " #判断是否到达目的地\n", " if det_f<=threshold:\n", " break\n", " \n", " return x_all,od_f_all,det_f_all,count_n\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "x,y,dety,n=get_GD(loss_function,det_function,x_0=1,eta=0.1,threshold=0.0001)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XlcVNX/x/HXYRhAEEEERGVVcV8Td03LUivT0hYtSzM1bbHy12Jf21fLLLPya6a2mtqelZWpbWYu4IqKIiKLuCCogMg65/fHmF8VF5Rh7iyf5+PRw5g73ftuqre3M+eeo7TWCCGEcB0eRgcQQghhW1LsQgjhYqTYhRDCxUixCyGEi5FiF0IIFyPFLoQQLkaKXQghXIwUuxBCuBgpdiGEcDGeRlw0ODhYR0dHG3FpIYRwWgkJCYe01iEXep8hxR4dHU18fLwRlxZCCKellEqrzPtkKEYIIVyMFLsQQrgYKXYhhHAxUuxCCOFipNiFEMLFVLrYlVLzlFIHlVKJp7w2VSmVpJTarJT6RikVWD0xhRBCVNbF3LF/CPQ/47VfgVZa6zbATuAJG+USQghxiSpd7FrrP4HcM15bqrUuO/HjaiDchtkqWJl8iJm/76rOSwghhNOz5Rj7KOCncx1USo1VSsUrpeKzs7Mv6QJ/JWczbelODuYVXWpGIYRweTYpdqXUZKAMmH+u92itZ2ut47TWcSEhF3wi9qxu7RhBuUXzRULmJSYVQgjXV+ViV0qNAAYAt2utddUjnVvDkJp0jgli0boMLJZqvZQQQjitKhW7Uqo/8DgwUGtdaJtI53db50jScwtZlZJjj8sJIYTTuZjpjguAf4CmSqlMpdTdwDuAP/CrUmqjUmpWNeU8qV/LMAJ9zSxYl17dlxJCCKdU6dUdtdbDzvLyXBtmqRQfs4kb2zfg09Vp5BQUU6emt70jCCHERTteUs6MFcnc3jmS8Nq+1Xotp3zydFinSErLNV+tly9RhRDOYcmWffz39xQyco9X+7Wcstib1PWnQ1RtFq7LoJq/rxVCCJtYuC6dmGA/ujQMqvZrOWWxAwztGMHu7GOsTc298JuFEMJAyQfyWbfnMEM7RqCUqvbrOW2xD2hTH38fTxaslS9RhRCObeG6DMwmxZAO1fpw/klOW+w1vEzc0K4BSxL3c6SwxOg4QghxVkWl5Xy1PpO+LcIIttNkD6ctdoChnSIoKbPwzYa9RkcRQoiz+mXrfo4UljK0U4TdrunUxd6yfgBtwwNYsDZdvkQVQjik+avTia7jS/dGwXa7plMXO8DQTpHsPFDA+vQjRkcRQojT7Nifz9o9udzWORIPj+r/0vRfTl/s17etj5+XiYXyJaoQwsF8tiYNL08Pbupgv2EYcIFir+ntycB29fl+cxZHj5caHUcIIQA4VlzG1+v3cl3regT5edn12k5f7AC3d46iqNTCN/IkqhDCQXy/KYv84jJu7xxp92u7RLG3ahBA24hAPlmdJl+iCiEMp7Xm0zVpND3xlLy9uUSxA9zRJYqU7GP8s1uW8xVCGGtz5lES9+YxvEukXZ40PZPLFPuANvUIqGFm/mr5ElUIYaz5a9Lw9TJxQ/sGhlzfZYrdx2zilrhwftm6X/ZEFUIY5mhhKYs3ZTGoXX38fcyGZHCZYgfrl6hlFs3CdRlGRxFCuKmvN2RSVGrh9s5RhmVwqWKPDvajZ2wwn61Jp6zcYnQcIYSb0Vozf006bSMCadUgwLAcLlXsYP0SdX9eEcu2HzQ6ihDCzaxJzWXXwQKGGzDF8VQuV+xXNgulXoAP89ekGR1FCOFm5q9Jp5aPJwPa1Dc0x8VsZj1PKXVQKZV4ymtBSqlflVLJJ361/4TNM3iaPLitUyR/JR8i9dAxo+MIIdxEdn4xPyfuY0iHcGp4mQzNcjF37B8C/c94bRKwXGsdCyw/8bPhbu0UgaeHYv5quWsXQtjH5/EZlJZrQ540PVOli11r/Sdw5j50g4CPTvz5R8ANNspVJaH+PvRrFcYXCZkcLyk3Oo4QwsWVlVv4dHUa3RvXoXGov9FxqjzGXldrvQ/gxK+h53qjUmqsUipeKRWfnZ1dxcte2B1dojh6vJTvN2dV+7WEEO5t6bYD7DtaxIiu0UZHAez45anWerbWOk5rHRcSElLt1+scE0RsaE0ZjhFCVLsPV+0hvHYN+jSva3QUoOrFfkApVQ/gxK8OM8dQKcXwLlFsyjzKxgzZhEMIUT2278tjbWoud3SJwmTHzTTOp6rFvhgYceLPRwDfVfF8NjX4sgbU9Pbko1V7jI4ihHBRH63ag4/Zg1s72nczjfO5mOmOC4B/gKZKqUyl1N3AFOBqpVQycPWJnx2Gv4+ZmzqE88PmLFk/Rghhc0cKS/h2415uaNeAQF/7bqZxPhczK2aY1rqe1tqstQ7XWs/VWudorftorWNP/HrmrBnDjewWTZnF+pivEELY0qJ1GRSVWhjRLdroKKdxuSdPzxQd7McVTUOZvyad4jKZ+iiEsI1yi+aT1Wl0igmieb1aRsc5jcsXO1jv2g8VFLNkyz6jowghXMSKpINkHj7OSAe7Wwc3KfaescE0CvHjg7/3yNZ5Qgib+GjVHuoF+NC3hWNMcTyVWxS7UoqR3WPYnHmU9eky9VEIUTW7DuazctchhneJwtPkeDXqeImqyeD2DfD38eRDmfoohKiij1al4eXpwVAHmuJ4Krcpdj9vT26Ni+CnLfvYf1SmPgohLk1eUSlfrc/k+jb1qVPT2+g4Z+U2xQ4wols05VrzqSwzIIS4RF/EZ1JYUu6QX5r+y62KPSLIl6ua1+WztekUlcrURyHExSm3aD74O5W4qNq0Djdu67sLcatiB7irWzS5x0r4fpOs+iiEuDhLt+4n8/BxRveMMTrKebldsXdtVIemdf1l6qMQ4qLNXZlKRFANrm4RZnSU83K7YrdOfYxm27481qQ63AoIQggHtSH9MPFphxnVPcZhVnE8F7crdoAb2jUgyM+LOX+lGh1FCOEk5q5Mxd/Hk5vjHHOK46ncsthreJkY3iWKZdsPkJJdYHQcIYSDyzxcyE+J+7mtUyQ1vT2NjnNBblnsAHd2jcLL04O5K+WuXQhxfv/u6eBoqziei9sWe3BNb4Zc1oCvEjLJKSg2Oo4QwkHlF5WycG0G17WuR/3AGkbHqRS3LXaAu3s0pLjMwifywJIQ4hw+j88kv7jM4ac4nsqti71xaE36NAvl43/S5IElIUQFZeUWPvg7lU7RQbQJDzQ6TqW5dbEDjO7ZkNxjJXy9fq/RUYQQDmbptgNkHj7O3U50tw5S7HRpGETrBgHMWbkbi0UeWBJC/M/clalE1bEuReJMbFLsSqmHlVJblVKJSqkFSikfW5zXHpRSjO4Zw+7sY6xIOmh0HCGE0ebPh+hotIcHb026gReObXL4B5LOVOViV0o1ACYAcVrrVoAJGFrV89rTta3rUT/Ah/f/2m10FCGEkebPh7FjIS0NpTXhedn0nDrZ+roTsdVQjCdQQynlCfgCTrXCltnkwageMaxJzWVzpuywJITbmjwZCgtPe0kVFlpfdyJVLnat9V7gdSAd2Acc1VovPfN9SqmxSql4pVR8dnZ2VS9rc7d2jMDf25P3ZZkBIdxXevrFve6gbDEUUxsYBMQA9QE/pdTwM9+ntZ6ttY7TWseFhIRU9bI25+9jZljnSJZs2UdGbuGF/wIhhOuJjLy41x2ULYZirgJStdbZWutS4Gugmw3Oa3cju0WjQJYZEMJdTZpEhblxvr7w0ktGpLlktij2dKCLUspXKaWAPsB2G5zX7uoH1uCG9g1YuC5dlhkQwg2V/vEn5SiOBAaDUhAVBbNnw+23Gx3tothijH0N8CWwHthy4pyzq3peo4zr1YjiMgsf/L3H6ChCCHv64QfMCxfwTvehZGzZBRYL7NnjdKUONpoVo7V+RmvdTGvdSmt9h9baaW93G4fWpF+LMD76Zw/5RaVGxxFC2MORI+h77mFX3WjWDx/v0PuZVobbP3l6NuN7NyK/qIz5a5zrm3AhxCV69FH0/v1M7DeBMX2aGZ2myqTYz6JtRCDdG9dh7spUWRxMCFe3bBnMmcOiy2+hvEMcPRoHG52oyqTYz+He3o3Jzi/my4RMo6MIIapLQQGMGcOx6EY8e9nN3NOrEdY5IM5Niv0cujWqQ9vwAN77M4WycovRcYQQ1eE//0GnpfHCjRMJrRvIta3CjE5kE1Ls56CUYnzvxmTkHufHLfuMjiOEsLWVK+Gdd8gafjcLvaIY36sxnibXqETX+LuoJn1b1KVxaE3++3sKWsuSvkK4jOPH4e67ISqKSR2GElbLhyEdGhidymak2M/Dw0Mxrlcjkvbny5K+QriSZ5+FnTtJen4af+0r4p5eDfH2NBmdymak2C9gULv6NAiswUy5axfCNaxbB6+/DqNH83JpOME1vRja0bnWgrkQKfYLMJs8GNMzhoS0w6xNzTU6jhCiKkpKYNQoqFePLQ89yZ87s7m7R0NqeLnO3TpIsVfKrR0jCa7pxTu/7TI6ihCiKl5+GRITYdYsZiRkE1DDzPAurnW3DlLslVLDy8SYng35K/kQCWmHjY4jhLgUmzdbV2m8/Xa2d7icX7cd4K7u0fj7mI1OZnNS7JU0vEsUQX5ezFiebHQUIcTFKiuzDsEEBcFbb/Hub7uo6e3JyG7RRierFlLsleTn7cnonjH8sTObjRmyfZ4QTmXaNEhIgHfeIcXizY9b9jG8SxSBvl5GJ6sWUuwX4c6u0QT6muWuXQhnsmMHPPMMDB4MN93EzN9S8Pb0YHTPGKOTVRsp9otQ09uTMT0bsiLpoGx6LYQzKC+3DsH4+sK775KaU8i3G/dyW6cogmt6G52u2kixX6Q7u0YRUMPMjOUyQ0YIh/fuu7BqFUyfDmFhvL08GbNJMa53Q6OTVSsp9ovk72Pm7h4xLNt+gMS9R42OI4Q4l9RUeOIJuOYauOMOUrIL+HbjXoZ3jiLU38fodNVKiv0SjOgWjb+Pp4y1C+GotIbRo8FkgvfeA6V4e3kyXp4e3NOrkdHpqp0U+yUIqGFmVPcYlm47wLasPKPjCCHONGcOrFgBU6dCRAS7DhaweFMWI7pGE+LvumPr/7JJsSulApVSXyqlkpRS25VSXW1xXkc2qnsM/t6eTF+20+goQohTZWbCI4/AFVfAmDEAzFiejI/ZxNjLXXts/V+2umN/C/hZa90MaAtst9F5HVaAr5nRPRuydNsBtmTKWLsQDkFruOce6wNJ778PHh4kH8jn+81Z3Nk1mjouPBPmVFUudqVULeByYC6A1rpEa+0WcwFH9bDOa3996Q6jowghAObPhyVLrEsHNLKOpb+1PBlfN7pbB9vcsTcEsoEPlFIblFJzlFJ+Z75JKTVWKRWvlIrPzs62wWWN5+9jZlyvRvyxM5t1e2TlRyEMdeAAPPggdO0KDzwAwM4D+fy4ZR8jukUT5OeaT5mejS2K3RO4DPiv1ro9cAyYdOabtNaztdZxWuu4kJAQG1zWMdzZ1fqgw+u/7JD12oUw0v33w7FjMG+edTYM8NayZPy8rA8WuhNbFHsmkKm1XnPi5y+xFr1b8PXy5P4rGrEmNZe/d+UYHUcI9/TVV/Dll9alA5o1AyBx71F+3LKPu7pHU9uN7tbBBsWutd4PZCilmp54qQ+wrarndSbDOkdSP8CH15fKXbsQdpeTA/fdB5ddZp0Nc8LUX3YQ6GtmjBuNrf/LVrNiHgDmK6U2A+2Al210Xqfg7WnigT6xbMw4InujCmFvDz9sLfd588BsXVt9ze4c/tiZzfhejajlguutX4hNil1rvfHE+HkbrfUNWmu3243ipg7hRNXx5fWlO7FY5K5dCLtYsgQ++cS6dEDbtgBorXntlx3UreXNCBddb/1C5MlTGzGbPHjoqli278tjSeI+o+MI4fqOHrXOWW/ZEiZPPvnyiqSDJKQdZkKfWHzMrrWXaWVJsdvQwLYNiA2tyRtLd1JabjE6jhCu7bHHICvLOgTjbX3wyGLRTP1lB9F1fLklLsLggMaRYrchk4fi0X5N2X3oGJ/HZxgdRwjXtWIFzJ4NEydCp04nX/5+cxZJ+/OZ2LcpZpP71pv7/p1Xk6tb1CUuqjbTlyVTWFJmdBwhXM+xY9aVG2Nj4fnnT75cUmZh2tKdtKhXiwGt6xkY0HhS7DamlOKJa5uRnV/MvJWpRscRwvVMnmxda33uXKhR4+TLi+IzSM8t5NF+TfHwUAYGNJ4UezXoEBVE3xZ1mfXHbnIKio2OI4Tr+PtvmDHDOm+9Z8+TLxeWlPH28mQ6Rtemd1PXebL9UkmxV5PH+jelsKSMd36TLfSEsImiIrj7boiMhFdeOe3Q+3+mcjC/mEnXNEMp975bByn2atM41J9b4iL4dHUaGbmFRscRwvk99xzs2GFdjtff/+TLB/OLeO/PFK5pFUaHqCADAzoOKfZq9NBVTTB5KKbJsr5CVE1CgnU3pFGj4OqrTzv05q/JlJRZeKx/M4PCOR4p9moUFuDDqO4xfLsxSza+FuJSlZRYCz00FKZNO+1Q8oF8Fq1LZ3iXKGKCK6wW7rak2KvZuN6NCPQ1M+WnJFkgTIhLMWUKbN4Ms2ZBYODph35Kws/Lkwl9Yg0K55ik2KtZLR8zD/aJZeWuQ/y2QxYIE+KiJCbCiy/CsGEwcOBph1alHGJ50kHuvaKxW22iURlS7HYwvEsUDUP8ePHH7bLUgBCVVVZmHYIJDLROcTyFxaJ5ecl26gf4cFf3aGPyOTApdjswmzyYfG1zdmcf49PVaUbHEcI5vPkmrFsH77wDwcGnHVq8KYvEvXk80q+p2y70dT5S7HZyZbNQejQOZvqyZI4UlhgdRwjHtnMnPP003HAD3HzzaYeKSsuZ+ssOWtavxQ3tGhgU0LFJsduJUoonBzQnv6iUt5YnGx1HCMdlsVgfRPLxgZkz4YwHjmb/uZu9R44z+brmbr90wLlIsdtRs7Ba3Noxkk/+SSMlu8DoOEI4ppkzYeVKmD4d6p2+mFfWkePM/H0X17QKo1uj4HOcQEix29nEq5vgYzbxypLtRkcRwvHs2QOTJkH//nDnnRUOv/pzEhYN/7m2uf2zOREpdjsL8ffmvisas2z7Qf7edcjoOEI4Dq1hzBjr0Mt771UYgonfk8t3G7O45/KGRAT5GhTSOdis2JVSJqXUBqXUD7Y6p6u6q3s04bVr8MIP2yiT6Y9CWM2bB8uWWZcOiIw87ZDFonn2+62E1fJhfO9GBgV0Hra8Y38QkPGFSvAxm3jyuuYk7c+X6Y9CAOzda90NqVcvGDu2wuEvEzJJ3JvHE9c2w9fL04CAzsUmxa6UCgeuA+bY4nzuoF/LMHrGBjPt151k58ua7cKNaQ3jxkFpKcyZAx6n11JeUSmv/ZJEh6jaDGxb36CQzsVWd+zTgccAGVeoJKUUzw5sSVFpOa/9nGR0HCGMs2AB/PADvPQSNG5c4fA7K3aRc6yEZ65vIWutV1KVi10pNQA4qLVOuMD7xiql4pVS8dnZ2VW9rEtoFFKTUT1i+CIhk/Xph42OI4T9HTwIEyZAly7WX8+w62ABH/ydys0dwmkTHniWE4izscUde3dgoFJqD7AQuFIp9emZb9Jaz9Zax2mt40JCZOuqf024Mpa6tbx5+rtEyi2y+qNwMw88APn51v1LTacvDaC15unvEvExm3i0n6y1fjGqXOxa6ye01uFa62hgKLBCaz28ysnchJ+3J5Ova0Hi3jwWrks3Oo4Q9vP11/D55/DMM9CiRYXDizdlsSolh8f6NyPE39uAgM5L5rE7gOvb1KNLwyCm/rKDw8dkHRnhBnJz4d57oV07ePTRCofzikp58cfttAkP4LZOkWc5gTgfmxa71vp3rfUAW57THSileG5gK/KLypgq2+gJdzBxIuTkwAcfgNlc4fAbS3dyqKCYF29ohUnWg7locsfuIJqG+TOyWzQL1qaTkCZfpAoX9tNP8NFH1qUD2rWrcDhx71E+/mcPwztHyReml0iK3YE8fHUTwmr5MPmbLbIhh3BNeXnWB5BatIAnn6xw2GLRPPltIkF+XjzSr6kBAV2DFLsDqentyXMDW5K0P585f6UaHUcI23v8ccjKsi4f4F3xC9FF8RlszDjC5OuaE1Cj4hCNqBwpdgfTt2UY/VrW5a3lO0nPKTQ6jhC289tv1g2pH34YOneucDinoJgpPyXROSZINtCoIil2B/TcwFZ4enjw5HeJaC1z24ULOHYMRo+2Pln6/PNnfcvzP2yjsKSMF29oJU+YVpEUuwMKC/Dhkb5N+HNnNos3ZRkdR4iqe+op2L3buhaMb8Uld1ckHeC7jVncd0VjYuv6GxDQtUixO6g7ukbTNjyAF37YJnukCuf2zz/W3ZDuvde6euMZCorLePKbRJrUrcm9vSuuFSMunhS7gzJ5KF4e3JrDhaVM+UkWCRNOqqgIRo2CiAiYMuWsb3nt5yT25RXxyuA2eHlKJdmCfIoOrGX9AO7uEcPCdRmskt2WhDN64QVISoLZs8G/4hBL/J5cPlmdxoiu0XSIqm1AQNckxe7gHr6qCdF1fHn8680cKy4zOo4Qlbd+Pbz6Ktx1F/TrV+FwcVk5k77eQv2AGjwqc9ZtSordwdXwMvHaTW3JPHxc1m0XzqO01DoEExIC06ad9S3vrtjFroMFvHhjK/y8ZVckW5JidwKdYoIY2S2aj/5JY/XuHKPjCHFhr74KmzZZ563XrjjEkrQ/j5m/p3BDu/pc0TTUgICuTYrdSTzarylRdXx57MvNFJbIkIxwYFu3WueqDx0KgwZVOFxSZmHiok0E+pp5+vqWBgR0fVLsTsLXy5PXhrQhPbeQqb/ICpDCQZWXW4dgAgJgxoyzvuXtFcls25fHSze2JsjPy84B3YMUuxPp3LAOI7pG8eGqPaxNzTU6jhAVTZ8Oa9fC229bx9fPsDHjCDN/T2HwZQ3o1zLMgIDuQYrdyTzWvxnhtWvw2JebZEhGOJbkZOuKjYMGwa23VjhcVFrO/32+kVB/b56RIZhqJcXuZPy8PXltSFvScgt56cftRscRwspigbvvtq7YOHMmnGWtl6m/7CAl+xivDmkjKzdWMyl2J9S1UR3G9mzI/DXpLN9+wOg4Qlhnv/z1F7z5JtSvX+Hw6t05zPs7leFdIrm8iWxmX92k2J3UxL5NaF6vFo99uZns/GKj4wh3lpZmXWe9b18YObLC4YLiMh79chMRtX154prm9s/nhqpc7EqpCKXUb0qp7UqprUqpB20RTJyft6eJ6be2I7+4jElfbZblfYUxtLbuiATWZQPOMgTz3OKtZB4+zrRb2sqDSHZiizv2MuD/tNbNgS7AfUqpFjY4r7iApmH+TOrfjOVJB/lsbbrRcYQ7+vBDWLrU+kBSVFSFw99vyuKLhEzu692YjtFB9s/npqpc7FrrfVrr9Sf+PB/YDsj2J3Yysls0PWODeeGHbaRkFxgdR7iTrCzrbkiXXw7jxlU4nJFbyH++3kL7yEAevCrWgIDuy6Zj7EqpaKA9sMaW5xXn5uGheP3mtviYTTy8aKNsgi3sQ2sYPx6Ki62bZ3icXiVl5RYeXLgBgBlD22M2ydd59mSzT1spVRP4CnhIa513luNjlVLxSqn47OxsW11WAHVr+fDKja3ZnHlUnkoV9rFoESxeDC++CLEV78ZnLE9mffoRXryxFRFBFXdMEtXLJsWulDJjLfX5Wuuvz/YerfVsrXWc1jou5CxPpImquaZ1PYZ3iWT2n7tZkSRTIEU1ys6GBx6ATp3goYcqHF6zO4d3ftvFkMvCGSSbUhvCFrNiFDAX2K61fqPqkcSlevK6FrSoV4uJn28i68hxo+MIVzVhAhw9CvPmgcl02qEjhSU8tGgjkUG+PDdIni41ii3u2LsDdwBXKqU2nvjjWhucV1wkH7OJd2+/jNIyCw8s2CDj7cL2vv0WFi6Ep5+GlqcXt8WieeSLTWTnFzNjWHtqytRGw9hiVsxKrbXSWrfRWrc78ccSW4QTFy8m2I9XhrQhIe0w05buNDqOcCWHD1u/MG3XzvpA0hlm/ZnCsu0HmXxdc9qEBxoQUPxLfkt1QQPb1mf17hxm/ZFC54ZBspGBsI3/+z/r+PqSJWA+fa2XVSmHeP2XHVzXph4ju0Ubk0+cJHOQXNTTA1rQLMyfiYs2yni7qLpffoEPPrDeqbdvf9qh/UeLmLBgAzHBfrw6pA3qLE+fCvuSYndRPmYTM2+/jNJyzfhPEygqLTc6knBW+fnWZQOaN4ennjrtUGm5hfs/W09hSTmzhneQcXUHIcXuwhqG1OSNW9qyKfMoT36bKOvJiEszaRJkZMDcueDjc9qhKT8lEZ92mFcGtya2rr9BAcWZpNhdXN+WYUzoE8uXCZl8sjrN6DjC2fzxh3V99Ycegq5dTzv04+Z9zF2ZyoiuUTJf3cFIsbuBh/rE0qdZKM9/v401u3OMjiOcRWGhdfOMRo2sT5ieYltWHo98sYn2kYFMvk7W/HM0UuxuwMND8ebQdkQG+XLfZ+vly1RROU8/DSkp1rVgfP+3LEBOQTFjPo4noIaZ94Z3wMtTasTRyD8RN1HLx8x7d3TgeEm5fJkqLmz1autuSOPGQe/eJ18uKbMwfv56DhUUM/vODoTW8jn3OYRhpNjdSGxdf964tR2bMo/yuGzOIc6luBhGjYIGDazrrJ/i2e+3sjY1l9duaiMPITkwKXY3069lGI/0bcJ3G7OYsXyX0XGEI3rxRdi+3bojUq1aJ1/++J89fLYmnfG9G8mXpQ5OJp26ofuuaMzuQ8d4c9lOooN95T9S8T8bNsArr8CIEdC//8mXVyQd4NnFW7mqeSiP9G1qYEBRGXLH7oaUUrwyuDWdYoJ49IvNJKTlGh1JOILSUusQTEgIvPG/hVoT9x7l/s820LJ+ADOGtcfkIU+WOjopdjfl7WniveEdqB/ow5iPE0jPKTQ6kjDa1KmwcaN13nqQdX/SvUeOM+rDddT29WLuiDh8veR/8p2BFLsbq+3nxbyRHbFozZ3z1nCooNjoSMIo27bBc8/BLbfAjTcCkFdUyqgP1nG8pJwP7uooM2CciBS7m2sYUpO5IzqyP69EikvhAAANeElEQVSIuz5YR0FxmdGRhL2Vl1uHYPz94e23ASguK2fcJwmkZBcw644ONJHlApyKFLugQ1Rt3r3tMrbty2P8pwmUlMkGHW5lxgxYs8b6a2iodSPqBRtZlZLDaze1oXvjYKMTioskxS4A6NO8LlMGt+av5EM88sUmLBaZ4+4Wdu2CyZPh+uth2DC01kz+JpGft+7n6QEtGHxZuNEJxSWQb0LESTfHRZBdUMxrP+8g0NfMcwNbytrarsxigdGjwcsL/vtfUIpXf0piUXwGD1zZmFE9YoxOKC6RFLs4zfhejTh8rIT3/0qlhtnEpGuaSbm7qtmzras3zpkDDRrw3h8pzPojheFdIpl4dROj04kqsEmxK6X6A28BJmCO1nqKLc4r7E8pxX+ubU5RqYX3/tyNj9nEw/IfuetJT4dHH4WrroJRo5i3MpVXfkpiQJt6PDewlfxm7uSqXOxKKRPwLnA1kAmsU0ot1lpvq+q5hTGUUjw3sCVFpeW8tTwZH7OJ8b0bGR1L2IrW1h2RtIb33+fj1Wk8/8M2+rcM481b28kDSC7AFnfsnYBdWuvdAEqphcAgQIrdiXl4KKYMaUNxmYVXf07CbFKM7tnQ6FjCFj7+2LqH6dtv8+l+xdPfJXJ1i7rMGNYes0nmU7gCWxR7AyDjlJ8zgc42OK8wmMlDMe2WtpRZLLz443br/qly5+7c9u2z7obUowcL4gbw5LeJ9GkWyru3XSbrqrsQWxT72f6/rcJcOaXUWGAsQGRkpA0uK+zBbPJgxtD2mE2bePXnJIrLynmwT6yMwTojreHee6GoiC/HP8sT326ld9MQZg6XUnc1tij2TCDilJ/Dgawz36S1ng3MBoiLi5NJ0k7E0+TBG7e0w2zyYPqyZErKLDzar6mUu7P54gv49ltWjnmMRzYX0b9lGG8Na4e3p8noZMLGbFHs64BYpVQMsBcYCtxmg/MKB2LyULw2pA1enh7M/D2FwpJynh7QAg/5os05ZGej77+ffU1aMyKwO4PbN+C1m9rgKWPqLqnK/1S11mXA/cAvwHbgc6311qqeVzgeDw/FSze0YnSPGD5ctYcJCzdQXCZb7Dms+fMhOho8PNDR0VgOHWJkj3u4rVtDXr+5rZS6C7PJPHat9RJgiS3OJRybUorJ1zUnxN+bV35K4nBhCbOGd8Dfx2x0NHGq+fOtUxoLrcsxq8JCLB4m7q9TyPWD5IliVye/ZYuLppTinl6NeOOWtqzZncut763mYH6R0bHEqSZPPlnq/zJbyhn4+btS6m5Ail1cssGXhTNnRBx7co5x47ur2L4vz+hI4l/p6Rf3unApUuyiSno3DeXze7pSbtEM+e8qft12wOhIYs0aLKZzjLLKVGO3IMUuqqxVgwC+u787saE1GftJPLP+SEFrmdFqd7m56LFj0V27kmfyosR0xvcevr7w0kvGZBN2JcUubKJuLR8W3dOVAW3qM+WnJCZ+vonCEtmNyS60hg8/RDdtimXuPObEDWLy9B8pnzMHoqJAKeuvs2fD7bcbnVbYgSzbK2zGx2xixtB2NK1bk2m/7mRr1lFm3t6BxqE1jY7muhITYfx4WLmSpIatmXj901w9rB9v94m1PmMw8k6jEwoDyB27sCmlFPdfGcvHozpxqKCEQe+sZPGmCg8ii6oqKLAuu9uuHSWJW3lm4ERuuf1VJjx8ExOvbiIPjrk5uWMX1aJnbAg/TujB/Z9tYMKCDaxNzWHytS2o4SWPr1eJ1vDNN/Dgg5CZyfq+QxjVbAgRjSP4flh7ooP9jE4oHIDcsYtqUy+gBgvHdmFMzxg+XZ3OgLf/YnPmEaNjOa/du2HAABgyhOKA2jzy8H8Z3P4ubryqDV+O7yqlLk5SRsxeiIuL0/Hx8Xa/rjDOyhObZB8qKObBPrGM791IHmmvrOJieO01ePlltKcn/4x4kLtrdcHHx4spQ9rQr2WY0QmFnSilErTWcRd8nxS7sJejhaU89V0iizdl0TYikCmDW9O8Xi2jYzm2X3+F++6D5GQKBg3moU53sCzPTN8WdXnpxtaE+HsbnVDYUWWLXW6ZhN0E+JqZMaw9M4a1JyO3kOvfXsmUn5I4XiILiVWQlQVDh0Lfvli0ZvGr84hrPZq1JTWYfms73rujg5S6OCf58lTY3cC29enROJhXlmxn1h8p/LglixcGtaJ301CjoxmvrAzefReeegpKSkib8Bhjw/qwI7eU/i1DeW5QS+rW8jE6pXBwcscuDBHk58XUm9uyYEwXzB4ejPxgHXd9sJbkA/lGRzPO6tXQsSM89BDHO3XhmZcX0avG5RR7mvnwro7MuqODlLqoFBljF4YrLivnw7/38M5vuygsKWdoxwgevroJwTXdZKghJweeeALef5/y+g346s5HmKxi8fDw4L4rGjP28ob4mGWaqJAvT4UTyj1WwlvLdjJ/TTrenh7c2S2a0T1iqOOqBW+xwEcfwWOPoQ8fZsPgEYxrOIBDyotbO0bwYJ8mhAXIHbr4Hyl24bRSsguYviyZHzZn4eNpYniXSMZc3pBQfxcquS1brEsB/P03mS0uY8LlY1gfEEH/lmE80q+pLMMgzkqKXTi9XQcLmPnbLr7duBdPkwcD29ZnRNdoWocHGB2t8ubPt256kZ5uXTL3qadg+3b09OkU+vrzQq+RfN7ySq5p3YBxvRo519+bsDspduEy9hw6xpyVu/l6/V4KS8ppHxnInV2j6N+ynmMvUXDG9nQAGlDAZ237Mf2Ku+jTszljL29EjDw1KirBLsWulJoKXA+UACnAXVrrCz4zLsUuLkVeUSlfJWTy8T9ppB46hp+XiX4twxjYzjp90uGeZI2OhrS0Ci/n1KzNFz+s46YO4e7zBbGwCXsVe19ghda6TCn1KoDW+vEL/XVS7KIqLBbN6tQcFm/MYsmWfeQVlVHHz4veTUPp3TSEy2NDCPCtps21zxxaeeml09c4z8mhcPlvZP+4jMiP3+NsayxqpVAWS/XkEy7N7kMxSqkbgZu01hdcyV+KXdhKcVk5v+/I5sfN+/hjZzZHj5fioaB9ZG06RgfRLiKQ9pGBlZ//7eUFpaUVX4+Kgmuvtc5iOXWT6Bo1KLxjJLkFxXiv+ouQPckAFHl6obTGu/wc59qz5+L/ZoXbM6LYvwcWaa0/vdB7pdhFdSi3aDZmHOGPHQf5I/kQ27KOUlpu/fe7foAPTcL8iQn2o2GwH9HBfoTV8qG2nxe1fb0weahzl/oJ/46Pn80xsw8J4S3IaNUBU+/eRPfvTYeE5ZjHjTv9NwJfX9nJSFyyyhb7BZcUUEotA862fNxkrfV3J94zGSgD5p/nPGOBsQCRsqGuqAYmD0WHqNp0iKrNxL5NKSotZ2tWHhszjrAp4wgp2QWsTc2l8Iy1aZSCgBpmNpSWnrO44dylrlEkbkmlfVQdLvc5ZQioxR3g4XH+oRshqkGV79iVUiOAcUAfrXXhhd4PcscujKO15mB+Mbuzj5FdUMzhYyXknvjj+Rtbn7fYz0mGVoSd2OyO/QIX6Q88DvSqbKkLYSSlFHVr+Vz6mitKWXcx+pevr/UuXAgHUtX5Ye8A/sCvSqmNSqlZNsgkhDHMF5hJ4+sL48ZZ79CVsv4q4+XCAVXpjl1r3dhWQYQwXEnJ+WfFyPi4cBKyHrsQpyopMTqBEFXmYI/qCSGEqCopdiGEcDFS7EII4WKk2IUQwsVIsQshhIsxZD12pVQ2UHE908oJBg7ZMI4rkc/m3OSzOTf5bM7OET+XKK11yIXeZEixV4VSKr4yj9S6I/lszk0+m3OTz+bsnPlzkaEYIYRwMVLsQgjhYpyx2GcbHcCByWdzbvLZnJt8NmfntJ+L042xCyGEOD9nvGMXQghxHk5Z7EqpF5RSm08sFbxUKVXf6EyOQik1VSmVdOLz+UYpFWh0JkeglLpZKbVVKWVRSjnlTAdbU0r1V0rtUErtUkpNMjqPo1BKzVNKHVRKJRqd5VI5ZbEDU7XWbbTW7YAfgKeNDuRAfgVaaa3bADuBJwzO4ygSgcHAn0YHcQRKKRPwLnAN0AIYppRqYWwqh/Eh0N/oEFXhlMWutc475Uc/rPsMC0BrvVRrXXbix9VAuJF5HIXWervWeofRORxIJ2CX1nq31roEWAgMMjiTQ9Ba/wnkGp2jKpx2PXal1EvAncBR4AqD4ziqUcAio0MIh9QAyDjl50ygs0FZhI05bLErpZYBYWc5NFlr/Z3WejIwWSn1BHA/8IxdAxroQp/NifdMBsqA+fbMZqTKfC7ipLPt2y3/5+siHLbYtdZXVfKtnwE/4kbFfqHPRik1AhgA9NFuNJ/1Iv6dEdY79IhTfg4HsgzKImzMKcfYlVKxp/w4EEgyKoujUUr1Bx4HBmqtC43OIxzWOiBWKRWjlPIChgKLDc4kbMQpH1BSSn0FNAUsWFeJHKe13mtsKseglNoFeAM5J15arbUeZ2Akh6CUuhF4GwgBjgAbtdb9jE1lLKXUtcB0wATM01q/ZHAkh6CUWgD0xrq64wHgGa31XENDXSSnLHYhhBDn5pRDMUIIIc5Nil0IIVyMFLsQQrgYKXYhhHAxUuxCCOFipNiFEMLFSLELIYSLkWIXQggX8//o6uPJvPG/sAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_x = np.linspace(-3,1.3,300)\n", "plot_y = loss_function(plot_x)\n", "plt.plot(plot_x,plot_y)\n", "plt.plot(x,y, color='red', marker='o')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 多元方程式\n", " - $f(x)=β_{0}+β_{1}x_{1}+β_{2}x_{2}+...+β_{i}x_{i}$\n", " - 向量化后梯度:$\\triangledown J(θ)=\\frac{1}{m} X_{b}^{T}(X_{b}θ-y)$\n", " - 把代码封装在线性回归代码中" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_gd(self, X_train, y_train, eta=0.01, n_iters=1e4):\n", " \"\"\"根据训练数据集X_train, y_train, 使用梯度下降法训练Linear Regression模型\"\"\"\n", " assert X_train.shape[0] == y_train.shape[0], \\\n", " \"the size of X_train must be equal to the size of y_train\"\n", " \n", " def J(theta, X_b, y):\n", " try:\n", " return np.sum((y - X_b.dot(theta)) ** 2) / len(y)\n", " except:\n", " return float('inf')\n", " \n", " def dJ(theta, X_b, y):\n", " return X_b.T.dot(X_b.dot(theta) - y) * 2. / len(y)\n", " \n", " \n", " def gradient_descent(X_b, y, initial_theta, eta, n_iters=1e4, epsilon=1e-8):\n", "\n", " theta = initial_theta\n", " cur_iter = 0\n", "\n", " while cur_iter < n_iters:\n", " gradient = dJ(theta, X_b, y)\n", " last_theta = theta\n", " theta = theta - eta * gradient\n", " if (abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):\n", " break\n", "\n", " cur_iter += 1\n", "\n", " return theta\n", " \n", " X_b = np.hstack([np.ones((len(X_train), 1)), X_train])\n", " initial_theta = np.zeros(X_b.shape[1])\n", " self._theta = gradient_descent(X_b, y_train, initial_theta, eta, n_iters)\n", "\n", " self.intercept_ = self._theta[0]\n", " self.coef_ = self._theta[1:]\n", "\n", " return self\n", " " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# 用小例子验证我们的代码\n", "import numpy as np\n", "from sklearn import datasets\n", "\n", "boston = datasets.load_boston()\n", "X = boston.data\n", "y = boston.target\n", "\n", "X = X[y < 50.0]\n", "y = y[y < 50.0]" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\tool\\anaconda\\lib\\site-packages\\numpy\\core\\fromnumeric.py:86: RuntimeWarning: overflow encountered in reduce\n", " return ufunc.reduce(obj, axis, dtype, out, **passkwargs)\n", "E:\\Data analysis\\machinelearning\\机器学习小组二期\\代码\\myML_Algorithm\\LinearRegression.py:43: RuntimeWarning: overflow encountered in square\n", " return np.sum((y - X_b.dot(theta)) ** 2) / len(y)\n", "E:\\Data analysis\\machinelearning\\机器学习小组二期\\代码\\myML_Algorithm\\LinearRegression.py:60: RuntimeWarning: invalid value encountered in double_scalars\n", " if (abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):\n" ] }, { "data": { "text/plain": [ "array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from myML_Algorithm.LinearRegression import linearRegression as LR\n", "from myML_Algorithm.model_selection import train_test_split\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, seed=666)\n", "\n", "lr = LR()\n", "lr.fit_gd(X_train, y_train)\n", "lr.coef_" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[6.3200e-03, 1.8000e+01, 2.3100e+00, 0.0000e+00, 5.3800e-01,\n", " 6.5750e+00, 6.5200e+01, 4.0900e+00, 1.0000e+00, 2.9600e+02,\n", " 1.5300e+01, 3.9690e+02, 4.9800e+00],\n", " [2.7310e-02, 0.0000e+00, 7.0700e+00, 0.0000e+00, 4.6900e-01,\n", " 6.4210e+00, 7.8900e+01, 4.9671e+00, 2.0000e+00, 2.4200e+02,\n", " 1.7800e+01, 3.9690e+02, 9.1400e+00],\n", " [2.7290e-02, 0.0000e+00, 7.0700e+00, 0.0000e+00, 4.6900e-01,\n", " 7.1850e+00, 6.1100e+01, 4.9671e+00, 2.0000e+00, 2.4200e+02,\n", " 1.7800e+01, 3.9283e+02, 4.0300e+00]])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X[:3]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "lr.coef_全是NAN值,是由于数据值大小不一致,存在极大值和极小值。解决方法是数据做归一化处理。因此在使用梯度下降前需要对数据进行归一化处理" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-1.04042202 0.83093351 -0.24794356 0.01179456 -1.35034756 2.25074\n", " -0.66384353 -2.53568774 2.25572406 -2.34011572 -1.76565394 0.70923397\n", " -2.72677064]\n" ] } ], "source": [ "# 归一化处理后\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "standardScaler = StandardScaler()\n", "standardScaler.fit(X_train)\n", "X_train_std = standardScaler.transform(X_train)\n", "\n", "lr.fit_gd(X_train_std, y_train)\n", "print(lr.coef_)\n", "X_test_std = standardScaler.transform(X_test)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 随机梯度下降\n", "- 日常工作中,我们遇到的数据往往特别大,如果采取传统的梯度下降算法系统运行时间会特别长,而且我们找到的是局部最优解。于是,我们采取随机梯度下降方法,可以减少运行时间也能找到全局最优解" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-1.04042202, 0.83093351, -0.24794356, 0.01179456, -1.35034756,\n", " 2.25074 , -0.66384353, -2.53568774, 2.25572406, -2.34011572,\n", " -1.76565394, 0.70923397, -2.72677064])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def fit_sgd(self, X_train, y_train, n_iters=50, t0=5, t1=50):\n", " \"\"\"根据训练数据集X_train, y_train, 使用梯度下降法训练Linear Regression模型\"\"\"\n", " assert X_train.shape[0] == y_train.shape[0], \\\n", " \"the size of X_train must be equal to the size of y_train\"\n", " assert n_iters >= 1\n", "\n", " def dJ_sgd(theta, X_b_i, y_i):\n", " return X_b_i * (X_b_i.dot(theta) - y_i) * 2.\n", "\n", " def sgd(X_b, y, initial_theta, n_iters=5, t0=5, t1=50):\n", "\n", " def learning_rate(t):\n", " return t0 / (t + t1)\n", "\n", " theta = initial_theta\n", " m = len(X_b)\n", " for i_iter in range(n_iters):\n", " # 将原本的数据随机打乱,然后再按顺序取值就相当于随机取值\n", " indexes = np.random.permutation(m)\n", " X_b_new = X_b[indexes,:]\n", " y_new = y[indexes]\n", " for i in range(m):\n", " gradient = dJ_sgd(theta, X_b_new[i], y_new[i])\n", " theta = theta - learning_rate(i_iter * m + i) * gradient\n", "\n", " return theta\n", "\n", " X_b = np.hstack([np.ones((len(X_train), 1)), X_train])\n", " initial_theta = np.random.randn(X_b.shape[1])\n", " self._theta = sgd(X_b, y_train, initial_theta, n_iters, t0, t1)\n", " self.intercept_ = self._theta[0]\n", " self.coef_ = self._theta[1:]\n", " return self" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# sklearn中的SGD" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wall time: 168 ms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "D:\\tool\\anaconda\\lib\\site-packages\\sklearn\\linear_model\\stochastic_gradient.py:166: FutureWarning: max_iter and tol parameters have been added in SGDRegressor in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " FutureWarning)\n" ] }, { "data": { "text/plain": [ "0.8032440859338719" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import SGDRegressor\n", "\n", "sgd_reg = SGDRegressor() # 默认n_iter=5\n", "%time sgd_reg.fit(X_train_std, y_train)\n", "sgd_reg.score(X_test_std, y_test)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wall time: 3.99 ms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "D:\\tool\\anaconda\\lib\\site-packages\\sklearn\\linear_model\\stochastic_gradient.py:152: DeprecationWarning: n_iter parameter is deprecated in 0.19 and will be removed in 0.21. Use max_iter and tol instead.\n", " DeprecationWarning)\n" ] }, { "data": { "text/plain": [ "0.8127865478262196" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 增加迭代次数\n", "sgd_reg = SGDRegressor(n_iter=100)\n", "%time sgd_reg.fit(X_train_std, y_train)\n", "sgd_reg.score(X_test_std, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "速度非常快!增加迭代次数,可以提升效果" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }