{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 随机梯度下降" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度下降计算 Loss:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "L ( W ) = \\frac { 1 } { N } \\sum _ { i = 1 } ^ { N } L _ { i } \\left( x _ { i } , y _ { i } , W \\right)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度下降计算 Loss 关于权重的梯度" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\nabla _ { W } L ( W ) = \\frac { 1 } { N } \\sum _ { i = 1 } ^ { N } \\nabla _ { W } L _ { i } \\left( x _ { i } , y _ { i } , W \\right)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当 $N$ 非常大时,全批量计算是不可能的,没有这么大的内存和显存可以容纳。这时候使用一个 minibatch 来估计全部数据集,minibatch 通常为 32/64/128。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SGD 示例:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://tuchuang-1252747889.cosgz.myqcloud.com/2018-11-29-1601E590-7AEE-4743-B77B-10C783A461A4.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "按批读取数据,计算梯度更新参数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来模仿 Pytorch 实现一个 Dataloader,重写 `__getitem__`、`__iter__`、`__len__`,因此可以根据下标获取数据、迭代数据和获取数据长度。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class Dataloader(object):\n", " def __init__(self, data, labels, batch_size, shuffle=True):\n", " self.data = data\n", " self.batch_size = batch_size\n", " self.shuffle = shuffle\n", " self.labels = labels\n", " \n", " def __getitem__(self, index):\n", " return self.data[index], self.labels[index]\n", " \n", " def __iter__(self):\n", " datasize = self.data.shape[0]\n", " data_seq = np.arange(datasize)\n", " if self.shuffle:\n", " np.random.shuffle(data_seq)\n", " interval_list = np.append(np.arange(0, datasize, self.batch_size), datasize)\n", " for index in range(interval_list.shape[0]-1):\n", " s = data_seq[interval_list[index]:interval_list[index+1]]\n", " yield self.data[s], self.labels[s]\n", " \n", " def __len__(self):\n", " return self.data.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch 中首先需要重写 [torchvision.datasets](https://pytorch.org/docs/stable/torchvision/datasets.html),然后使用 [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 加载数据,支持并行加载和打乱数据。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SGD 只是最简单的一个优化器,通常不会单独使用,SDG+Momentum 和 Adam 更为常用,几种优化器只是在参数的更新上有所差别。希望能整理一下各种优化器的知识。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 线性层实现" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在 [Softmax、K-L 散度、交叉熵和 Cross Entropy Loss 推导和实现](https://hzzone.io/3.%20cs231n/Softmax、KL%20散度和%20Cross%20Entropy%20Loss%20推导和实现.html) 推到了交叉熵损失函数,实现了一个线性分类器。\n", "\n", "其实线性层的基本算法就是简单的 $y=Wx+b$,这个线性分类器就是单层神经网络。\n", "\n", "因此对于参数 $W$ 和 $b$ 求导可得:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial y}{\\partial W}=x^T$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial y}{\\partial b}=1$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后再用链式法则乘以上一层的梯度即可。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class Linear(object):\n", " def __init__(self, D_in, D_out):\n", " self.weight = np.random.randn(D_in, D_out).astype(np.float32)*0.01\n", " self.bias = np.zeros((1, D_out), dtype=np.float32)\n", " \n", " def forward(self, input):\n", " self.data = input\n", " return np.dot(self.data, self.weight)+self.bias\n", " \n", " def backward(self, top_grad, lr):\n", " self.grad = np.dot(top_grad, self.weight.T).astype(np.float32)\n", " # 更新参数\n", " self.weight -= lr*np.dot(self.data.T, top_grad)\n", " self.bias -= lr*np.mean(top_grad, axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来使用 SGD 重新训练一个单层神经网络分类 mnist。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from utils import read_mnist\n", "from nn import CrossEntropyLossLayer, lr_scheduler\n", "\n", "# 读取并归一化数据,不归一化会导致 nan\n", "test_data = ((read_mnist('../data/mnist/t10k-images.idx3-ubyte').reshape((-1, 784))-127.0)/255.0).astype(np.float32)\n", "train_data = ((read_mnist('../data/mnist/train-images.idx3-ubyte').reshape((-1, 784))-127.0)/255.0).astype(np.float32)\n", "# 独热编码标签\n", "from sklearn.preprocessing import OneHotEncoder\n", "encoder = OneHotEncoder()\n", "encoder.fit(np.arange(10).reshape((-1, 1)))\n", "train_labels = encoder.transform(read_mnist('../data/mnist/train-labels.idx1-ubyte').reshape((-1, 1))).toarray().astype(np.float32)\n", "test_labels = encoder.transform(read_mnist('../data/mnist/t10k-labels.idx1-ubyte').reshape((-1, 1))).toarray().astype(np.float32)\n", "\n", "loss_layer = CrossEntropyLossLayer()\n", "lr = 0.1\n", "D, C = 784, 10\n", "np.random.seed(1) # 固定随机生成的权重\n", "best_acc = -float('inf')\n", "max_iter = 900\n", "step_size = 400\n", "scheduler = lr_scheduler(lr, step_size)\n", "loss_list = []\n", "\n", "batch_size = 120\n", "\n", "train_dataloader = Dataloader(train_data, train_labels, batch_size, shuffle=True)\n", "test_dataloader = Dataloader(test_data, test_labels, batch_size, shuffle=False)\n", "\n", "linear_classifer = Linear(D, C)\n", "\n", "from tqdm import tqdm_notebook\n", "for epoch in tqdm_notebook(range(max_iter)):\n", " # 测试\n", " correct = 0\n", " for data, labels in test_dataloader:\n", " test_pred = linear_classifer.forward(data)\n", " pred_labels = np.argmax(test_pred, axis=1)\n", " real_labels = np.argmax(labels, axis=1)\n", " correct += np.sum(pred_labels==real_labels)\n", " acc = correct/len(test_dataloader)\n", " if acc>best_acc: best_acc=acc\n", " # 训练\n", " total_loss = 0\n", " for data, labels in test_dataloader:\n", " train_pred = linear_classifer.forward(data)\n", " loss = loss_layer.forward(train_pred, labels)\n", " total_loss += loss\n", " loss_layer.backward()\n", " linear_classifer.backward(loss_layer.grad, scheduler.get_lr())\n", " loss_list.append(total_loss)\n", " scheduler.step()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.967" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "绘制 Loss 曲线。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAGfNJREFUeJzt3XtwVOeZ5/Hv02rdBUKXBmRuAoMdO876priwPTNJ7HgmyaQCtWtnk8qs2S3XUjWV2kkmqZo4s3+4pmq3Kq7KxJnUbGWXsZMlqUzGGcYJrKcqsxSxd3YzCY5wjG0MNpirQIAASYBAEpKe/aNfXenWaYRardP9+xSq0+f02+qnD83vvLznZu6OiIjEX6LQBYiIyOxQoIuIFAkFuohIkVCgi4gUCQW6iEiRUKCLiBQJBbqISJFQoIuIFAkFuohIkUjO5Zs1Nzd7a2vrXL6liEjs7dmz55y7p6LazWmgt7a20t7ePpdvKSISe2Z2LJd2GnIRESkSCnQRkSKhQBcRKRIKdBGRIqFAFxEpEgp0EZEioUAXESkSsQj0l17v4Ee7czoMU0SkZMUi0HfsPcWLvzlR6DJEROa1WAS6AbqXtYjI9OIR6GaFLkFEZN6LRaADOOqii4hMJxaBriEXEZFo8Qh0U6CLiESJRaCDacBFRCRCLAJd+0RFRKLFItABXGMuIiLTyinQzexPzWyfmb1tZj82syozW21mu83soJm9aGYV+SpSHXQRkWiRgW5my4A/Adrc/S6gDPgc8CzwnLuvA7qBp/JVpHaKiohEy3XIJQlUm1kSqAE6gUeAbeH5rcDG2S8vzTAdhy4iEiEy0N39JPBN4DjpIO8F9gA97j4UmnUAyzK93sw2m1m7mbV3dXXNqEj10EVEouUy5NIAbABWA7cAtcAnMzTNGLnuvsXd29y9LZVKzahIHeUiIhItlyGXjwNH3L3L3a8BLwEPAYvCEAzAcuBUnmoEsmwtRERkTC6BfhxYb2Y1lr5K1qPAO8ArwOOhzSZge35KDGPoGnMREZlWLmPou0nv/HwdeCu8ZgvwNeArZnYIaAJeyFuVph66iEiUZHQTcPdngGemLD4MPDDrFWVgoEQXEYkQizNFdT10EZFosQh0UAddRCRKLAI9fT10RbqIyHTiEejaKSoiEikegY7OFBURiRKPQDddy0VEJEo8Ar3QBYiIxEAsAh005CIiEiUega6rLYqIRIpFoJsGXUREIsUj0E3HoYuIRIlHoBe6ABGRGIhFoINOLBIRiRKLQNct6EREosUj0HWTaBGRSPEIdPXQRUQixSfQC12EiMg8FxnoZna7mb0x4eeimX3ZzBrNbKeZHQzThvyVqeNcRESi5HJP0Xfd/R53vwe4H7gC/BR4Gtjl7uuAXWE+bzTkIiIyvRsdcnkUeN/djwEbgK1h+VZg42wWNpHppqIiIpFuNNA/B/w4PF7i7p0AYbp4NgubSNdDFxGJlnOgm1kF8Bng72/kDcxss5m1m1l7V1fXjdYXfof65yIiUW6kh/5J4HV3PxPmz5hZC0CYns30Inff4u5t7t6WSqVmVKQuziUiEu1GAv3zjA+3AOwANoXHm4Dts1VUJro4l4jI9HIKdDOrAR4DXpqw+BvAY2Z2MDz3jdkvb/T9NeQiIhIlmUsjd78CNE1Zdp70US95p52iIiLRYnKmqGnIRUQkQiwCHTTkIiISJRaBbjrIRUQkUiwCHVAXXUQkQiwCPX09dBERmU48Al03iRYRiRSPQEcjLiIiUeIR6NopKiISKRaBDjqxSEQkSiwC3Uw3iRYRiRKPQEc9dBGRKLEIdHRxLhGRSLEIdFOii4hEikeg6ygXEZFIsQh0QDtFRUQixCLQtVNURCRaPAJdQ+giIpHiEejoBhciIlFyvafoIjPbZmYHzGy/mT1oZo1mttPMDoZpQ76K1E5REZFoufbQ/wr4ubt/ALgb2A88Dexy93XArjCfN+qfi4hMLzLQzWwh8HvACwDuPujuPcAGYGtothXYmK8itVNURCRaLj30NUAX8H0z+62ZPW9mtcASd+8ECNPFeatSYy4iIpFyCfQkcB/wXXe/F+jjBoZXzGyzmbWbWXtXV9eMihyNc+0YFRHJLpdA7wA63H13mN9GOuDPmFkLQJiezfRid9/i7m3u3pZKpWZU5GgHXXkuIpJdZKC7+2nghJndHhY9CrwD7AA2hWWbgO15qZBwLRcREZlWMsd2/wn4kZlVAIeB/0B6Y/ATM3sKOA48kZ8Sx6mDLiKSXU6B7u5vAG0Znnp0dsvJbHzIxUG9dRGRjGJypmiaeugiItnFI9C1U1REJFJMAl3DLCIiUWIR6KN0TXQRkeziFejKcxGRrGIR6BpxERGJFo9AD8e5qIcuIpJdPAJ99CgXjaGLiGQVi0AXEZFosQj08astFrQMEZF5LR6BPjbkIiIi2cQj0Md2iirSRUSyiUegq4cuIhIpFoEuIiLRYhXoGnEREckuFoFuGnMREYkUj0APU51YJCKSXTwCXddDFxGJlNMt6MzsKHAJGAaG3L3NzBqBF4FW4CjwWXfvzkeRumORiEi0G+mhf8zd73H30XuLPg3scvd1wK4wLyIiBXIzQy4bgK3h8VZg482Xk9noTlGdWCQikl2uge7A/zazPWa2OSxb4u6dAGG6OB8Fgk4sEhHJRU5j6MDD7n7KzBYDO83sQK5vEDYAmwFWrlw5gxJ1cS4RkVzk1EN391Nhehb4KfAAcMbMWgDC9GyW125x9zZ3b0ulUjOrcnTIRX10EZGsIgPdzGrNbMHoY+D3gbeBHcCm0GwTsD1fReoOdCIi0XIZclkC/DTsmEwCf+vuPzez3wA/MbOngOPAE/krM1AHXUQkq8hAd/fDwN0Zlp8HHs1HUVNpp6iISLR4nCmqm0SLiESKR6DrJtEiIpHiEeiFLkBEJAZiEeijNOQiIpJdLAJdO0VFRKLFI9B1k2gRkUixCHR0PXQRkUixCHTtFBURiRaPQDdFuohIlFgEeiLk+fCIxlxERLKJRaCXhUQfUqCLiGQVi0AvL0uXqR66iEh2sQj00R76teGRAlciIjJ/xSLQy8vSga4euohIdrEI9LJEukyNoYuIZBeLQE+O7hTVkIuISFaxCnQNuYiIZBePQA9j6NcU6CIiWeUc6GZWZma/NbOXw/xqM9ttZgfN7EUzq8hXkcnE6GGLGnIREcnmRnroXwL2T5h/FnjO3dcB3cBTs1nYRGMnFg2rhy4ikk1OgW5my4E/BJ4P8wY8AmwLTbYCG/NRIIwPuegoFxGR7HLtoX8b+DNgdMyjCehx96Ew3wEsm+XaxiR12KKISKTIQDezTwNn3X3PxMUZmmZMWzPbbGbtZtbe1dU1oyJ12KKISLRceugPA58xs6PA35Eeavk2sMjMkqHNcuBUphe7+xZ3b3P3tlQqNaMiNeQiIhItMtDd/evuvtzdW4HPAb9w9y8ArwCPh2abgO35KnL8KBcFuohINjdzHPrXgK+Y2SHSY+ovzE5J1yvTkIuISKRkdJNx7v4q8Gp4fBh4YPZLul65hlxERCLF4kxRHYcuIhItFoE+eoOLazpTVEQkq1gEekUI9IFrCnQRkWxiEeiJhFGRTNA/NFzoUkRE5q1YBDpAVTKhHrqIyDTiE+jlZfRfUw9dRCQbBbqISJGIUaAn6NeQi4hIVjEK9DLtFBURmUZ8Aj1ZxtVBBbqISDaxCfTK8gT9QxpyERHJJjaBXlNRxpWBoeiGIiIlKjaBXl9dzsX+a4UuQ0Rk3opVoPdeVaCLiGQTq0DvvzbCgI50ERHJKFaBDqiXLiKSRWwCfWEI9IsKdBGRjGIT6Oqhi4hMLzLQzazKzF4zs71mts/M/iIsX21mu83soJm9aGYV+SxUgS4iMr1ceugDwCPufjdwD/AJM1sPPAs85+7rgG7gqfyVqUAXEYkSGeiedjnMlocfBx4BtoXlW4GNeakwGAv0Kwp0EZFMchpDN7MyM3sDOAvsBN4Hetx99NTNDmBZfkpMW1hdjhlcUKCLiGSUU6C7+7C73wMsBx4A7sjULNNrzWyzmbWbWXtXV9eMCy0vS5Cqq+R079UZ/w4RkWJ2Q0e5uHsP8CqwHlhkZsnw1HLgVJbXbHH3NndvS6VSN1Mrtyyq5lRP/039DhGRYpXLUS4pM1sUHlcDHwf2A68Aj4dmm4Dt+Spy1LJF1ZzqUQ9dRCSTXHroLcArZvYm8Btgp7u/DHwN+IqZHQKagBfyV2baLYuqONlzFfeMozsiIiUtGdXA3d8E7s2w/DDp8fQ5s2xRNQNDI5zvG6S5rnIu31pEZN6LzZmiACsaawA4dr6vwJWIiMw/sQr025YsAODA6UsFrkREZP6JVaAvb6imrjLJuwp0EZHrxCrQzYzbltSphy4ikkGsAh3grmX17DvZy9CwbhgtIjJR7AK9rbWRvsFh3um8WOhSRETmldgF+odbGwB47ciFAlciIjK/xC7QW+qruTVVyy8OnC10KSIi80rsAh3gDz64lN1HLtDdN1joUkRE5o1YBvon7lrK8Iizc/+ZQpciIjJvxDLQP7SsnlVNNWxr7yh0KSIi80YsA93M+HfrV/Ha0Qu8c0pHu4iIQEwDHeCJ+1dQXV7Gln9+v9CliIjMC7EN9PqacjY91Mr2vac4cFq9dBGR2AY6wB9/5FYWVCZ5Zvs+RkZ0jXQRKW2xDvT6mnL+/FN3sPvIBf72teOFLkdEpKBiHegA//bDK/jddc3813/cr6EXESlpsQ90M+Mvn7ibhdVJ/uMP2jl/eaDQJYmIFEQuN4leYWavmNl+M9tnZl8KyxvNbKeZHQzThvyXm9nihVX89z+6n65LA3zh+d06g1RESlIuPfQh4KvufgewHviimd0JPA3scvd1wK4wXzD3rmzgb55s4/C5Pr7w/G7OXOwvZDkiInMuMtDdvdPdXw+PLwH7gWXABmBraLYV2JivInP1u+tS/M2TbRw738eGv/4lb5/sLXRJIiJz5obG0M2sFbgX2A0scfdOSIc+sHi2i5uJj9yWYtsfP0RZwvg33/0Xtv7LUdx1SKOIFL+cA93M6oB/AL7s7jkfTmJmm82s3czau7q6ZlLjDbujZSE/++LDPHhrE8/s2MeT33uNju4rc/LeIiKFklOgm1k56TD/kbu/FBafMbOW8HwLkPEC5e6+xd3b3L0tlUrNRs05SS2o5Pv//sP8l4130X60m0f/8v/wrZ3vcXVweM5qEBGZS7kc5WLAC8B+d//WhKd2AJvC403A9tkv7+aYGX+0fhW7vvoRfv+DS/nOroN87Juv8sNfHaX/moJdRIqLRY0vm9nvAP8XeAsYvTPzn5MeR/8JsBI4Djzh7tPeF66trc3b29tvtuYZe+3IBZ79+QH2HOtmycJKNv/erXy2bTkLqsoLVpOISBQz2+PubZHt5nKHYaEDHcDd+dX75/nOLw7y68MXqK0o41/ft5wnH1zFuiULClqbiEgmuQZ6ci6KmU/MjIfWNvPQ2mb2nujhB786xovtJ/jhr49x/6oGNt67jE9/qIWG2opClyoickNKroeeyYW+QX7SfoKXXu/gvTOXSSaMj96e4g//VQuP3L6E+hoNyYhI4WjIZQbcnf2dl/jZGyfZ/sZJzlwcoCxhPNDayGN3LuGxO5eworGm0GWKSIlRoN+kkRFnb0cPO985w853znDw7GUA1jTX8vDaZh5e28yDa5rUexeRvFOgz7Kj5/rYdeAsvzx0jl8fPs+VwWESlr5h9fo1Tdy/qoH7VzXQVFdZ6FJFpMgo0PNocGiEvR09/L+D5/jloXPs7ejh2nB6Pa5uruW+lQ20tTZw38oG1i6uoyxhBa5YROJMgT6H+q8N89bJXvYc66b9aDevH+/mQriEb3V5GXfespC7blnIB5fV86Fl9axdXEd5WewvRS8ic0SBXkDuzpFzffz2eA9vn+rl7ZO97Dt1kSvhsgMVyQR3tCzkjqULuG1J+FlaR6qukvSJuSIi4xTo88zwSDrk94WAf+tkL++evkT3lWtjbRbVlIeAr+P2JQtYt2QBa1K1CnqREqcTi+aZsoSxdnEdaxfXseGeZUC6J3/u8iAHz1zi3TOXeO/MZd47c4ntb5ziUv/Q2GvrKpO0NtewurmO1U01rE7V0tpUy5rmOh1lIyJjFOgFZGakFlSSWlDJQ2ubx5a7O6cv9vPemcsc6brM0fNXOHyuj70nevjHN08xMuE/VQ015axurmVlYw0rGmtY3lDN8oYaVjTU0LKoSmP1IiVEgT4PmRkt9dW01FfzkdsmX3J4YGiYExeucvRcH0fO9XHkfB9HuvpoP9bN/3qzk+EJaZ8wWLqwiuWTgr6aZQ3V3FJfzdL6KqrKy+b644lInijQY6YyWTY2dDPV0PAInb39dHRf5UT3FTq6r9LRfYWOC1f59fvn6bx4kqm7TBpqyllaX01LfRVL66tYujA9bQk/S+urqavU10QkDvQvtYgkyxKsCEMvD9J03fODQyN09l7lZPdVOnv7OX2xn87eq5zu7aezt5+9J3o4Hw63nGhBZZKl9VUsWVg1NkSUqktPF4/OL6ikvrpcO29FCkiBXkIqkglWNdWyqqk2a5uBoWHOXhygs3dy2Hf2XuXspQGOHu2j69IAA0Mj1722vMzGgj61YEL4hw1AU10FjbUVNNZUUF9dTkInXInMKgW6TFKZLBvr5Wfj7lwaGKLr0sDYz9kJj7suD3Cy5ypvnOjhfN/AdcM8kD7qp6GmPB3wtRU01VbSWFtBQ20FTWPLKmgMG4GGmgrt4BWJoECXG2ZmLKwqZ2FVObemrh/Ln2hoeIQLVwbpujTAhb5BLvQNcv5ymPYNcqEvvXz/6Ytc6BukZ8Jx+VMtrErSVFc5tiFYVJMO+2WLqnmibTk1Ffo6S2nTvwDJq2RZgsULqli8oCqn9kPDI/RcvTYp+C/0DYTwT28Eeq4Mcqqnn32nLnK+b5DBoRGe2bGPirIEiQSUmZFIGAkzysam48tHl5lBwoxEmNqExwljyvyE9okbbB+ey+39wrLEDbbP8H7Ttpn2M4wvY/R1jL/emPx50lOAKe9NaB8eJxJhOrrsut873n5irRPfb1L7KW3Gl4//rvHXUxL7dyID3cy+B3waOOvud4VljcCLQCtwFPisu3fnr0wpFcmyBM11lTTXVcKS3F7zyoGz7O3oof/aCCPuDI+kf9ydYXeGR9KXQx52H596eujIHUbcw0962cjYstF5Z2SEsd99bTiH9mHZpN8fdjtkfb+Rie0zvH7uTuouWtNucGzKhiMxeUPBhI3R1A3N1A0QEzZ8o7636cOsbMrv/RRy6aH/T+CvgR9MWPY0sMvdv2FmT4f5r81+eSLRPvaBxXzsA4sLXcacmLwBybABGJmywWBKm5Gojdj488MjjgM+uvEjvJ7086T/THqNM3FDOf66ScuYvEF10hu7ie3G22T6vaNtpiybUItPaDNp2aR6x9fP1PpGl41MqSVTfdf/3vQ8UzbAFcn87wOKDHR3/2cza52yeAPw0fB4K/AqCnSRvDMzygzKKP7hA7lxM91kLHH3ToAwLY3ukYjIPJb3/wOY2WYzazez9q6urny/nYhIyZppoJ8xsxaAMD2braG7b3H3NndvS6VS2ZqJiMhNmmmg7wA2hcebgO2zU46IiMxUZKCb2Y+BXwG3m1mHmT0FfAN4zMwOAo+FeRERKaBcjnL5fJanHp3lWkRE5Cbo4hgiIkVCgS4iUiTm9CbRZtYFHJvhy5uBc7NYTtxpfYzTuphM62OyYlgfq9w98jDBOQ30m2Fm7bnc9bpUaH2M07qYTOtjslJaHxpyEREpEgp0EZEiEadA31LoAuYZrY9xWheTaX1MVjLrIzZj6CIiMr049dBFRGQasQh0M/uEmb1rZofCDTWKmpmtMLNXzGy/me0zsy+F5Y1mttPMDoZpQ1huZvadsH7eNLP7CvsJ8sPMyszst2b2cphfbWa7w/p40cwqwvLKMH8oPN9ayLpnm5ktMrNtZnYgfEceLOXvhpn9afh38raZ/djMqkr1uzHvA93MyoD/BnwSuBP4vJndWdiq8m4I+Kq73wGsB74YPvPonaLWAbvCPKTXzbrwsxn47tyXPCe+BOyfMP8s8FxYH93AU2H5U0C3u68FngvtislfAT939w8Ad5NeJyX53TCzZcCfAG3hFpllwOco1e9G+pZK8/cHeBD4pwnzXwe+Xui65ngdbCd9EbR3gZawrAV4Nzz+H8DnJ7Qfa1csP8By0kH1CPAy6Vs2ngOSU78nwD8BD4bHydDOCv0ZZmk9LASOTP08pfrdAJYBJ4DG8Hf9MvAHpfjdcPf530Nn/C9sVEdYVhLCfwnvBXaT/U5RpbCOvg38GRButUwT0OPuQ2F+4mceWx/h+d7QvhisAbqA74fhp+fNrJYS/W64+0ngm8BxoJP03/UeSvO7EYtAz3TzxJI4NMfM6oB/AL7s7hena5phWdGsIzP7NHDW3fdMXJyhqefwXNwlgfuA77r7vUAf48MrmRTzuiDsK9gArAZuAWpJDzNNVQrfjVgEegewYsL8cuBUgWqZM2ZWTjrMf+TuL4XF2e4UVezr6GHgM2Z2FPg70sMu3wYWmdnoJaAnfuax9RGerwcuzGXBedQBdLj77jC/jXTAl+p34+PAEXfvcvdrwEvAQ5TmdyMWgf4bYF3Ya11BeofHjgLXlFdmZsALwH53/9aEp7LdKWoH8GQ4omE90Dv63+9i4O5fd/fl7t5K+u//F+7+BeAV4PHQbOr6GF1Pj4f2RdELc/fTwAkzuz0sehR4hxL9bpAeallvZjXh383o+ii57wYw/3eKhnX9KeA94H3gPxe6njn4vL9D+r+BbwJvhJ9PkR7r2wUcDNPG0N5IHwn0PvAW6T3+Bf8ceVo3HwVeDo/XAK8Bh4C/ByrD8qowfyg8v6bQdc/yOrgHaA/fj58BDaX83QD+AjgAvA38EKgs1e+GzhQVESkScRhyERGRHCjQRUSKhAJdRKRIKNBFRIqEAl1EpEgo0EVEioQCXUSkSCjQRUSKxP8HCgjfpqlFaYMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(np.arange(max_iter), loss_list)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练速度比全批量训练提升很多倍,且精度达到了 ~97%,应该是我做了打乱训练数据的原因,可能和使用了 SGD 有关。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 2 }