{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 1.学习率调整策略\n", "\n", "Q:_LRScheduler的主要属性和方法有哪些?\n", "- 属性:\n", "- optimizer:关联的优化器\n", "- last_epoch:记录epoch数\n", "- base_lrs:记录初始学习率\n", "- 方法:\n", "- step():更新下一个epoch的学习率\n", "- get_lr():虚函数,计算下一个epoch的学习率\n", "\n", "Q:如何等间隔调整学习率?\n", "- `torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)`\n", "- 调整方式:lr = lr * gamma\n", "- step_size:调整间隔数\n", "- gamma:调整系数\n", "\n", "Q:StepLR代码示例" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de5RU5Znv8e+vmvtFMIAOigiJmAQ1ONgBGc8EieIloyEmOkJIgsZ1cCajRk1yNMslGgZPYiajTpQ1CZ4QxGhk4plEjjFqQnAmcbxwCUaQGFsGYwtRLgqicu3n/FG727Kphl0Nu6u79++zVq2ueve7q57eVV1Pv5f9bkUEZmZmzRWqHYCZmbVPThBmZlaWE4SZmZXlBGFmZmU5QZiZWVldqh3AwTJw4MAYNmxYtcMwM+tQli1btjEiBpXb1mkSxLBhw1i6dGm1wzAz61AkvdTSNncxmZlZWU4QZmZWlhOEmZmV1WnGIMwsnV27dlFfX8/27durHYq1oR49ejBkyBC6du2aeh8nCLOcqa+vp2/fvgwbNgxJ1Q7H2kBEsGnTJurr6xk+fHjq/TLtYpJ0lqTnJdVJurbM9o9JWi5pt6Tzm22bJumF5DYtyzjN8mT79u0MGDDAySFHJDFgwICKW42ZJQhJNcBs4GxgJDBF0shm1f4EXATc22zf9wE3AGOBMcANkg7NKlazvHFyyJ/WvOdZdjGNAeoiYg2ApPuAScBzjRUiYm2yraHZvmcCv4yIzcn2XwJnAT8+2EH+ect27n2qxWnA1kzv7l24+JThdOvi+Q1mnV2WCeJI4OWSx/UUWwSt3ffI5pUkTQemAwwdOrRVQb66dTu3L65r1b5503jpkJOOPpTaYe+rbjDWod10003ce++91NTUUCgU+P73v8/YsWO57bbbmD59Or169Trg15g3bx5Lly7ljjvueE/5sGHD6Nu3L5I49NBDmT9/PkcfffRe+8+dO5dbb70VSTQ0NHDTTTcxadKksq/12GOP8Z3vfIcHH3ywVbGuXbuWc845h5UrV+6zXuMJwQMHDmzV61QqywRRrj2T9upEqfaNiDnAHIDa2tpWXflo1FH9+e9v/k1rds2d/3pxI5+98yl2N/giU9Z6TzzxBA8++CDLly+ne/fubNy4kZ07dwJw22238bnPfe6gJIh9Wbx4MQMHDuSGG25g1qxZ3Hnnne/ZXl9fz0033cTy5cvp168f27ZtY8OGDZnGlIU9e/ZQU1PT6v2z7CeoB44qeTwEWNcG+1pGCkkfZoMThB2A9evXM3DgQLp37w7AwIEDOeKII/jud7/LunXrmDBhAhMmTADg0UcfZdy4cYwePZoLLriAbdu2AcX/pK+55hrGjBnDmDFjqKtrXS/AuHHjeOWVV/Yqf+211+jbty99+vQBoE+fPk2zf+rq6jj99NMZNWoUo0eP5sUXXwRg27ZtnH/++XzoQx9i6tSpNF6tc9myZYwfP56TTjqJM888k/Xr1zeVjxo1inHjxjF79uym1543bx6XXXZZ0+NzzjmHxx57bK8Yf/SjHzFmzBhOPPFELr30Uvbs2dMU64wZMxg7dixPPPFEq45LoyxbEEuAEZKGA68Ak4HPptz3EeB/lwxMnwF8/eCHaJVoShDOD53GN/7fKp5bt/WgPufIIw7hhnOPa3H7GWecwcyZMzn22GM5/fTTufDCCxk/fjxXXHEFt9xyS9N/9xs3bmTWrFn86le/onfv3tx8883ccsstzJgxA4BDDjmEp59+mvnz53PllVe2qnvn4Ycf5lOf+tRe5aNGjeLwww9n+PDhnHbaaXz605/m3HPPBWDq1Klce+21nHfeeWzfvp2GhgZefvllfve737Fq1SqOOOIITjnlFB5//HHGjh3L5ZdfzgMPPMCgQYNYsGAB1113HXPnzuXiiy/m9ttvZ/z48Xzta1+rKO7Vq1ezYMECHn/8cbp27cqXvvQl7rnnHr7whS/w1ltvcfzxxzNz5syKj0dzmSWIiNgt6TKKX/Y1wNyIWCVpJrA0IhZK+ijwU+BQ4FxJ34iI4yJis6R/pJhkAGY2Dlhb9RSSjr8GX8fcDkCfPn1YtmwZv/nNb1i8eDEXXngh3/rWt7joooveU+/JJ5/kueee45RTTgFg586djBs3rmn7lClTmn5eddVVFcUwYcIEXn31VQ477DBmzZq11/aamhoefvhhlixZwqJFi7jqqqtYtmwZX/nKV3jllVc477zzgOLJZ43GjBnDkCFDADjxxBNZu3Yt/fv3Z+XKlUycOBEodvkMHjyYLVu28MYbbzB+/HgAPv/5z/OLX/widfyLFi1i2bJlfPSjHwXgnXfe4bDDDmuK/TOf+UxFx6MlmZ4oFxEPAQ81K5tRcn8Jxe6jcvvOBeZmGZ9VRk0tCCeIzmJf/+lnqaamhlNPPZVTTz2VE044gbvuumuvBBERTJw4kR//uPzkxdJpm5VO4Vy8eDG9e/fmoosuYsaMGdxyyy1ln7+xC2vixIlcfPHFXH311S0+Z2OXWePvt3v3biKC4447bq+unjfeeKPFmLt06UJDw7sTO8uduxARTJs2jW9+85t7bevRo8cBjTuU8lxFS60maUI4P9iBeP7553nhhReaHq9YsaJpFlHfvn158803ATj55JN5/PHHm8YX3n77bf74xz827bdgwYKmn6Uti7R69uzJbbfdxvz589m8+b0dFOvWrWP58uV7xXjIIYcwZMgQfvaznwGwY8cO3n777RZf44Mf/CAbNmxoShC7du1i1apV9O/fn379+vHb3/4WgHvuuadpn2HDhrFixYqmrqunn356r+c97bTTuP/++3nttdcA2Lx5My+9dPCn63upDUutsYtpjwch7ABs27aNyy+/nDfeeIMuXbpwzDHHMGfOHACmT5/O2WefzeDBg1m8eDHz5s1jypQp7NixA4BZs2Zx7LHHAsUv57Fjx9LQ0NBiK2PevHlNX+ZQ7LYqNXjwYKZMmcLs2bO5/vrrm8p37drFV7/6VdatW0ePHj0YNGgQ3/ve9wC4++67ufTSS5kxYwZdu3blJz/5SYu/a7du3bj//vu54oor2LJlC7t37+bKK6/kuOOO44c//CFf/OIX6dWrF2eeeWbTPqeccgrDhw/nhBNO4Pjjj2f06NF7Pe/IkSOZNWsWZ5xxBg0NDXTt2pXZs2eXna57IBSd5N/B2tra8AWDsrXylS2cc/tvmfP5kzjjuL+odjjWSqtXr+bDH/5wtcM4IG19PkBnUe69l7QsImrL1XcXk6WmpkHq6sZhZm3DXUyWWuM0187S6rSOa+3atdUOIRfcgrDUGgep3YLo+Jzk86c177kThKXWNEjtL5cOrUePHmzatMlJIkcarwdRet5GGu5istTkLqZOYciQIdTX13fItYWs9RqvKFcJJwhLreAT5TqFrl27VnRVMcsvdzFZak1LbTS/eoeZdUpOEJaaWxBm+eIEYakVCk4QZnniBGGpFXyinFmuOEFYau5iMssXJwhLzUttmOWLE4SlVuPzIMxyxQnCUmvsYvJy32b54ARhqfma1Gb54gRhqSn5tLiLySwfnCAsNc9iMssXJwhLzedBmOWLE4Sl5kFqs3xxgrDUfEU5s3xxgrDU3MVkli9OEJaaB6nN8sUJwlLzUhtm+eIEYalJoiBocIYwywUnCKtIQXIXk1lOOEFYRYoJotpRmFlbcIKwikie5mqWF04QVhF3MZnlR6YJQtJZkp6XVCfp2jLbu0takGx/StKwpLyrpLskPStptaSvZxmnpVdTEHsaqh2FmbWFzBKEpBpgNnA2MBKYImlks2qXAK9HxDHArcDNSfkFQPeIOAE4Cbi0MXlYdUk+D8IsL7JsQYwB6iJiTUTsBO4DJjWrMwm4K7l/P3CaJAEB9JbUBegJ7AS2ZhirpVSQPAZhlhNZJogjgZdLHtcnZWXrRMRuYAswgGKyeAtYD/wJ+E5EbG7+ApKmS1oqaemGDRsO/m9geynIJ8qZ5UWWCUJlypp/tbRUZwywBzgCGA58RdL796oYMSciaiOidtCgQQcar6XgQWqz/MgyQdQDR5U8HgKsa6lO0p3UD9gMfBZ4OCJ2RcRrwONAbYaxWkqFghOEWV5kmSCWACMkDZfUDZgMLGxWZyEwLbl/PvDrKHZw/wn4uIp6AycDf8gwVkupuNRGtaMws7aQWYJIxhQuAx4BVgP/FhGrJM2U9Mmk2g+AAZLqgKuBxqmws4E+wEqKieaHEfH7rGK19NzFZJYfXbJ88oh4CHioWdmMkvvbKU5pbb7ftnLlVn1easMsP3wmtVXES22Y5YcThFWkpiD2OEGY5YIThFXEXUxm+eEEYRXxUhtm+eEEYRXxUhtm+eEEYRXxeRBm+eEEYRUpyIPUZnnhBGEVcReTWX44QVhFCgWv5mqWF04QVhEvtWGWH04QVhH5PAiz3HCCsIrUCBqcIcxywQnCKuIuJrP8cIKwijhBmOWHE4RVRL4mtVluOEFYRXwehFl+OEFYRWoKYo+bEGa54ARhFXEXk1l+OEFYRdzFZJYfThBWkYJbEGa54QRhFfE0V7P8cIKwihQ8SG2WG/tNEJJ6Sbpe0p3J4xGSzsk+NGuPCgI3IMzyIU0L4ofADmBc8rgemJVZRNauuYvJLD/SJIgPRMS3gV0AEfEOoEyjsnbLCcIsP9IkiJ2SegIBIOkDFFsUlkNyF5NZbnRJUedG4GHgKEn3AKcAF2cZlLVfNQVfk9osL/abICLiUUnLgJMpdi19OSI2Zh6ZtUvuYjLLjzSzmBZFxKaI+HlEPBgRGyUtaovgrP2RoKGh2lGYWVtosQUhqQfQCxgo6VDeHZg+BDiiDWKzdshLbZjlx75aEJcCy4APJT8bbw8As9M8uaSzJD0vqU7StWW2d5e0INn+lKRhJds+IukJSaskPZskLKsyL7Vhlh8ttiAi4l+Af5F0eUTcXukTS6qhmEgmUjx3YomkhRHxXEm1S4DXI+IYSZOBm4ELJXUBfgR8PiKekTSAZJqtVZcHqc3yI80g9e2SjgdGAj1KyufvZ9cxQF1ErAGQdB8wCShNEJMozpICuB+4Q5KAM4DfR8QzyWttSvXbWObkLiaz3EgzSH0DcHtymwB8G/hkiuc+Eni55HF9Ula2TkTsBrYAA4BjgZD0iKTlkv5XC7FNl7RU0tINGzakCMkOlLuYzPIjzYly5wOnAX+OiIuBUUD3FPuVO9u6+VdLS3W6AP8DmJr8PE/SaXtVjJgTEbURUTto0KAUIdmB8jRXs/xIkyDeiYgGYLekQ4DXgPen2K8eOKrk8RBgXUt1knGHfsDmpPw/ImJjRLwNPASMTvGalrGCRIObEGa5kCZBLJXUH7iT4iym5cDTKfZbAoyQNFxSN2AysLBZnYXAtOT++cCvo9jB/QjwkWQl2S7AeN47dmFVUmxBVDsKM2sL+xykTgaMvxkRbwDfk/QwcEhE/H5/TxwRuyVdRvHLvgaYGxGrJM0ElkbEQuAHwN2S6ii2HCYn+74u6RaKSSaAhyLi563/Ne1gKY5BOEOY5cE+E0REhKSfASclj9dW8uQR8RDF7qHSshkl97cDF7Sw748oTnW1dqRQ8BiEWV6k6WJ6UtJHM4/EOgR5FpNZbqRZzXUCcKmkl4C3KM48ioj4SKaRWbvkpTbM8iNNgjg78yisw6iRr0ltlhdpzqR+qS0CsY7BJ8qZ5UeaMQizJsWJbbibySwHnCCsIoUkQbgVYdb5OUFYRQrJ4iie6mrW+e13DELSm+y9htIWYCnwlcbVWi0fCoXGFoQThFlnl2YW0y0U11C6l+IU18nAXwDPA3OBU7MKztqfpi4mX3bUrNNL08V0VkR8PyLejIitETEH+ERELAAOzTg+a2fcxWSWH2kSRIOkv5VUSG5/W7LN3xI58+4gtd96s84uTYKYCnye4jLfryb3PyepJ3BZhrFZO6SmFkR14zCz7KU5UW4NcG4Lm397cMOx9q6m4PMgzPIizSymQcD/BIaV1o+IL2YXlrVXjV1MXm7DrPNLM4vpAeA3wK+APdmGY+1dwV1MZrmRJkH0iohrMo/EOgQvtWGWH2kGqR+U9InMI7EOwUttmOVHmgTxZYpJ4h1JWyW9KWlr1oFZ+1STfGI8zdWs80szi6lvWwRiHYM8SG2WGy0mCEkfiog/SBpdbntELM8uLGuvCk1jEFUOxMwyt68WxNXAdOCfy2wL4OOZRGTtmpfaMMuPFhNERExPfk5ou3CsvfNSG2b5kWaaK5L+ir1PlJufUUzWjr273HeVAzGzzKU5k/pu4APACt49US4AJ4gccheTWX6kaUHUAiPDZ0YZ7mIyy5M050GspHiBILN3WxC+YJBZp5emBTEQeE7S08COxsKI+GRmUVm7JbcgzHIjTYK4MesgrOOo8XkQZrmxzwQhqQa4PiJOb6N4rJ0rJJ2Se5whzDq9fY5BRMQe4G1J/dooHmvn3MVklh9pupi2A89K+iXwVmNhRFyRWVTWbhW83LdZbqSZxfRz4HrgP4FlJbf9knSWpOcl1Um6tsz27pIWJNufkjSs2fahkrZJ+mqa17Ps+YJBZvmRZjXXu1rzxMn4xWxgIlAPLJG0MCKeK6l2CfB6RBwjaTJwM3BhyfZbgV+05vUtG42D1A3OEGad3n5bEJJGSLpf0nOS1jTeUjz3GKAuItZExE7gPmBSszqTgMYEdD9wmpJObkmfAtYAq9L+Mpa9puW+3cVk1uml6WL6IfCvwG5gAsUlNu5Osd+RwMslj+uTsrJ1ImI3sAUYIKk3cA3wjX29gKTpkpZKWrphw4YUIdmBauxicn4w6/zSJIieEbEIUES8FBE3km6pb5Upa/610lKdbwC3RsS2fb1ARMyJiNqIqB00aFCKkOxAvbtYnzOEWWeXahaTpALwgqTLgFeAw1LsVw8cVfJ4CLCuhTr1kroA/YDNwFjgfEnfBvoDDZK2R8QdKV7XMuRBarP8SJMgrgR6AVcA/0ixm2laiv2WACMkDaeYVCYDn21WZ2HyXE8A5wO/ThYF/OvGCpJuBLY5ObQPXqzPLD/SzGJaAiApIuLitE8cEbuTFscjQA0wNyJWSZoJLI2IhcAPgLsl1VFsOUxuzS9hbafgWUxmuZHmehDjKH6R9wGGShoFXBoRX9rfvhHxEPBQs7IZJfe3Axfs5zlu3N/rWNt5twVR5UDMLHNpBqlvA84ENgFExDPAx7IMytov+YJBZrmRJkEQES83K9pTtqJ1el5qwyw/0gxSv5xckzokdaM4WL0627CsvarxNanNciNNC+LvgH+geFJbPXAisN/xB+ucGqe57nGGMOv00sxi2ghMLS2TdCXFsQnLGS/3bZYfqcYgyrj6oEZhHYaX2jDLj9YmiHJLZFgO+EQ5s/xobYLwt0NOeZDaLD9aHIOQ9CblE4GAnplFZO1a03kQzhBmnV6LCSIi+rZlINYxuIvJLD9a28VkOeWlNszywwnCKlLwUhtmueEEYRVpvGCQl9ow6/ycIKwijV1MPpParPNzgrCK+IpyZvnhBGEV8VIbZvnhBGEV8VIbZvnhBGEVefdMamcIs87OCcIq0jRI7QRh1uk5QVhF5C4ms9xwgrCKNJ1J7WlMZp2eE4RVxEttmOWHE4RVxEttmOWHE4RVRBKSE4RZHjhBWMUKkhOEWQ44QVjFCvIYhFkeOEFYxeQWhFkuOEFYxWoknwdhlgNOEFaxgrzct1keOEFYxTxIbZYPmSYISWdJel5SnaRry2zvLmlBsv0pScOS8omSlkl6Nvn58SzjtMpIXmrDLA8ySxCSaoDZwNnASGCKpJHNql0CvB4RxwC3Ajcn5RuBcyPiBGAacHdWcVrlCgW3IMzyIMsWxBigLiLWRMRO4D5gUrM6k4C7kvv3A6dJUkT8LiLWJeWrgB6SumcYq1Wgxl1MZrmQZYI4Eni55HF9Ula2TkTsBrYAA5rV+Qzwu4jYkVGcViFJ7GmodhRmlrUuGT63ypQ1/7dzn3UkHUex2+mMsi8gTQemAwwdOrR1UVrFCoJwC8Ks08uyBVEPHFXyeAiwrqU6kroA/YDNyeMhwE+BL0TEi+VeICLmRERtRNQOGjToIIdvLfEsJrN8yDJBLAFGSBouqRswGVjYrM5CioPQAOcDv46IkNQf+Dnw9Yh4PMMYrRW81IZZPmSWIJIxhcuAR4DVwL9FxCpJMyV9Mqn2A2CApDrgaqBxKuxlwDHA9ZJWJLfDsorVKuNZTGb5kOUYBBHxEPBQs7IZJfe3AxeU2W8WMCvL2Kz1CpKvKGeWAz6T2irmLiazfHCCsIp5kNosH5wgrGJeasMsH5wgrGI1HqQ2ywUnCKtYQfJy32Y54ARhFSteUa7aUZhZ1pwgrGJeasMsH5wgrGKexWSWD04QVrHimdTVjsLMsuYEYRUrnijnDGHW2TlBWMXcxWSWD04QVrGCoMEXDDLr9JwgrGJyC8IsFzJdzdU6pxqJV954h7v+a221Q+kQ/uoDAxhxeN9qh2FWMScIq9gR/XvyxJpN3LBwVbVD6RDGHzuIu744ptphmFXMCcIq9k/nf4Tr/ubD1Q6jQ7hqwQpee3NHtcMwaxUnCKtYoSDe17tbtcPoEAb368Hq9VurHYZZq3iQ2ixDA/p0Y/NbO30FPuuQnCDMMjSgd3d2NwRbt++qdihmFXOCMMvQgD7FrriN23ZWORKzyjlBmGVoYJ/uAGza5oFq63icIMwy1NiC2PSWWxDW8ThBmGWocbaXWxDWETlBmGXofb08BmEdlxOEWYa61BQ4tFdXNr3lFoR1PE4QZhkb0Kc7m9yCsA7ICcIsYwN6d3OCsA7JCcIsYwP7dGeju5isA3KCMMvYgD5uQVjH5ARhlrEBvbuz5Z1d7Nzty/BZx+IEYZaxxpPlXn/brQjrWDJd7lvSWcC/ADXA/4mIbzXb3h2YD5wEbAIujIi1ybavA5cAe4ArIuKRLGM1y8rAJEHMeGAl/Xp2rXI07V//Xt24/OPH0LeHj1W1ZZYgJNUAs4GJQD2wRNLCiHiupNolwOsRcYykycDNwIWSRgKTgeOAI4BfSTo2IvZkFa9ZVo4/sh8fGNSb39dvqXYoHcKrW7ezev1W5l70UbrWuJOjmrJsQYwB6iJiDYCk+4BJQGmCmATcmNy/H7hDkpLy+yJiB/DfkuqS53siw3jNMjHk0F4s+sqp1Q6jw7jv6T9x7b8/y6n/9Bi9utVUO5wO4dQPDuK6vxl50J83ywRxJPByyeN6YGxLdSJit6QtwICk/Mlm+x7Z/AUkTQemAwwdOvSgBW5m1TN5zFD2RPB43cZqh9JhHH5Ij0yeN8sEoTJlzS+r1VKdNPsSEXOAOQC1tbW+ZJdZJzF17NFMHXt0tcPIvSw7+OqBo0oeDwHWtVRHUhegH7A55b5mZpahLBPEEmCEpOGSulEcdF7YrM5CYFpy/3zg1xERSflkSd0lDQdGAE9nGKuZmTWTWRdTMqZwGfAIxWmucyNilaSZwNKIWAj8ALg7GYTeTDGJkNT7N4oD2ruBf/AMJjOztqXiP+wdX21tbSxdurTaYZiZdSiSlkVEbbltnmRsZmZlOUGYmVlZThBmZlaWE4SZmZXVaQapJW0AXjqApxgItMdTNx1XZRxX5dprbI6rMq2N6+iIGFRuQ6dJEAdK0tKWRvKryXFVxnFVrr3G5rgqk0Vc7mIyM7OynCDMzKwsJ4h3zal2AC1wXJVxXJVrr7E5rsoc9Lg8BmFmZmW5BWFmZmU5QZiZWVm5TxCSzpL0vKQ6SddWMY6jJC2WtFrSKklfTspvlPSKpBXJ7RNVim+tpGeTGJYmZe+T9EtJLyQ/D23jmD5YclxWSNoq6cpqHDNJcyW9JmllSVnZ46Oi7yafud9LGt3Gcf2TpD8kr/1TSf2T8mGS3ik5bt/LKq59xNbieyfp68kxe17SmW0c14KSmNZKWpGUt9kx28d3RHafs4jI7Y3iMuQvAu8HugHPACOrFMtgYHRyvy/wR2AkxWt2f7UdHKu1wMBmZd8Grk3uXwvcXOX38s/A0dU4ZsDHgNHAyv0dH+ATwC8oXjnxZOCpNo7rDKBLcv/mkriGldar0jEr+94lfwvPAN2B4cnfbU1bxdVs+z8DM9r6mO3jOyKzz1neWxBjgLqIWBMRO4H7gEnVCCQi1kfE8uT+m8BqylyHu52ZBNyV3L8L+FQVYzkNeDEiDuRs+laLiP+keE2TUi0dn0nA/Ch6EugvaXBbxRURj0bE7uThkxSv2NjmWjhmLZkE3BcROyLiv4E6in+/bRqXJAF/C/w4i9fel318R2T2Oct7gjgSeLnkcT3t4EtZ0jDgL4GnkqLLkibi3LbuxikRwKOSlkmanpQdHhHrofjhBQ6rUmxQvNhU6R9tezhmLR2f9vS5+yLF/zIbDZf0O0n/IemvqxRTufeuvRyzvwZejYgXSsra/Jg1+47I7HOW9wShMmVVnfcrqQ/wf4ErI2Ir8K/AB4ATgfUUm7fVcEpEjAbOBv5B0seqFMdeVLyk7SeBnyRF7eWYtaRdfO4kXUfxio33JEXrgaER8ZfA1cC9kg5p47Baeu/axTEDpvDef0Ta/JiV+Y5osWqZsoqOWd4TRD1wVMnjIcC6KsWCpK4U3/h7IuLfASLi1YjYExENwJ1k1Kzen4hYl/x8DfhpEserjU3W5Odr1YiNYtJaHhGvJjG2i2NGy8en6p87SdOAc4CpkXRYJ903m5L7yyj28x/blnHt471rD8esC/BpYEFjWVsfs3LfEWT4Oct7glgCjJA0PPkvdDKwsBqBJH2bPwBWR8QtJeWlfYbnASub79sGsfWW1LfxPsVBzpUUj9W0pNo04IG2ji3xnv/q2sMxS7R0fBYCX0hmmZwMbGnsImgLks4CrgE+GRFvl5QPklST3H8/MAJY01ZxJa/b0nu3EJgsqbuk4UlsT7dlbMDpwB8ior6xoC2PWUvfEWT5OWuL0ff2fKM40v9Hipn/uirG8T8oNv9+D6xIbp8A7gaeTcoXAoOrENv7Kc4geQZY1XicgAHAIuCF5Of7qhBbL2AT0K+krM2PGcUEtR7YRfE/t0taOj4Um/6zk8/cs0BtG8dVR7FvuvFz9r2k7meS9/cZYDlwbhWOWYvvHXBdcqCff88AAAHFSURBVMyeB85uy7iS8nnA3zWr22bHbB/fEZl9zrzUhpmZlZX3LiYzM2uBE4SZmZXlBGFmZmU5QZiZWVlOEGZmVpYThFkFJO3Re1eQPWgrACcrg1brnA2zvXSpdgBmHcw7EXFitYMwawtuQZgdBMk1Am6W9HRyOyYpP1rSomTxuUWShiblh6t4LYZnkttfJU9VI+nOZL3/RyX1rNovZbnnBGFWmZ7NupguLNm2NSLGAHcAtyVld1BccvkjFBfF+25S/l3gPyJiFMVrD6xKykcAsyPiOOANimfqmlWFz6Q2q4CkbRHRp0z5WuDjEbEmWVDtzxExQNJGistF7ErK10fEQEkbgCERsaPkOYYBv4yIEcnja4CuETEr+9/MbG9uQZgdPNHC/ZbqlLOj5P4ePE5oVeQEYXbwXFjy84nk/n9RXCUYYCrw2+T+IuDvASTVVOG6C2b75f9OzCrTU8kF6xMPR0TjVNfukp6i+I/XlKTsCmCupK8BG4CLk/IvA3MkXUKxpfD3FFcQNWs3PAZhdhAkYxC1EbGx2rGYHSzuYjIzs7LcgjAzs7LcgjAzs7KcIMzMrCwnCDMzK8sJwszMynKCMDOzsv4/N5mmVdVM3LIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(1)\n", "\n", "LR = 0.1\n", "iteration = 10\n", "max_epoch = 200\n", "# ------------------------------ fake data and optimizer ------------------------------\n", "\n", "weights = torch.randn((1), requires_grad=True)\n", "target = torch.zeros((1))\n", "\n", "optimizer = optim.SGD([weights], lr=LR, momentum=0.9)\n", "\n", "# ------------------------------ 1 Step LR ------------------------------\n", "\n", "scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # 设置学习率下降策略\n", "\n", "lr_list, epoch_list = list(), list()\n", "for epoch in range(max_epoch):\n", "\n", " # 获取当前lr,新版本用 get_last_lr()函数,旧版本用get_lr()函数,具体看UserWarning\n", " lr_list.append(scheduler_lr.get_lr())\n", " epoch_list.append(epoch)\n", "\n", " for i in range(iteration):\n", "\n", " loss = torch.pow((weights - target), 2)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " scheduler_lr.step()\n", "\n", "plt.plot(epoch_list, lr_list, label=\"Step LR Scheduler\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Learning rate\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q:如何按给定间隔调整学习率?\n", "- `torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)`\n", "- milestones:设定调整时刻数\n", "- gamma:调整系数\n", "\n", "Q:MultiStepLR的代码示例" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(1)\n", "\n", "LR = 0.1\n", "iteration = 10\n", "max_epoch = 200\n", "# ------------------------------ fake data and optimizer ------------------------------\n", "\n", "weights = torch.randn((1), requires_grad=True)\n", "target = torch.zeros((1))\n", "\n", "optimizer = optim.SGD([weights], lr=LR, momentum=0.9)\n", "\n", "# ------------------------------ 2 Multi Step LR ------------------------------\n", "milestones = [50, 125, 160]\n", "scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)\n", "\n", "lr_list, epoch_list = list(), list()\n", "for epoch in range(max_epoch):\n", "\n", " lr_list.append(scheduler_lr.get_lr())\n", " epoch_list.append(epoch)\n", "\n", " for i in range(iteration):\n", "\n", " loss = torch.pow((weights - target), 2)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " scheduler_lr.step()\n", "\n", "plt.plot(epoch_list, lr_list, label=\"Multi Step LR Scheduler\\nmilestones:{}\".format(milestones))\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Learning rate\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q:如何按指数衰减调整学习率?\n", "- `torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)`\n", "- gamma:指数的底,通常设为一个接近1的数字,如0.95\n", "- 调整方式:lr = lr * gamma ** epoch\n", "\n", "Q:ExponentialLR代码示例" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(1)\n", "\n", "LR = 0.1\n", "iteration = 10\n", "max_epoch = 200\n", "# ------------------------------ fake data and optimizer ------------------------------\n", "\n", "weights = torch.randn((1), requires_grad=True)\n", "target = torch.zeros((1))\n", "\n", "optimizer = optim.SGD([weights], lr=LR, momentum=0.9)\n", "\n", "# ------------------------------ 3 Exponential LR ------------------------------\n", "gamma = 0.95\n", "scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)\n", "\n", "lr_list, epoch_list = list(), list()\n", "for epoch in range(max_epoch):\n", "\n", " lr_list.append(scheduler_lr.get_lr())\n", " epoch_list.append(epoch)\n", "\n", " for i in range(iteration):\n", "\n", " loss = torch.pow((weights - target), 2)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " scheduler_lr.step()\n", "\n", "plt.plot(epoch_list, lr_list, label=\"Exponential LR Scheduler\\ngamma:{}\".format(gamma))\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Learning rate\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q:如何按余弦周期调整学习率?\n", "- `torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)`\n", "- T_max:下降周期\n", "- eta_min:学习率下限\n", "- 调整方式:$$\\eta_{t}=\\eta_{\\min }+\\frac{1}{2}\\left(\\eta_{\\max }-\\eta_{\\min }\\right)\\left(1+\\cos \\left(\\frac{T_{c u r}}{T_{\\max }} \\pi\\right)\\right)$$\n", "\n", "Q:CosineAnnealingLR代码示例" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(1)\n", "\n", "LR = 0.1\n", "iteration = 10\n", "max_epoch = 200\n", "# ------------------------------ fake data and optimizer ------------------------------\n", "\n", "weights = torch.randn((1), requires_grad=True)\n", "target = torch.zeros((1))\n", "\n", "optimizer = optim.SGD([weights], lr=LR, momentum=0.9)\n", "\n", "# ------------------------------ 4 Cosine Annealing LR ------------------------------\n", "t_max = 50\n", "scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)\n", "\n", "lr_list, epoch_list = list(), list()\n", "for epoch in range(max_epoch):\n", "\n", " lr_list.append(scheduler_lr.get_lr())\n", " epoch_list.append(epoch)\n", "\n", " for i in range(iteration):\n", "\n", " loss = torch.pow((weights - target), 2)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " scheduler_lr.step()\n", "\n", "plt.plot(epoch_list, lr_list, label=\"CosineAnnealingLR Scheduler\\nT_max:{}\".format(t_max))\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Learning rate\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q:如何监控指标, 当指标不再变化则调整学习率?\n", "- `torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)`\n", "- mode:min/max两种模式:如min代表不下降就调整\n", "- factor:调整系数\n", "- patience:“耐心”,接受几次不变化\n", "- cooldown:“冷却时间”,停止监控一段时间\n", "- verbose:是否打印日志\n", "- min_lr:学习率下限\n", "- eps:学习率衰减最小值\n", "\n", "Q:ReduceLROnPlateau代码示例" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16: reducing learning rate of group 0 to 1.0000e-02.\n", "Epoch 37: reducing learning rate of group 0 to 1.0000e-03.\n", "Epoch 58: reducing learning rate of group 0 to 1.0000e-04.\n" ] } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(1)\n", "\n", "LR = 0.1\n", "iteration = 10\n", "max_epoch = 200\n", "# ------------------------------ fake data and optimizer ------------------------------\n", "\n", "weights = torch.randn((1), requires_grad=True)\n", "target = torch.zeros((1))\n", "\n", "optimizer = optim.SGD([weights], lr=LR, momentum=0.9)\n", "\n", "# ------------------------------ 5 Reduce LR On Plateau ------------------------------\n", "loss_value = 0.5\n", "accuray = 0.9\n", "\n", "factor = 0.1\n", "mode = \"min\"\n", "patience = 10\n", "cooldown = 10\n", "min_lr = 1e-4\n", "verbose = True\n", "\n", "scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,\n", " cooldown=cooldown, min_lr=min_lr, verbose=verbose)\n", "\n", "for epoch in range(max_epoch):\n", " for i in range(iteration):\n", "\n", " # train(...)\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " if epoch == 5:\n", " loss_value = 0.4\n", "\n", " scheduler_lr.step(loss_value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q:如何自定义学习率调整策略?\n", "- `torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)`\n", "- lr_lambda:function or list,若为list,里面每个元素须为function\n", "\n", "Q:LambdaLR的代码示例" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, lr:[0.1, 0.095]\n", "epoch: 1, lr:[0.1, 0.09025]\n", "epoch: 2, lr:[0.1, 0.0857375]\n", "epoch: 3, lr:[0.1, 0.081450625]\n", "epoch: 4, lr:[0.1, 0.07737809374999999]\n", "epoch: 5, lr:[0.1, 0.07350918906249998]\n", "epoch: 6, lr:[0.1, 0.06983372960937498]\n", "epoch: 7, lr:[0.1, 0.06634204312890622]\n", "epoch: 8, lr:[0.1, 0.0630249409724609]\n", "epoch: 9, lr:[0.1, 0.05987369392383787]\n", "epoch: 10, lr:[0.1, 0.05688000922764597]\n", "epoch: 11, lr:[0.1, 0.05403600876626367]\n", "epoch: 12, lr:[0.1, 0.051334208327950485]\n", "epoch: 13, lr:[0.1, 0.04876749791155296]\n", "epoch: 14, lr:[0.1, 0.046329123015975304]\n", "epoch: 15, lr:[0.1, 0.04401266686517654]\n", "epoch: 16, lr:[0.1, 0.04181203352191771]\n", "epoch: 17, lr:[0.1, 0.039721431845821824]\n", "epoch: 18, lr:[0.1, 0.03773536025353073]\n", "epoch: 19, lr:[0.010000000000000002, 0.03584859224085419]\n", "epoch: 20, lr:[0.010000000000000002, 0.03405616262881148]\n", "epoch: 21, lr:[0.010000000000000002, 0.0323533544973709]\n", "epoch: 22, lr:[0.010000000000000002, 0.03073568677250236]\n", "epoch: 23, lr:[0.010000000000000002, 0.02919890243387724]\n", "epoch: 24, lr:[0.010000000000000002, 0.027738957312183378]\n", "epoch: 25, lr:[0.010000000000000002, 0.026352009446574204]\n", "epoch: 26, lr:[0.010000000000000002, 0.025034408974245494]\n", "epoch: 27, lr:[0.010000000000000002, 0.023782688525533217]\n", "epoch: 28, lr:[0.010000000000000002, 0.022593554099256556]\n", "epoch: 29, lr:[0.010000000000000002, 0.02146387639429373]\n", "epoch: 30, lr:[0.010000000000000002, 0.02039068257457904]\n", "epoch: 31, lr:[0.010000000000000002, 0.019371148445850087]\n", "epoch: 32, lr:[0.010000000000000002, 0.018402591023557582]\n", "epoch: 33, lr:[0.010000000000000002, 0.017482461472379703]\n", "epoch: 34, lr:[0.010000000000000002, 0.016608338398760716]\n", "epoch: 35, lr:[0.010000000000000002, 0.01577792147882268]\n", "epoch: 36, lr:[0.010000000000000002, 0.014989025404881546]\n", "epoch: 37, lr:[0.010000000000000002, 0.014239574134637467]\n", "epoch: 38, lr:[0.010000000000000002, 0.013527595427905593]\n", "epoch: 39, lr:[0.0010000000000000002, 0.012851215656510312]\n", "epoch: 40, lr:[0.0010000000000000002, 0.012208654873684797]\n", "epoch: 41, lr:[0.0010000000000000002, 0.011598222130000557]\n", "epoch: 42, lr:[0.0010000000000000002, 0.011018311023500529]\n", "epoch: 43, lr:[0.0010000000000000002, 0.010467395472325502]\n", "epoch: 44, lr:[0.0010000000000000002, 0.009944025698709225]\n", "epoch: 45, lr:[0.0010000000000000002, 0.009446824413773765]\n", "epoch: 46, lr:[0.0010000000000000002, 0.008974483193085076]\n", "epoch: 47, lr:[0.0010000000000000002, 0.00852575903343082]\n", "epoch: 48, lr:[0.0010000000000000002, 0.00809947108175928]\n", "epoch: 49, lr:[0.0010000000000000002, 0.007694497527671315]\n", "epoch: 50, lr:[0.0010000000000000002, 0.007309772651287749]\n", "epoch: 51, lr:[0.0010000000000000002, 0.006944284018723362]\n", "epoch: 52, lr:[0.0010000000000000002, 0.0065970698177871935]\n", "epoch: 53, lr:[0.0010000000000000002, 0.006267216326897833]\n", "epoch: 54, lr:[0.0010000000000000002, 0.005953855510552941]\n", "epoch: 55, lr:[0.0010000000000000002, 0.005656162735025293]\n", "epoch: 56, lr:[0.0010000000000000002, 0.005373354598274029]\n", "epoch: 57, lr:[0.0010000000000000002, 0.005104686868360327]\n", "epoch: 58, lr:[0.0010000000000000002, 0.004849452524942311]\n", "epoch: 59, lr:[0.00010000000000000003, 0.004606979898695194]\n", "epoch: 60, lr:[0.00010000000000000003, 0.004376630903760435]\n", "epoch: 61, lr:[0.00010000000000000003, 0.004157799358572413]\n", "epoch: 62, lr:[0.00010000000000000003, 0.003949909390643792]\n", "epoch: 63, lr:[0.00010000000000000003, 0.003752413921111602]\n", "epoch: 64, lr:[0.00010000000000000003, 0.003564793225056022]\n", "epoch: 65, lr:[0.00010000000000000003, 0.0033865535638032207]\n", "epoch: 66, lr:[0.00010000000000000003, 0.0032172258856130592]\n", "epoch: 67, lr:[0.00010000000000000003, 0.0030563645913324064]\n", "epoch: 68, lr:[0.00010000000000000003, 0.002903546361765786]\n", "epoch: 69, lr:[0.00010000000000000003, 0.0027583690436774966]\n", "epoch: 70, lr:[0.00010000000000000003, 0.0026204505914936217]\n", "epoch: 71, lr:[0.00010000000000000003, 0.0024894280619189406]\n", "epoch: 72, lr:[0.00010000000000000003, 0.0023649566588229936]\n", "epoch: 73, lr:[0.00010000000000000003, 0.0022467088258818434]\n", "epoch: 74, lr:[0.00010000000000000003, 0.002134373384587751]\n", "epoch: 75, lr:[0.00010000000000000003, 0.0020276547153583635]\n", "epoch: 76, lr:[0.00010000000000000003, 0.0019262719795904452]\n", "epoch: 77, lr:[0.00010000000000000003, 0.001829958380610923]\n", "epoch: 78, lr:[0.00010000000000000003, 0.0017384604615803768]\n", "epoch: 79, lr:[1.0000000000000003e-05, 0.001651537438501358]\n", "epoch: 80, lr:[1.0000000000000003e-05, 0.00156896056657629]\n", "epoch: 81, lr:[1.0000000000000003e-05, 0.0014905125382474755]\n", "epoch: 82, lr:[1.0000000000000003e-05, 0.0014159869113351015]\n", "epoch: 83, lr:[1.0000000000000003e-05, 0.0013451875657683465]\n", "epoch: 84, lr:[1.0000000000000003e-05, 0.001277928187479929]\n", "epoch: 85, lr:[1.0000000000000003e-05, 0.0012140317781059325]\n", "epoch: 86, lr:[1.0000000000000003e-05, 0.0011533301892006358]\n", "epoch: 87, lr:[1.0000000000000003e-05, 0.001095663679740604]\n", "epoch: 88, lr:[1.0000000000000003e-05, 0.0010408804957535737]\n", "epoch: 89, lr:[1.0000000000000003e-05, 0.000988836470965895]\n", "epoch: 90, lr:[1.0000000000000003e-05, 0.0009393946474176001]\n", "epoch: 91, lr:[1.0000000000000003e-05, 0.0008924249150467202]\n", "epoch: 92, lr:[1.0000000000000003e-05, 0.0008478036692943841]\n", "epoch: 93, lr:[1.0000000000000003e-05, 0.0008054134858296649]\n", "epoch: 94, lr:[1.0000000000000003e-05, 0.0007651428115381816]\n", "epoch: 95, lr:[1.0000000000000003e-05, 0.0007268856709612725]\n", "epoch: 96, lr:[1.0000000000000003e-05, 0.0006905413874132089]\n", "epoch: 97, lr:[1.0000000000000003e-05, 0.0006560143180425484]\n", "epoch: 98, lr:[1.0000000000000003e-05, 0.0006232136021404209]\n", "epoch: 99, lr:[1.0000000000000004e-06, 0.0005920529220333997]\n", "epoch: 100, lr:[1.0000000000000004e-06, 0.0005624502759317298]\n", "epoch: 101, lr:[1.0000000000000004e-06, 0.0005343277621351433]\n", "epoch: 102, lr:[1.0000000000000004e-06, 0.0005076113740283861]\n", "epoch: 103, lr:[1.0000000000000004e-06, 0.00048223080532696673]\n", "epoch: 104, lr:[1.0000000000000004e-06, 0.0004581192650606184]\n", "epoch: 105, lr:[1.0000000000000004e-06, 0.00043521330180758743]\n", "epoch: 106, lr:[1.0000000000000004e-06, 0.00041345263671720806]\n", "epoch: 107, lr:[1.0000000000000004e-06, 0.0003927800048813476]\n", "epoch: 108, lr:[1.0000000000000004e-06, 0.00037314100463728026]\n", "epoch: 109, lr:[1.0000000000000004e-06, 0.00035448395440541624]\n", "epoch: 110, lr:[1.0000000000000004e-06, 0.0003367597566851454]\n", "epoch: 111, lr:[1.0000000000000004e-06, 0.0003199217688508881]\n", "epoch: 112, lr:[1.0000000000000004e-06, 0.0003039256804083437]\n", "epoch: 113, lr:[1.0000000000000004e-06, 0.0002887293963879265]\n", "epoch: 114, lr:[1.0000000000000004e-06, 0.00027429292656853016]\n", "epoch: 115, lr:[1.0000000000000004e-06, 0.00026057828024010366]\n", "epoch: 116, lr:[1.0000000000000004e-06, 0.0002475493662280985]\n", "epoch: 117, lr:[1.0000000000000004e-06, 0.00023517189791669353]\n", "epoch: 118, lr:[1.0000000000000004e-06, 0.0002234133030208588]\n", "epoch: 119, lr:[1.0000000000000005e-07, 0.00021224263786981585]\n", "epoch: 120, lr:[1.0000000000000005e-07, 0.00020163050597632508]\n", "epoch: 121, lr:[1.0000000000000005e-07, 0.0001915489806775088]\n", "epoch: 122, lr:[1.0000000000000005e-07, 0.00018197153164363337]\n", "epoch: 123, lr:[1.0000000000000005e-07, 0.00017287295506145168]\n", "epoch: 124, lr:[1.0000000000000005e-07, 0.00016422930730837908]\n", "epoch: 125, lr:[1.0000000000000005e-07, 0.00015601784194296014]\n", "epoch: 126, lr:[1.0000000000000005e-07, 0.00014821694984581212]\n", "epoch: 127, lr:[1.0000000000000005e-07, 0.0001408061023535215]\n", "epoch: 128, lr:[1.0000000000000005e-07, 0.00013376579723584542]\n", "epoch: 129, lr:[1.0000000000000005e-07, 0.00012707750737405313]\n", "epoch: 130, lr:[1.0000000000000005e-07, 0.00012072363200535048]\n", "epoch: 131, lr:[1.0000000000000005e-07, 0.00011468745040508295]\n", "epoch: 132, lr:[1.0000000000000005e-07, 0.0001089530778848288]\n", "epoch: 133, lr:[1.0000000000000005e-07, 0.00010350542399058736]\n", "epoch: 134, lr:[1.0000000000000005e-07, 9.833015279105799e-05]\n", "epoch: 135, lr:[1.0000000000000005e-07, 9.341364515150508e-05]\n", "epoch: 136, lr:[1.0000000000000005e-07, 8.874296289392982e-05]\n", "epoch: 137, lr:[1.0000000000000005e-07, 8.430581474923332e-05]\n", "epoch: 138, lr:[1.0000000000000005e-07, 8.009052401177165e-05]\n", "epoch: 139, lr:[1.0000000000000004e-08, 7.608599781118307e-05]\n", "epoch: 140, lr:[1.0000000000000004e-08, 7.228169792062392e-05]\n", "epoch: 141, lr:[1.0000000000000004e-08, 6.866761302459272e-05]\n", "epoch: 142, lr:[1.0000000000000004e-08, 6.523423237336307e-05]\n", "epoch: 143, lr:[1.0000000000000004e-08, 6.197252075469492e-05]\n", "epoch: 144, lr:[1.0000000000000004e-08, 5.8873894716960165e-05]\n", "epoch: 145, lr:[1.0000000000000004e-08, 5.593019998111216e-05]\n", "epoch: 146, lr:[1.0000000000000004e-08, 5.313368998205655e-05]\n", "epoch: 147, lr:[1.0000000000000004e-08, 5.0477005482953716e-05]\n", "epoch: 148, lr:[1.0000000000000004e-08, 4.795315520880603e-05]\n", "epoch: 149, lr:[1.0000000000000004e-08, 4.555549744836572e-05]\n", "epoch: 150, lr:[1.0000000000000004e-08, 4.327772257594744e-05]\n", "epoch: 151, lr:[1.0000000000000004e-08, 4.1113836447150066e-05]\n", "epoch: 152, lr:[1.0000000000000004e-08, 3.905814462479256e-05]\n", "epoch: 153, lr:[1.0000000000000004e-08, 3.710523739355293e-05]\n", "epoch: 154, lr:[1.0000000000000004e-08, 3.524997552387528e-05]\n", "epoch: 155, lr:[1.0000000000000004e-08, 3.3487476747681514e-05]\n", "epoch: 156, lr:[1.0000000000000004e-08, 3.181310291029744e-05]\n", "epoch: 157, lr:[1.0000000000000004e-08, 3.0222447764782564e-05]\n", "epoch: 158, lr:[1.0000000000000004e-08, 2.8711325376543437e-05]\n", "epoch: 159, lr:[1.0000000000000005e-09, 2.7275759107716264e-05]\n", "epoch: 160, lr:[1.0000000000000005e-09, 2.5911971152330445e-05]\n", "epoch: 161, lr:[1.0000000000000005e-09, 2.4616372594713925e-05]\n", "epoch: 162, lr:[1.0000000000000005e-09, 2.3385553964978226e-05]\n", "epoch: 163, lr:[1.0000000000000005e-09, 2.2216276266729317e-05]\n", "epoch: 164, lr:[1.0000000000000005e-09, 2.110546245339285e-05]\n", "epoch: 165, lr:[1.0000000000000005e-09, 2.0050189330723204e-05]\n", "epoch: 166, lr:[1.0000000000000005e-09, 1.9047679864187045e-05]\n", "epoch: 167, lr:[1.0000000000000005e-09, 1.809529587097769e-05]\n", "epoch: 168, lr:[1.0000000000000005e-09, 1.7190531077428805e-05]\n", "epoch: 169, lr:[1.0000000000000005e-09, 1.6331004523557364e-05]\n", "epoch: 170, lr:[1.0000000000000005e-09, 1.5514454297379498e-05]\n", "epoch: 171, lr:[1.0000000000000005e-09, 1.4738731582510519e-05]\n", "epoch: 172, lr:[1.0000000000000005e-09, 1.4001795003384993e-05]\n", "epoch: 173, lr:[1.0000000000000005e-09, 1.3301705253215743e-05]\n", "epoch: 174, lr:[1.0000000000000005e-09, 1.2636619990554954e-05]\n", "epoch: 175, lr:[1.0000000000000005e-09, 1.2004788991027206e-05]\n", "epoch: 176, lr:[1.0000000000000005e-09, 1.1404549541475845e-05]\n", "epoch: 177, lr:[1.0000000000000005e-09, 1.0834322064402054e-05]\n", "epoch: 178, lr:[1.0000000000000005e-09, 1.029260596118195e-05]\n", "epoch: 179, lr:[1.0000000000000006e-10, 9.777975663122852e-06]\n", "epoch: 180, lr:[1.0000000000000006e-10, 9.28907687996671e-06]\n", "epoch: 181, lr:[1.0000000000000006e-10, 8.824623035968373e-06]\n", "epoch: 182, lr:[1.0000000000000006e-10, 8.383391884169954e-06]\n", "epoch: 183, lr:[1.0000000000000006e-10, 7.964222289961456e-06]\n", "epoch: 184, lr:[1.0000000000000006e-10, 7.566011175463383e-06]\n", "epoch: 185, lr:[1.0000000000000006e-10, 7.187710616690214e-06]\n", "epoch: 186, lr:[1.0000000000000006e-10, 6.828325085855702e-06]\n", "epoch: 187, lr:[1.0000000000000006e-10, 6.486908831562916e-06]\n", "epoch: 188, lr:[1.0000000000000006e-10, 6.16256338998477e-06]\n", "epoch: 189, lr:[1.0000000000000006e-10, 5.854435220485532e-06]\n", "epoch: 190, lr:[1.0000000000000006e-10, 5.5617134594612554e-06]\n", "epoch: 191, lr:[1.0000000000000006e-10, 5.283627786488193e-06]\n", "epoch: 192, lr:[1.0000000000000006e-10, 5.0194463971637825e-06]\n", "epoch: 193, lr:[1.0000000000000006e-10, 4.768474077305593e-06]\n", "epoch: 194, lr:[1.0000000000000006e-10, 4.5300503734403135e-06]\n", "epoch: 195, lr:[1.0000000000000006e-10, 4.3035478547682975e-06]\n", "epoch: 196, lr:[1.0000000000000006e-10, 4.088370462029883e-06]\n", "epoch: 197, lr:[1.0000000000000006e-10, 3.883951938928388e-06]\n", "epoch: 198, lr:[1.0000000000000006e-10, 3.6897543419819688e-06]\n", "epoch: 199, lr:[1.0000000000000006e-11, 3.5052666248828703e-06]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(1)\n", "\n", "LR = 0.1\n", "iteration = 10\n", "max_epoch = 200\n", "# ------------------------------ fake data and optimizer ------------------------------\n", "\n", "weights = torch.randn((1), requires_grad=True)\n", "target = torch.zeros((1))\n", "\n", "optimizer = optim.SGD([weights], lr=LR, momentum=0.9)\n", "\n", "# ------------------------------ 6 lambda ------------------------------\n", "lr_init = 0.1\n", "\n", "weights_1 = torch.randn((6, 3, 5, 5))\n", "weights_2 = torch.ones((5, 5))\n", "\n", "optimizer = optim.SGD([\n", " {'params': [weights_1]},\n", " {'params': [weights_2]}], lr=lr_init)\n", "\n", "lambda1 = lambda epoch: 0.1 ** (epoch // 20)\n", "lambda2 = lambda epoch: 0.95 ** epoch\n", "\n", "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])\n", "\n", "lr_list, epoch_list = list(), list()\n", "for epoch in range(max_epoch):\n", " for i in range(iteration):\n", "\n", " # train(...)\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " scheduler.step()\n", "\n", " lr_list.append(scheduler.get_lr())\n", " epoch_list.append(epoch)\n", "\n", " print('epoch:{:5d}, lr:{}'.format(epoch, scheduler.get_lr()))\n", "\n", "plt.plot(epoch_list, [i[0] for i in lr_list], label=\"lambda 1\")\n", "plt.plot(epoch_list, [i[1] for i in lr_list], label=\"lambda 2\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Learning Rate\")\n", "plt.title(\"LambdaLR\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2.TensorBoard简介与安装\n", "\n", "Q:TensorBoard安装和启动命令?\n", "- 安装:`conda install tensorboard typing-extensions`\n", "- 启动:`tensorboard --logdir=./runs`\n", "\n", "Q:TensorBoard测试代码" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from torch.utils.tensorboard import SummaryWriter\n", "\n", "\n", "writer = SummaryWriter(comment='test_tensorboard')\n", "\n", "for x in range(100):\n", "\n", " writer.add_scalar('y=2x', x * 2, x)\n", " writer.add_scalar('y=pow(2, x)', 2 ** x, x)\n", " \n", " writer.add_scalars('data/scalar_group', {\"xsinx\": x * np.sin(x),\n", " \"xcosx\": x * np.cos(x),\n", " \"arctanx\": np.arctan(x)}, x)\n", "writer.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.TensorBoard使用(一)\n", "\n", "Q:SummaryWriter的功能和属性是什么?\n", "- 功能:提供创建event file的高级接口\n", "- ```python\n", "class SummaryWriter(object):\n", " def __init__(self, log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix=’’)\n", "```\n", "- log_dir:event file输出文件夹\n", "- comment:不指定log_dir时,文件夹后缀\n", "- filename_suffix:event file文件名后缀\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }