{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 参考\n", "* [Logistic function](https://en.wikipedia.org/wiki/Logistic_function)\n", "* [Softmax function](https://en.wikipedia.org/wiki/Softmax_function)\n", "* [如何理解K-L散度(相对熵)](https://www.jianshu.com/p/43318a3dc715?from=timeline&isappinstalled=0)\n", "* [Cross entropy](https://en.wikipedia.org/wiki/Cross_entropy)\n", "* [如何通俗的解释交叉熵与相对熵?](https://www.zhihu.com/question/41252833)\n", "* [Cross Entropy Loss with Softmax的求导](https://blog.csdn.net/jiajunlee/article/details/79665062)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "推导和实现 Softmax、KL 散度、交叉熵和 Cross Entropy Loss,最后实现一个线性分类器 $y=xW+b$做 mnist 分类。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Softmax 函数" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Softmax 函数是 Logistic function 的一种泛化,所以先介绍 Logistic function。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Logistic function 形式如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "f ( x ) = \\frac { L } { 1 + e ^ { - k \\left( x - x _ { 0 } \\right) } }\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中 $k$ 是 Logistic function 的抖动程度,$x_0$ 是函数关于 $x=x_0$ 对称,$L$ 曲线的最大值。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当 $k=1,L=1,x_0=0$ 时,Logistic function 就是 sigmoid 函数:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "f ( x ) = \\frac { 1 } { 1 + e ^ { - x } }\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "sigmoid 函数已经在激活函数中介绍,满足如下性质:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\tanh ( x ) = \\frac { e ^ { x } - e ^ { - x } } { e ^ { x } + e ^ { - x } }=2 f ( 2 x ) - 1\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "f(x)= \\frac { 1 } { 2 } + \\frac { 1 } { 2 } \\tanh \\left( \\frac { x } { 2 } \\right)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\frac { d f ( x )} { d x } = \\frac { e ^ { x } \\cdot \\left( 1 + e ^ { x } \\right) - e ^ { x } \\cdot e ^ { x } } { \\left( 1 + e ^ { x } \\right) ^ { 2 } } = \\frac { e ^ { x } } { \\left( 1 + e ^ { x } \\right) ^ { 2 } } = f ( x ) ( 1 - f ( x ) )\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$K$ 维向量经过 softmax 函数,输出归一化之后的向量可以作为 $K$ 个不同结果发生的概率,概率值和为 1,范围在 $[0,1]$。深度学习中常用于网络输出之后使用 softmax 函数输出概率。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "softmax 函数公式如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\sigma ( \\mathbf { z } ) _ { j } = \\frac { e ^ { z _ { j } } } { \\sum _ { k = 1 } ^ { K } e ^ { z _ { k } } } \\quad \\text { for } j = 1 , \\ldots , K\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "numpy 实现 softmax 函数,输入为一个 $batch\\_size\\times class\\_total$ 的矩阵。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "def softmax(input):\n", " exp_value = np.exp(input) #首先计算指数\n", " output = exp_value/np.sum(exp_value, axis=1)[:, np.newaxis] # 然后按行标准化\n", " return output" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[-0.33234003, 0.44697765, -1.4677743 ],\n", " [-1.06887423, -0.41212265, -1.07418837]]),\n", " array([[0.2856109 , 0.62262729, 0.09176181],\n", " [0.25489283, 0.49156529, 0.25354189]]))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = np.random.randn(2, 3)\n", "prob = softmax(output)\n", "output, prob" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### K-L 散度和 Cross Entropy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在 [如何理解K-L散度(相对熵)](https://www.jianshu.com/p/43318a3dc715?from=timeline&isappinstalled=0) 已经介绍的非常详细了,看完就可以理解 K-L 散度了。下面我提出自己的理解。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "K-L 散度(Kullback–Leibler divergenc,相对熵)是描述两个概率分布 $p$ 和 $q$ 差异的一种方法,$D_{KL}(p||q)$ 表示当用概率分布 $q$ 来拟合真实分布 $p$ 时,产生的信息损耗。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先介绍 **熵(Entropy)**,是信息论中信息的度量单位,基本公式如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "H(p) = - \\sum _ { i = 1 } ^ { N } p \\left( x _ { i } \\right) \\cdot \\log p \\left( x _ { i } \\right)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当对数底为 $2$ 时,表示的是编码概率分布 $p$ 所需要的最少二进制位个数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**熵的大小告诉我们编码 $p$ 最少需要多少空间,而 K-L 散度则是衡量使用一个概率分布代表另一个概率分布所损失的数据量。**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "K-L 散度定义如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "D _ { K L } ( p \\| q ) = \\sum _ { i = 1 } ^ { N } p \\left( x _ { i } \\right) \\cdot \\left( \\log p \\left( x _ { i } \\right) - \\log q \\left( x _ { i } \\right) \\right)\\\\\n", "=\\sum _ { i = 1 } ^ { N } p \\left( x _ { i } \\right) \\cdot \\log \\frac { p \\left( x _ { i } \\right) } { q \\left( x _ { i } \\right) }\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$p$ 为真实分布,使用 $q$ 来近似 $p$。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由公式可以看出,$D _ { K L } ( p \\| q )$ 就是 $q$ 和 $p$ 对数差值的期望,所以 k-L 散度表示如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "D _ { K L } ( p \\| q ) = E [ \\log p ( x ) - \\log q ( x ) ]\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果继续用 2 为底的对数计算,则 K-L 散度值表示信息损失的二进制位数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**K-L 散度不是距离**,因为不符合对称性,而距离度量应该满足对称性,例如 L1、L2 距离都是对称的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$D _ { K L } ( p \\| q )\\ne D _ { K L } ( q \\| p )$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**交叉熵是 $q$ 表示 $p$ 的平均编码长度。**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "H ( p , q ) = \\mathrm { E } _ { p } [ - \\log q ] = H ( p ) + D _ { \\mathrm { KL } } ( p \\| q )\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "交叉熵刚好为 **熵加上 K-L 散度**,所以具体公式是:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "H ( p , q ) = - \\sum _ { x } p ( x ) \\log q ( x )\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**那么为什么不用 K-L 散度而是选择交叉熵呢,K-L 散度不是更直观吗?**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "参考 [Why do we use Kullback-Leibler divergence rather than cross entropy in the t-SNE objective function?](https://stats.stackexchange.com/questions/265966/why-do-we-use-kullback-leibler-divergence-rather-than-cross-entropy-in-the-t-sne) 和 [Why train with cross-entropy instead of KL divergence in classification?](https://www.reddit.com/r/MachineLearning/comments/4mebvf/why_train_with_crossentropy_instead_of_kl/)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为 Cross Entropy 和 K-L divergence 的结果是一样的。\n", "\n", "在机器学习中,真实分布 $p$ 是固定的,是我们给训练集打的标签,因此对 K-L 散度由交叉熵决定,求梯度的结果是一样的。更何况交叉熵计算也非常方便。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这就是交叉熵的真实含义,由交叉熵和 softmax 函数组合之后就是交叉熵损失函数。**目的是最大化真是标签的概率。**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来介绍和实现交叉熵损失函数,并训练一个线性分类器。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cross Entropy Loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://tuchuang-1252747889.cosgz.myqcloud.com/2018-11-27-%E5%B1%8F%E5%B9%95%E5%BF%AB%E7%85%A7%202018-11-27%20%E4%B8%8B%E5%8D%8811.03.41.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上图是交叉熵损失函数的整个流程,来源于 cs231n 的 [Loss Functions and Optimization ](http://cs231n.stanford.edu/syllabus.html),神经网络的输出先经过 softmax 归一化,再使用交叉熵损失函数和真实概率计算 Loss。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "深度学习中使用的交叉熵损失函数形式如下,由交叉熵的公式代入 softmax 函数得来:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "L = - \\frac{1}{N} \\sum _ { i } y_i \\log \\frac { e ^ { z _ { i } } } { \\sum e ^ { z _ { j } } }\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$y_i$ 为真实标签的 one-hot 编码,**因此求和时只需要将真实标签对应的概率取对数求和即可**。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因此在 Pytorch 中 [torch.nn.CrossEntropyLoss](https://pytorch.org/docs/stable/nn.html#crossentropyloss) 使用的是 class based 来计算 Loss,因为这样计算的速度相对于矩阵乘法更快,也节约了空间,基本公式为:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\text{loss}(x, class) = -\\log\\left(\\frac{\\exp(x[class])}{\\sum_j \\exp(x[j])}\\right)\n", " = -x[class] + \\log\\left(\\sum_j \\exp(x[j])\\right)$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "先实现交叉熵损失函数的 **前向传播** 过程:" ] }, { "cell_type": "code", "execution_count": 168, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.6763745848377714" ] }, "execution_count": 168, "metadata": {}, "output_type": "execute_result" } ], "source": [ "N, C = 10, 3 # 随机数据,batch*class\n", "\n", "input = np.random.randn(N, C)\n", "from sklearn.preprocessing import OneHotEncoder\n", "encoder = OneHotEncoder()\n", "labels = encoder.fit_transform(np.random.randint(0, C,(N, 1))).toarray() # 生成标签并独热编码\n", "\n", "prob = softmax(input)\n", "# labels N*C\n", "# prob N*C\n", "loss = -np.sum(np.multiply(labels, np.log(prob)))/labels.shape[0] # 根据公式计算 loss,并求和\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**反向传播** 过程需要对交叉熵损失函数计算偏导。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "先对 softmax 函数求导在利用链式法则求解,对于输入 $x\\in \\mathcal{R}^{1\\times C}$,$x=(x_0,\\dots,x_i,\\dots)$。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为真实概率 $Y=(y_0,\\dots,y_i,\\dots)$ 为独热编码,只有一个位置为 1,所以化简 $L_i$ 为:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "L = - \\sum _ { i } y_{i} (\\log (e ^ { x _ { i } }) - \\log(\\sum e ^ { x _ { j } } ))\\\\\n", "=-(\\log (e ^ { x _ { i } }) - \\log(\\sum e ^ { x _ { j } } )),y_i=1\\\\\n", "=\\log(\\sum e ^ { x _ { j } } )-x_i,y_i=1\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下标舍去,所以对 $L$ 关于 $x_i$ 求偏导为,并且 $y_i=1$:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial L}{\\partial x_i}=\\log(\\sum e ^ { x _ { j } } )-x_i,y_i=1\\\\\n", "=\\frac{e^{x_ { i }}}{\\sum}-1\\\\\n", "=\\frac{e^{x_ { i }}}{\\sum}-y_i\\\\\n", "=p_i-y_i$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面只求了对于 $x_i$ 的偏导,且 $x_i$ 满足 $y_i=1$,对于 $x_j$ 且 $y_j\\ne 1$,即 $y_j=0$:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial L}{\\partial x_j}=\\log(\\sum e ^ { x _ { j } } )-x_i,y_j=0\\\\\n", "=\\frac{e^{x_ { j }}}{\\sum}\\\\\n", "=p_j\\\\\n", "=p_j-0\\\\\n", "=p_j-y_j$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$p$ 为经过 softmax 的概率,$y$ 为独热编码的矩阵,即真实概率。\n", "\n", "所以交叉熵的梯度为 **预测概率减去真实概率!**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "所以最后的解为:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial L}{\\partial x}=\\frac{1}{N}(P-Y)$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因此最后整个反向传播过程实现就很简单了:" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.10707167, -0.12583787, 0.01876621],\n", " [-0.89813245, 0.32953853, 0.56859392],\n", " [ 0.73705126, -0.83412744, 0.09707618],\n", " [ 0.1666249 , -0.90880152, 0.74217661],\n", " [ 0.41876515, 0.50426485, -0.92303 ],\n", " [-0.73118834, 0.26607816, 0.46511018],\n", " [-0.91878922, 0.07298688, 0.84580235],\n", " [ 0.68774499, -0.77065113, 0.08290614],\n", " [ 0.21098268, -0.75737717, 0.54639448],\n", " [-0.58370983, 0.51392215, 0.06978768]])" ] }, "execution_count": 169, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grad = prob - labels\n", "grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来就实现整个交叉熵损失函数的封装:" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [], "source": [ "class CrossEntropyLossLayer():\n", " def __init__(self):\n", " pass\n", " \n", " def forward(self, input, labels):\n", " # 做一些防止误用的措施,输入数据必须是二维的,且标签和数据必须维度一致\n", " assert len(input.shape)==2, '输入的数据必须是一个二维矩阵'\n", " assert len(labels.shape)==2, '输入的标签必须是独热编码'\n", " assert labels.shape==input.shape, '数据和标签数量必须一致'\n", " self.data = input\n", " self.labels = labels\n", " self.prob = np.clip(softmax(input), 1e-9, 1.0) #在取对数时不能为 0,所以用极小数代替 0\n", " loss = -np.sum(np.multiply(self.labels, np.log(self.prob)))/self.labels.shape[0]\n", " return loss\n", " \n", " def backward(self):\n", " self.grad = (self.prob - self.labels)/self.labels.shape[0] # 根据公式计算梯度" ] }, { "cell_type": "code", "execution_count": 189, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[ 0.10171029, -0.21004019, 0.1083299 ],\n", " [-0.10804168, 0.03667105, 0.07137063],\n", " [ 0.14886429, 0.09000375, -0.23886804]]), 0.8824119515040371)" ] }, "execution_count": 189, "metadata": {}, "output_type": "execute_result" } ], "source": [ "N, C = 3, 3\n", "data = np.random.randn(N, C)\n", "labels = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]], dtype=np.float32)\n", "loss = CrossEntropyLossLayer()\n", "l = loss.forward(data, labels)\n", "loss.backward()\n", "loss.grad, l" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "利用 Pytorch 的自动求导机制检验计算是否正确:" ] }, { "cell_type": "code", "execution_count": 190, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 0.1017, -0.2100, 0.1083],\n", " [-0.1080, 0.0367, 0.0714],\n", " [ 0.1489, 0.0900, -0.2389]]), 0.8824119567871094)" ] }, "execution_count": 190, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "data = torch.from_numpy(data).float()\n", "data.requires_grad = True\n", "labels = torch.from_numpy(labels)\n", "prob = torch.nn.functional.softmax(data, dim=1)\n", "loss = -torch.sum(torch.mul(torch.log(prob), labels))/prob.size(0)\n", "loss.backward()\n", "data.grad, loss.item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**两个结果对输入求梯度是一样的,所以我最后的答案是正确的。**\n", "\n", "只是因为精度不同有些差异,numpy 使用 64 位浮点数,而 Pytorch 使用 32 位,因为 GPU 计算单精度浮点数速度比双精度快得多。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 线性分类器" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "实现一个线性分类器,使用交叉熵损失函数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$y=xW$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度下降更新参数,先求 $W$ 的偏导:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial y}{\\partial W}=x^T$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "再根据链式法则可以求得 $L$ 关于 $W$ 的偏导:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\frac{\\partial L}{\\partial W}=x^T \\frac{\\partial L}{\\partial y}\\\\\n", "=\\frac{1}{N} x^T(P-Y)$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$P$ 为 softmax 之后的预测概率,$Y$ 为独热编码的真实概率。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "既然我的实现是正确的,因此读取数据:" ] }, { "cell_type": "code", "execution_count": 191, "metadata": {}, "outputs": [], "source": [ "import struct\n", "import numpy as np\n", "def read_mnist(filename):\n", " with open(filename, 'rb') as f:\n", " zero, data_type, dims = struct.unpack('>HBB', f.read(4))\n", " shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))\n", " return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)\n", "\n", "# 读取并归一化数据,不归一化会导致 nan\n", "test_data = (read_mnist('../data/mnist/t10k-images.idx3-ubyte').reshape((-1, 784))-127.0)/255.0\n", "train_data = (read_mnist('../data/mnist/train-images.idx3-ubyte').reshape((-1, 784))-127.0)/255.0\n", "# 独热编码标签\n", "from sklearn.preprocessing import OneHotEncoder\n", "encoder = OneHotEncoder()\n", "encoder.fit(np.arange(10).reshape((-1, 1)))\n", "train_labels = encoder.transform(read_mnist('../data/mnist/train-labels.idx1-ubyte').reshape((-1, 1))).toarray()\n", "test_labels = encoder.transform(read_mnist('../data/mnist/t10k-labels.idx1-ubyte').reshape((-1, 1))).toarray()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nn import lr_scheduler\n", "loss_layer = CrossEntropyLossLayer()\n", "lr = 0.1\n", "D, C = 784, 10\n", "np.random.seed(1) # 固定随机生成的权重\n", "W = np.random.randn(D, C)*0.01 # 高斯初始化,均值为 0,标准差 0.01\n", "b = np.zeros((1, C)) # 偏置项\n", "best_acc = -float('inf')\n", "max_iter = 900\n", "step_size = 400\n", "scheduler = lr_scheduler(lr, step_size)\n", "loss_list = []\n", "from tqdm import tqdm_notebook\n", "for epoch in tqdm_notebook(range(max_iter)):\n", " # 测试\n", " test_pred = np.dot(test_data, W) + b\n", " pred_labels = np.argmax(test_pred, axis=1)\n", " real_labels = np.argmax(test_labels, axis=1)\n", " acc = np.mean(pred_labels==real_labels)\n", " if acc>best_acc: best_acc=acc\n", " # 训练\n", " train_pred = np.dot(train_data, W) + b\n", " loss = loss_layer.forward(train_pred, train_labels)\n", " loss_list.append(loss)\n", " loss_layer.backward()\n", " grad = np.dot(train_data.T, loss_layer.grad)\n", " # 更新参数\n", " W -= scheduler.get_lr()*grad\n", " b -= scheduler.get_lr()*np.mean(loss_layer.grad, axis=0)\n", " scheduler.step()" ] }, { "cell_type": "code", "execution_count": 200, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8992" ] }, "execution_count": 200, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "获得了 ~90% 的准确度。绘制 Loss 曲线:" ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAHoJJREFUeJzt3Xt0HOWZ5/Hv0ze1bpZkSZaNZWEDNgQSMETc4uwA2QEMmwnZmewGhiEkmxyf2SEzSSZnd5PZPckOmbNnstnNbUJgvIRcZhNINkCGYQjg4RISCMQyIeZiGxsbsLCNZVu2Zcmybs/+0SW71e5ute2WWq76fc7p011vvdX9dLn9q9Jb1V3m7oiISHTEKl2AiIhMLwW/iEjEKPhFRCJGwS8iEjEKfhGRiFHwi4hEjIJfRCRiFPwiIhGj4BcRiZhEpQvIp6WlxRcuXFjpMkREThpr1qzZ5e6tpfSdkcG/cOFCurq6Kl2GiMhJw8zeKLWvhnpERCJGwS8iEjEKfhGRiFHwi4hEjIJfRCRiFPwiIhGj4BcRiZhQBf83H9vIL17tqXQZIiIzWqiC/45fvMYvFfwiIkWFKvirEjGGRscqXYaIyIwWsuCPc2hYwS8iUkyogj+ViHFoZLTSZYiIzGihCn4N9YiITG7S4DezBWb2hJmtM7OXzexTefrcaGZrg9szZnZe1rzXzexFM3vBzKb0JzerkjEN9YiITKKUn2UeAT7r7s+bWT2wxsxWufsrWX22AJe5e6+ZXQOsBC7Omn+Fu+8qX9n5VSXiHBpR8IuIFDNp8Lv7dmB78LjPzNYB84FXsvo8k7XIs0B7messSSquMX4Rkckc0xi/mS0EzgeeK9Lt48DPs6YdeNTM1pjZimMt8FhUJWMMaY9fRKSokq/AZWZ1wL3Ap919f4E+V5AJ/vdmNS9z921mNgdYZWbr3f2pPMuuAFYAdHR0HMNbOKIqEdNQj4jIJEra4zezJJnQ/6G731egz7nAncB17r57vN3dtwX3O4H7gYvyLe/uK9290907W1tLumzkUTTGLyIyuVLO6jHgO8A6d/9qgT4dwH3ATe7+alZ7bXBAGDOrBa4CXipH4fmkEjEODWuMX0SkmFKGepYBNwEvmtkLQdtfAR0A7n4H8AWgGfh2ZjvBiLt3Am3A/UFbAviRuz9c1neQRefxi4hMrpSzen4F2CR9PgF8Ik/7ZuC8o5eYGvrJBhGRyYXrm7tJHdwVEZlMqII/Fc8M9bh7pUsREZmxQhX8VcnM29Fev4hIYeEK/kQcUPCLiBQTsuAf3+PXKZ0iIoWEKvhTQfDrZxtERAoLVfAf2eNX8IuIFBKy4A/G+HUuv4hIQeEK/qTG+EVEJhOq4E8He/yD2uMXESkoVMFfnRoPfu3xi4gUEq7gTyr4RUQmE6rgTwdj/AcV/CIiBYUq+Mf3+BX8IiKFhSr408EY/8EhBb+ISCGhCv7xPX59gUtEpLBQBX8yHiMeM+3xi4gUUco1dxeY2RNmts7MXjazT+XpY2b2TTPbZGZrzeyCrHk3m9nG4HZzud9ArupkXGP8IiJFlHLN3RHgs+7+fHDh9DVmtsrdX8nqcw2wOLhdDNwOXGxms4EvAp2AB8s+4O69ZX0XWdIKfhGRoibd43f37e7+fPC4D1gHzM/pdh3wA894Fmg0s3nA1cAqd98ThP0qYHlZ30GOdDKm8/hFRIo4pjF+M1sInA88lzNrPrA1a7o7aCvUnu+5V5hZl5l19fT0HEtZE1Qn4wp+EZEiSg5+M6sD7gU+7e77c2fnWcSLtB/d6L7S3TvdvbO1tbXUso5SnYrr4K6ISBElBb+ZJcmE/g/d/b48XbqBBVnT7cC2Iu1TRmP8IiLFlXJWjwHfAda5+1cLdHsA+Ehwds8lwD533w48AlxlZk1m1gRcFbRNmXQyrl/nFBEpopSzepYBNwEvmtkLQdtfAR0A7n4H8BBwLbAJGAA+FszbY2ZfAlYHy93q7nvKV/7RqpMxdu7XHr+ISCGTBr+7/4r8Y/XZfRy4pcC8u4C7jqu646Dz+EVEigvVN3chc3BXZ/WIiBQWuuCvSuisHhGRYkIX/Jk9fh3cFREpJHzBn4wzNDrG6FjerwuIiEReKIMfdPlFEZFCQhf8uvyiiEhxIQx+XYVLRKSY0AV/dUpDPSIixYQv+A+P8evMHhGRfEIb/ANDIxWuRERkZgpd8NdUZX6FYkBj/CIieYUu+GuDMf5+7fGLiOQVuuAf3+PvP6TgFxHJJ3TBX5caD34N9YiI5BO64K+p0sFdEZFiQhf8yXiMVCLGAe3xi4jkNemFWMzsLuD9wE53f2ee+f8JuDHr+d4BtAZX33od6ANGgRF37yxX4cXUpuLa4xcRKaCUPf7vAcsLzXT3r7j7UndfCnwe+EXO5RWvCOZPS+gD1KQSGuMXESlg0uB396eAUq+TewNw9wlVVAZ1VQmd1SMiUkDZxvjNrIbMXwb3ZjU78KiZrTGzFeV6rcnUVMV1Hr+ISAGTjvEfgz8Ans4Z5lnm7tvMbA6wyszWB39BHCXYMKwA6OjoOKFCalPa4xcRKaScZ/VcT84wj7tvC+53AvcDFxVa2N1Xununu3e2traeUCG1VXH9ZIOISAFlCX4zawAuA/4xq63WzOrHHwNXAS+V4/UmU5tKaKhHRKSAUk7nvBu4HGgxs27gi0ASwN3vCLr9W+BRd+/PWrQNuN/Mxl/nR+7+cPlKL6ymKq6zekRECpg0+N39hhL6fI/MaZ/ZbZuB8463sBNRq7N6REQKCt03dyEz1HNoZIyRUV2MRUQkVyiDv+bwTzNruEdEJFcog7/u8MVYNNwjIpIrlMF/5Df5tccvIpIrlMF/+CpcOsArInKUcAa/rsIlIlJQKIN/VjoJwP5BBb+ISK5QBn99OrPHv39wuMKViIjMPKEM/lnVwR7/QQW/iEiuUAZ/fVUCM+jTUI+IyFFCGfyxmFGXSmioR0Qkj1AGP2SGe/Yf1B6/iEiu0AZ/fTpBn/b4RUSOEtrgn5VOaqhHRCSP8AZ/dUJDPSIieYQ3+LXHLyKSV2iDPzPGrz1+EZFckwa/md1lZjvNLO/1cs3scjPbZ2YvBLcvZM1bbmYbzGyTmX2unIVPZlZ1kr7BYcbGfDpfVkRkxitlj/97wPJJ+vzS3ZcGt1sBzCwO3AZcA5wN3GBmZ59IscdiVjrJmKOLrouI5Jg0+N39KWDPcTz3RcAmd9/s7kPAPcB1x/E8x2VWdeb3ejTcIyIyUbnG+C81s9+Z2c/N7JygbT6wNatPd9CWl5mtMLMuM+vq6ek54YLqD/9Cpw7wiohkK0fwPw+c6u7nAX8H/Cxotzx9Cw64u/tKd+90987W1tYTLmr8p5n3DSj4RUSynXDwu/t+dz8QPH4ISJpZC5k9/AVZXduBbSf6eqVqrAmCX7/QKSIywQkHv5nNNTMLHl8UPOduYDWw2MwWmVkKuB544ERfr1Tjwd87MDRdLykiclJITNbBzO4GLgdazKwb+CKQBHD3O4APAf/RzEaAg8D17u7AiJl9EngEiAN3ufvLU/Iu8phdmwKgV0M9IiITTBr87n7DJPO/BXyrwLyHgIeOr7QTU52MU5WI0duvPX4RkWyh/eaumTG7NsUeBb+IyAShDX6AppqUxvhFRHKEO/hrk9rjFxHJEe7gr0np4K6ISI5QB//sWg31iIjkCnXwN9Wk2HdwmJHRsUqXIiIyY4Q8+JO469u7IiLZwh38h7/EpeEeEZFxoQ7+8W/v7unXHr+IyLhIBP/uA4cqXImIyMwR6uCfU58GYGefgl9EZFyog392bYp4zOhR8IuIHBbq4I/HjObaFDv7BitdiojIjBHq4AeYM6tKe/wiIlnCH/z1aY3xi4hkCX3wt9ZVKfhFRLJMGvxmdpeZ7TSzlwrMv9HM1ga3Z8zsvKx5r5vZi2b2gpl1lbPwUs2ZVcXuA4cYHSt4nXcRkUgpZY//e8DyIvO3AJe5+7nAl4CVOfOvcPel7t55fCWemDn1VYw57O7XXr+ICJQQ/O7+FLCnyPxn3L03mHwWaC9TbWXRGpzLrwO8IiIZ5R7j/zjw86xpBx41szVmtqLMr1WS1voqAHbuV/CLiEAJF1svlZldQSb435vVvMzdt5nZHGCVma0P/oLIt/wKYAVAR0dHucrilMbMHv+2fQfL9pwiIiezsuzxm9m5wJ3Ade6+e7zd3bcF9zuB+4GLCj2Hu690905372xtbS1HWUDmdM5EzNi2V8EvIgJlCH4z6wDuA25y91ez2mvNrH78MXAVkPfMoKkUjxlzG9Js26tv74qIQAlDPWZ2N3A50GJm3cAXgSSAu98BfAFoBr5tZgAjwRk8bcD9QVsC+JG7PzwF72FSpzRW85b2+EVEgBKC391vmGT+J4BP5GnfDJx39BLTb35jNb/ZUvDEJBGRSAn9N3chc4B3x/5BfYlLRISIBP/8xhpGx1y/0ikiQkSCf/yUzrd6Nc4vIhKJ4G9vqgZga+9AhSsREam8iAR/DWbw+i4Fv4hIJII/nYxzSkM1b+zur3QpIiIVF4ngB1jYUsOW3drjFxGJTvA312qPX0SEiAX/3oFh9g4MVboUEZGKik7wt9QC8LqGe0Qk4iIT/IuC4H9t54EKVyIiUlmRCf5Tm2tIxWO8+nZfpUsREamoyAR/Mh7j9Dl1bFDwi0jERSb4Ac6aW8+GHQp+EYm2SAX/krZ6tu8bZN/AcKVLERGpmEgF/1lz6wE03CMikRap4D9zPPh37K9wJSIilVNS8JvZXWa208zyXjPXMr5pZpvMbK2ZXZA172Yz2xjcbi5X4cdjXkOa+nSC9RrnF5EIK3WP/3vA8iLzrwEWB7cVwO0AZjabzDV6LwYuAr5oZk3HW+yJMjPeMW8WL23THr+IRFdJwe/uTwHFLlp7HfADz3gWaDSzecDVwCp33+PuvcAqim9Aptz5CxpZt20/h0ZGK1mGiEjFlGuMfz6wNWu6O2gr1H4UM1thZl1m1tXT01Omso62dEEjQ6NjvKK9fhGJqHIFv+Vp8yLtRze6r3T3TnfvbG1tLVNZR1va0QjAC1v3TtlriIjMZOUK/m5gQdZ0O7CtSHvFzGuoZu6sNL99U8EvItFUruB/APhIcHbPJcA+d98OPAJcZWZNwUHdq4K2ilq6oFF7/CISWYlSOpnZ3cDlQIuZdZM5UycJ4O53AA8B1wKbgAHgY8G8PWb2JWB18FS3unuxg8TT4vyORh5+eQc7+waZU5+udDkiItOqpOB39xsmme/ALQXm3QXcdeylTZ1LT28G4Nev7ea6pXmPNYuIhFakvrk77pxTGpiVTvD0pl2VLkVEZNpFMvjjMePS05t5etNuMn+siIhERySDH2DZGS28tfcgb+7RpRhFJFoiHfwAT23UcI+IREtkg/+0lloWNtfwL6+8XelSRESmVWSD38y48uw2nnltF32DujCLiERHZIMf4Kpz5jI86jy5Yep+G0hEZKaJdPBf0NFEc22KR17eUelSRESmTaSDPx4zlr9zLv+y7m0OHBqpdDkiItMi0sEP8IcXzGdweIyfv7i90qWIiEyLyAf/BR1NnNpcw/2/favSpYiITIvIB7+Z8Yfnt/PrzbvZqi9ziUgERD74AT584QJiZnz/mdcrXYqIyJRT8ANzG9Jc8865/LhrK/06yCsiIafgD3xs2SL6Bke47/nuSpciIjKlFPyBCzoaOa+9gTt/tYWR0bFKlyMiMmVKCn4zW25mG8xsk5l9Ls/8r5nZC8HtVTPbmzVvNGveA+UsvpzMjFuuOIM3dg9w3/M6w0dEwmvSK3CZWRy4DbiSzMXTV5vZA+7+yngfd/9MVv8/B87PeoqD7r60fCVPnSvPbuPc9ga+8dhGPnj+fFIJ/UEkIuFTSrJdBGxy983uPgTcA1xXpP8NwN3lKG66mRl/eeUS3tp7kHtWv1npckREpkQpwT8f2Jo13R20HcXMTgUWAY9nNafNrMvMnjWzDx53pdPksiWtXHpaM//70VfZ0z9U6XJERMqulOC3PG2Frld4PfBTdx/Nautw907gj4Gvm9npeV/EbEWwgejq6ancr2WaGX993Tn0HxrhK49sqFgdIiJTpZTg7wYWZE23A9sK9L2enGEed98W3G8GnmTi+H92v5Xu3ununa2trSWUNXWWtNXz0fcs5J7Vb7L69T0VrUVEpNxKCf7VwGIzW2RmKTLhftTZOWZ2JtAE/DqrrcnMqoLHLcAy4JXcZWeiT1+5hPamav7yJy/olztFJFQmDX53HwE+CTwCrAN+4u4vm9mtZvaBrK43APe4e/Yw0DuALjP7HfAE8LfZZwPNZHVVCb7275fyVu9BvvRPJ0XJIiIlsYk5PTN0dnZ6V1dXpcsA4CuPrOe2J17j6x9eygfPz3tMW0Sk4sxsTXA8dVI6UX0Sn/79JVy0aDb/5d61rO3eO/kCIiIznIJ/Esl4jNtvvICWuipW/GANO/YNVrokEZETouAvQXNdFSs/8m76Boe56TvP0avz+0XkJKbgL9E5pzRw580X8uaeAW7+7m/YPzhc6ZJERI6Lgv8YXHp6M7f/yQWs276f6//+WXr6DlW6JBGRY6bgP0bvO6uNO2++kC27+vl3dzyjyzWKyElHwX8cLlvSyv/9xMXs6R/ig7c9zbObd1e6JBGRkin4j9O7T23i/luW0ViT5MY7n+OuX21hJn4nQkQkl4L/BJzeWsfPblnGFWfO4dYHX+Hj3+9iZ59O9xSRmU3Bf4Lq00lW3vRu/vsfnM3Tm3ax/Ou/5KEXt2vvX0RmLAV/GcRixkeXLeLBP38vpzSm+bMfPs/N313Nll39lS5NROQoCv4yWtxWz8/+bBlfeP/ZPP9GL1d/7Sn+58Pr2XdQ5/yLyMyh4C+zRDzGf3jvIh7/7GVc+665fPvJ1/hXX36c257YxMCQft5ZRCpPv845xV7eto+vPvoqj63fSUtdio9cupA/ueRUZtemKl2aiITIsfw6p4J/mqx5o5e/e3wjT27oIZ2M8UcXtPOxZYs4Y05dpUsTkRBQ8M9gG9/u485fbuH+377F0OgYFy5s4sMXdnDtu+ZSk0pUujwROUkp+E8CPX2H+Omabn7StZUtu/qpq0rw/nPn8W/OncelpzWTiOvwi4iUruzBb2bLgW8AceBOd//bnPkfBb4CvBU0fcvd7wzm3Qz8t6D9b9z9+5O9XhSCf5y7s/r1Xn68eisPv7Sd/qFRZtemuPqcNq591zwuXtRMKqGNgIgUV9bgN7M48CpwJdBN5uLrN2RfOzcI/k53/2TOsrOBLqATcGAN8G537y32mlEK/myDw6P84tUe/nntdh5b9zb9Q6PUVSVYdkYzl585h8vPbGVeQ3WlyxSRGehYgr+UQeWLgE3uvjl48nuA64BSrkB+NbDK3fcEy64ClgN3l1Jc1KSTca4+Zy5XnzOXweFRfrlxF09s2MkvNvTwyMtvA3BmWz3vOaOZixc1c9Gi2To7SESOWSnBPx/YmjXdDVycp98fmdnvkfnr4DPuvrXAsnmvWG5mK4AVAB0dHSWUFW7pZJwrz27jyrPbcHc27jzAkxt28uSGHu7+zZt89+nXAVjSVsfFi5q5cNFszl/QSHtTNWZW2eJFZEYrJfjzpUju+NA/AXe7+yEz+1Pg+8D7Slw20+i+ElgJmaGeEuqKDDNjSVs9S9rqWfF7pzM0Msba7r08t2UPz27ezb3Pd/MPz74BwOzaFOe1N3DegkbOa2/k3PYGmuuqKvwORGQmKSX4u4EFWdPtwLbsDu6e/YP0/wf4ctayl+cs++SxFikTpRIxOhfOpnPhbG654gxGRsd4Zft+fte9j7Vb9/K77r08+WoP44dv5jWkOXNuPWfOreesufWcNXcWp7fW6aCxSESVEvyrgcVmtojMWTvXA3+c3cHM5rn79mDyA8C64PEjwP8ws6Zg+irg8ydctUyQiMc4t72Rc9sb4ZJTAThwaISX3trH2u69rNvex/odfTy9aRfDo5mtQSJmnNZay5K2ek5rreP01loWtWRu9elkJd+OiEyxSYPf3UfM7JNkQjwO3OXuL5vZrUCXuz8A/IWZfQAYAfYAHw2W3WNmXyKz8QC4dfxAr0ytuqoEl5zWzCWnNR9uGx4dY8uuftbv6GP99v1s2NHH77r38s8vbif75K7W+ioWtdQe3hh0zK6lvamaBU01zKpO6BiCyElOX+ASBodH2bpngNd6+tmyq58tuw6wOXi8u39oQt/6qgTzm6ppb6rJbAxmZ+7nN1YzryHN7NqUNgwiFVDu0zkl5NLJOIvb6lncVn/UvH0Dw2ztHaC79yDdE+4H+PVru+gfGp3QPxWPMWdWFXNnpWlrSDN3Vvrw47b6KuY2pGmblSadjE/X2xORHAp+KaqhJklDTQPvnN9w1Dx3Z9/B4cMbg+37Btmxf5C3g/tXtu3n8XU7OTg8etSy9ekErXVVNNelaAnum2uraKmvoqU2RUt9Fc3BfX2VhpdEyknBL8fNzGisSdFYk8q7YYDMxmH/4Ahv7x9kR9aGYdeBQ+zqH2JX3yE27jzAs5sP0TuQ/4I1qUSMltoUTbUpmmpSNNQkaapJ0lidorEmSWNNKjMdPG6sTtJQndTvHYkUoOCXKWVmNARBvCTPUFK24dExevuH6DlwiN0Hhtjdf4hdfUPsCu57B4bYOzDEtr0H2XtwmL0DQ4wVOUQ1K504vFGoTyepSsRIJWLEY0bMDLPMF01iwYOY2eFps0ztZhAzMCxzH/zlEcued/i5xvsceS7McpbPft48r0fW8lltsZxaxucdfs7cttzXm/A6+Z53/LVtwuvHMm+iQL3B68Qmvh5Z62pivYVfL/96L/B6OetKfw0eOwW/zBjJeIw5s9LMmZUuqf/YmNM3OMLeg0P0DmQ2BHuD+96BYfYdHA42FsPsHxymd2CMoZExRsccB8bccQfHGRsLnjNoG/NMH8+ZHguWJbtt/HnGn8szyxXbKEl5Fd3QkLOBieXf4GZvlCD/hqqgAjOLLZNvgzW7JsVP/vTSkt7ziVDwy0krFrPgGESSU5sn718JfnjjcmRj4BzZUBy98TjSf0IbORulPBuaiRuqIxu0fK93ZMM0vvwkr5fpWmTDmFNL1vvMbRvLU8v4645lrati62fiesytd+K87Ncj599hfF6+1yPr36HYRrzQmZFFt/sFZtanpyeSFfwiU2h8TzOYqmQpIofp6JeISMQo+EVEIkbBLyISMQp+EZGIUfCLiESMgl9EJGIU/CIiEaPgFxGJmBn5e/xm1gO8cZyLtwC7yljOyU7r4witi4m0PiY62dfHqe7eWkrHGRn8J8LMukq9GEEUaH0coXUxkdbHRFFaHxrqERGJGAW/iEjEhDH4V1a6gBlG6+MIrYuJtD4misz6CN0Yv4iIFBfGPX4RESkiNMFvZsvNbIOZbTKzz1W6nulgZgvM7AkzW2dmL5vZp4L22Wa2ysw2BvdNQbuZ2TeDdbTWzC6o7DsoPzOLm9lvzezBYHqRmT0XrIsfm1kqaK8KpjcF8xdWsu6pYGaNZvZTM1sffEYujfhn4zPB/5OXzOxuM0tH9fMRiuA3szhwG3ANcDZwg5mdXdmqpsUI8Fl3fwdwCXBL8L4/Bzzm7ouBx4JpyKyfxcFtBXD79Jc85T4FrMua/jLwtWBd9AIfD9o/DvS6+xnA14J+YfMN4GF3Pws4j8x6ieRnw8zmA38BdLr7O4E4cD1R/XxkLjl2ct+AS4FHsqY/D3y+0nVVYD38I3AlsAGYF7TNAzYEj/8euCGr/+F+YbgB7WTC7H3Ag2QuebULSOR+ToBHgEuDx4mgn1X6PZRxXcwCtuS+pwh/NuYDW4HZwb/3g8DVUf18hGKPnyP/qOO6g7bICP4UPR94Dmhz9+0Awf2coFvY19PXgf8MBJdOpxnY6+4jwXT2+z28LoL5+4L+YXEa0AN8Nxj6utPMaonoZ8Pd3wL+F/AmsJ3Mv/caIvr5CEvw57uYaWROVzKzOuBe4NPuvr9Y1zxtoVhPZvZ+YKe7r8luztPVS5gXBgngAuB2dz8f6OfIsE4+oV4fwbGM64BFwClALZnhrVyR+HyEJfi7gQVZ0+3AtgrVMq3MLEkm9H/o7vcFzW+b2bxg/jxgZ9Ae5vW0DPiAmb0O3ENmuOfrQKOZJYI+2e/38LoI5jcAe6az4CnWDXS7+3PB9E/JbAii+NkA+H1gi7v3uPswcB/wHiL6+QhL8K8GFgdH6FNkDto8UOGappyZGfAdYJ27fzVr1gPAzcHjm8mM/Y+3fyQ4g+MSYN/4n/0nO3f/vLu3u/tCMv/+j7v7jcATwIeCbrnrYnwdfSjoH5o9OnffAWw1szODpn8NvEIEPxuBN4FLzKwm+H8zvj4i+fmo+EGGct2Aa4FXgdeA/1rpeqbpPb+XzJ+fa4EXgtu1ZMYiHwM2Bvezg/5G5uyn14AXyZzhUPH3MQXr5XLgweDxacBvgE3A/wOqgvZ0ML0pmH9apeuegvWwFOgKPh8/A5qi/NkA/hpYD7wE/ANQFdXPh765KyISMWEZ6hERkRIp+EVEIkbBLyISMQp+EZGIUfCLiESMgl9EJGIU/CIiEaPgFxGJmP8PF7WpyFsma00AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(np.arange(max_iter), loss_list)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用 Pytorch 实现一个线性分类器。" ] }, { "cell_type": "code", "execution_count": 202, "metadata": {}, "outputs": [], "source": [ "import torch\n", "train_data = torch.from_numpy(train_data).float()\n", "train_labels = torch.from_numpy(train_labels).float()\n", "test_data = torch.from_numpy(test_data).float()\n", "test_labels = torch.from_numpy(test_labels).float()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 初始化一样的权重和偏置\n", "np.random.seed(1) # 固定随机生成的权重\n", "linear_classifier = torch.nn.Linear(784, 10)\n", "linear_classifier.weight.data = torch.from_numpy((np.random.randn(D, C)*0.01).transpose((1, 0))).float()\n", "_ = linear_classifier.bias.data.fill_(0)\n", "best_acc = -float('inf')\n", "lr = 0.1\n", "max_iter = 900\n", "step_size = 400\n", "scheduler = lr_scheduler(lr, step_size)\n", "loss_list = []\n", "criterion = torch.nn.CrossEntropyLoss()\n", "for epoch in tqdm_notebook(range(max_iter)):\n", " with torch.no_grad():\n", " # 测试\n", " test_pred = linear_classifier(test_data)\n", " pred_labels = torch.argmax(test_pred, dim=1)\n", " real_labels = torch.argmax(test_labels, dim=1)\n", " acc = torch.mean((pred_labels==real_labels).float())\n", " if acc>best_acc: best_acc=acc\n", " train_pred = linear_classifier(train_data)\n", " real_labels = torch.argmax(train_labels, dim=1)\n", " loss = criterion(train_pred, real_labels)\n", " loss.backward()\n", " for p in linear_classifier.parameters():\n", " p.data.add_(-scheduler.get_lr(), p.grad.data)\n", " linear_classifier.zero_grad()\n", " loss_list.append(loss.item())\n", " scheduler.step()" ] }, { "cell_type": "code", "execution_count": 204, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.8992)" ] }, "execution_count": 204, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_acc" ] }, { "cell_type": "code", "execution_count": 205, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAHlVJREFUeJzt3X2UVPWd5/H3t576mX7mQZoWUNSoUTQd1OBOdGZEdDMhu8nmyDjRZJLDmR0zk2Rzdk4yuyeeMXv2TCZ78uDG6DLGPMwmOtlEM47HaBjjQ6LR0BiDCoIIKC0gDTR00w3003f/qNtQ3dQTUN3V3Pt5nVOn6v7u71Z961J87u3fvVXX3B0REYmOWLkLEBGRqaXgFxGJGAW/iEjEKPhFRCJGwS8iEjEKfhGRiFHwi4hEjIJfRCRiFPwiIhGTKHcB2bS0tPj8+fPLXYaIyBlj3bp1e929tZi+0zL458+fT2dnZ7nLEBE5Y5jZm8X21VCPiEjEKPhFRCJGwS8iEjEKfhGRiFHwi4hEjIJfRCRiFPwiIhETquC/84nXeXpzd7nLEBGZ1kIV/Pc8/Qa/UvCLiOQVquCvSMQYHBktdxkiItNayII/ztEhBb+ISD6hCv5UIsbR4ZFylyEiMq2FKvg11CMiUljB4DezeWb2pJltNLNXzewzWfrcbGbrg9tzZnZpxrztZvaymb1kZpP6k5sVyZiGekRECijmZ5mHgc+7+4tmVgesM7M17r4ho8824P3u3mNmNwCrgSsy5l/r7ntLV3Z2FYk4R4cV/CIi+RQMfnffBewKHveZ2UZgLrAho89zGYs8D7SVuM6ipOIa4xcRKeSkxvjNbD5wGfBCnm6fBH6eMe3AL8xsnZmtOtkCT0ZFMsag9vhFRPIq+gpcZlYL/BT4rLv35uhzLengvzqjeam77zSzmcAaM3vN3Z/JsuwqYBVAe3v7SbyF4yoSMQ31iIgUUNQev5klSYf+D939wRx9LgHuBVa4+76xdnffGdzvAR4ClmRb3t1Xu3uHu3e0thZ12cgTaIxfRKSwYs7qMeA7wEZ3/1qOPu3Ag8DH3H1zRntNcEAYM6sBlgGvlKLwbFKJGEeHNMYvIpJPMUM9S4GPAS+b2UtB298C7QDufg/wJaAZ+HZ6O8Gwu3cAs4CHgrYE8CN3f6yk7yCDzuMXESmsmLN6fg1YgT6fAj6VpX0rcOmJS0wO/WSDiEhh4frmblIHd0VECglV8Kfi6aEedy93KSIi01aogr8imX472usXEcktXMGfiAMKfhGRfEIW/GN7/DqlU0Qkl1AFfyoIfv1sg4hIbqEK/uN7/Ap+EZFcQhb8wRi/zuUXEckpXMGf1Bi/iEghoQr+ymCP/4j2+EVEcgpV8FelxoJfe/wiIrmEK/iTCn4RkUJCFfyVwRj/YQW/iEhOoQr+sT1+Bb+ISG6hCv7KYIz/8KCCX0Qkl3AFv36rR0SkoFAFfzJuxGOmPX4RkTyKuebuPDN70sw2mtmrZvaZLH3MzO40sy1mtt7MLs+Yd6uZvR7cbi31G5hQB1XJuMb4RUTyKOaau8PA5939xeDC6evMbI27b8jocwOwKLhdAdwNXGFmTcDtQAfgwbIPu3tPSd9FhkoFv4hIXgX3+N19l7u/GDzuAzYCcyd0WwH8wNOeBxrMbA5wPbDG3fcHYb8GWF7SdzBBZTKm8/hFRPI4qTF+M5sPXAa8MGHWXGBHxnRX0JarPdtzrzKzTjPr7O7uPpmyxqlKxhX8IiJ5FB38ZlYL/BT4rLv3TpydZRHP035io/tqd+9w947W1tZiyzpBVSqug7siInkUFfxmliQd+j909wezdOkC5mVMtwE787RPmspkXD/SJiKSRzFn9RjwHWCju38tR7eHgVuCs3uuBA66+y7gcWCZmTWaWSOwLGibNDq4KyKSXzFn9SwFPga8bGYvBW1/C7QDuPs9wKPAjcAWYAD4RDBvv5l9GVgbLHeHu+8vXfknqkrG2NOr4BcRyaVg8Lv7r8k+Vp/Zx4Hbcsy7D7jvlKo7BTqPX0Qkv1B9cxfSB3d1Vo+ISG6hC/6KhM7qERHJJ3TBn97j11k9IiK5hC/4k3EGR0YZGc36dQERkcgLZfCDLr8oIpJL6IJfl18UEckvhMGvq3CJiOQTuuCvSmmoR0Qkn9AF/9jlF3Vmj4hIdqEL/upgj39gcLjMlYiITE/hC/6K9K9QDGiMX0Qkq9AFf02wx9+vPX4RkaxCF/xje/z9RxX8IiLZhC74a1Njwa+hHhGRbEIX/FU6uCsiklfogj+ViJGKxzikPX4RkawKXojFzO4DPgDscfeLs8z/r8DNGc/3LqA1uPrWdqAPGAGG3b2jVIXnU1MR1x6/iEgOxezxfw9Ynmumu3/V3Re7+2Lgi8DTEy6veG0wf0pCH6A6ldAYv4hIDgWD392fAYq9Tu5K4P7TqqgEairiOqtHRCSHko3xm1k16b8MfprR7MAvzGydma0q1WsVUlOR0Hn8IiI5FBzjPwl/Ajw7YZhnqbvvNLOZwBozey34C+IEwYZhFUB7e/tpFVKTSuibuyIiOZTyrJ6bmDDM4+47g/s9wEPAklwLu/tqd+9w947W1tbTKkRDPSIiuZUk+M2sHng/8C8ZbTVmVjf2GFgGvFKK1yukJqWhHhGRXIo5nfN+4Bqgxcy6gNuBJIC73xN0+w/AL9y9P2PRWcBDZjb2Oj9y98dKV3pu1RVxndUjIpJDweB395VF9Pke6dM+M9u2ApeeamGno6YioaEeEZEcQvfNXUgP9RwdHmV4RBdjERGZKJTBX33sp5k13CMiMlEog7/22MVYNNwjIjJRKIP/+G/ya49fRGSiUAb/satw6QCviMgJwhn8ugqXiEhOoQz+usp08PceUfCLiEwUyuCfUZkEoPfIUJkrERGZfsIZ/FXp4O/THr+IyAlCGfxjp3P2HtYev4jIRKEM/njMqKtIaKhHRCSLUAY/pId7eg9rqEdEZKLQBn9dZYI+7fGLiJwgtME/ozKpoR4RkSzCG/xVCQ31iIhkEd7gr0zSd1R7/CIiE4U2+OsqtccvIpJNweA3s/vMbI+ZZb1erpldY2YHzeyl4PaljHnLzWyTmW0xsy+UsvBCZlQl6TsyxOioT+XLiohMe8Xs8X8PWF6gz6/cfXFwuwPAzOLAXcANwIXASjO78HSKPRkzKpOMOrrouojIBAWD392fAfafwnMvAba4+1Z3HwQeAFacwvOckrEfatPPNoiIjFeqMf6rzOz3ZvZzM7soaJsL7Mjo0xW0ZWVmq8ys08w6u7u7T7ugsd/r0SmdIiLjlSL4XwTOdvdLgf8N/Cxotyx9cw64u/tqd+9w947W1tbTLurYL3TqAK+IyDinHfzu3uvuh4LHjwJJM2shvYc/L6NrG7DzdF+vWPXBHv+BgcGpekkRkTPCaQe/mc02MwseLwmecx+wFlhkZgvMLAXcBDx8uq9XrMaadPD3KPhFRMZJFOpgZvcD1wAtZtYF3A4kAdz9HuAjwH82s2HgMHCTuzswbGafBh4H4sB97v7qpLyLLJpqUgD0DGiMX0QkU8Hgd/eVBeZ/C/hWjnmPAo+eWmmnpyoZpyIRo6dfe/wiIplC+81dM6OpJsV+Bb+IyDihDX6AhuqUxvhFRCYIdfA31SS1xy8iMkGog7+xOqWDuyIiE4Q6+JtqNNQjIjJRqIO/sTrFwcNDDI+MlrsUEZFpI+TBn8QdDh7WcI+IyJhwB/+xL3FpuEdEZEyog3/s27v7+7XHLyIyJhLBv+/Q0TJXIiIyfYQ6+GfWVQLQreAXETkm1MHfVJMiZrCnV8EvIjIm1MEfjxkttRXs6TtS7lJERKaNUAc/wMwZFXT3aY9fRGRM+IO/rpI9Cn4RkWNCH/yttdrjFxHJVDD4zew+M9tjZq/kmH+zma0Pbs+Z2aUZ87ab2ctm9pKZdZay8GLNnFHB3kNHGRnNeZ13EZFIKWaP/3vA8jzztwHvd/dLgC8DqyfMv9bdF7t7x6mVeHpm1lUw6rCvX3v9IiJQRPC7+zPA/jzzn3P3nmDyeaCtRLWVRGtdBYCGe0REAqUe4/8k8POMaQd+YWbrzGxViV+rKK3Bl7h0Lr+ISFrBi60Xy8yuJR38V2c0L3X3nWY2E1hjZq8Ff0FkW34VsAqgvb29VGVxVkM6+Hcd1Ln8IiJQoj1+M7sEuBdY4e77xtrdfWdwvwd4CFiS6zncfbW7d7h7R2traynKAtKncyZixtsHBkr2nCIiZ7LTDn4zawceBD7m7psz2mvMrG7sMbAMyHpm0GSKx4zZ9ZXsPKA9fhERKGKox8zuB64BWsysC7gdSAK4+z3Al4Bm4NtmBjAcnMEzC3goaEsAP3L3xybhPRR0VkMVbx84XI6XFhGZdgoGv7uvLDD/U8CnsrRvBS49cYmpN7ehit9uy3likohIpIT+m7uQPsC7u/eIvsQlIkJkgr+KkVHXr3SKiBCh4Ad4u0fj/CIikQj+eY3p4N/Ro1M6RUQiEfxtjdWYwZv7FPwiIpEI/spknLPqq9i+t7/cpYiIlF0kgh9gfks127XHLyISoeBvrmH7Pu3xi4hEKvgPDAxxYGCw3KWIiJRVdIK/pQZAwz0iEnmRCf4FLdUAbO0+VOZKRETKKzLBf3ZzDal4jE3v9JW7FBGRsopM8CfjMc6ZWcum3Qp+EYm2yAQ/wAWz6xT8IhJ5kQr+82bVsevgEQ4ODJW7FBGRsolU8F8wuw5A4/wiEmmRCv7zFfwiIsUFv5ndZ2Z7zCzrNXMt7U4z22Jm683s8ox5t5rZ68Ht1lIVfirm1FdSV5lg0+7ecpYhIlJWxe7xfw9Ynmf+DcCi4LYKuBvAzJpIX6P3CmAJcLuZNZ5qsafLzHjXnBm88raCX0Siq6jgd/dngHwXrV0B/MDTngcazGwOcD2wxt33u3sPsIb8G5BJd9m8Bjbs7OXo8Eg5yxARKZtSjfHPBXZkTHcFbbnaT2Bmq8ys08w6u7u7S1TWiRbPa2BwZJQNO7XXLyLRVKrgtyxtnqf9xEb31e7e4e4dra2tJSrrRIvbGwB4aceBSXsNEZHprFTB3wXMy5huA3bmaS+bOfVVzJ5RqeAXkcgqVfA/DNwSnN1zJXDQ3XcBjwPLzKwxOKi7LGgrq8XzGvjdWwp+EYmmRDGdzOx+4Bqgxcy6SJ+pkwRw93uAR4EbgS3AAPCJYN5+M/sysDZ4qjvcPd9B4ilxWXsDj726mz19R5hZV1nuckREplRRwe/uKwvMd+C2HPPuA+47+dImz1XnNAPwmzf2sWJx1mPNIiKhFalv7o656Kx6ZlQmeHbL3nKXIiIy5SIZ/PGYcdU5zTy7ZR/pP1ZERKIjksEPsPTcFt4+cJi39utSjCISLZEOfoBnXtdwj4hES2SDf2FLDWc3V/NvG94pdykiIlMqssFvZiy7cBbPvbGXviO6MIuIREdkgx9g2UWzGRpxnt48eb8NJCIy3UQ6+C9vb6S5JsVjr+wudykiIlMm0sEfjxnXXzybf9v4DoeODpe7HBGRKRHp4Af48OVzOTI0qr1+EYmMyAf/5e2NnN1czYMvdpW7FBGRKRH54Dcz/uNlbfxm6z526MtcIhIBkQ9+gI++t42YGT/4zfZylyIiMukU/KQvznLDxbN5YO0O+nWQV0RCTsEf+MTS+fQdGdZYv4iEnoI/cHl7I5e01fOdX29jeGS03OWIiEyaooLfzJab2SYz22JmX8gy/+tm9lJw22xmBzLmjWTMe7iUxZeSmfHpa89l+74BHvzd2+UuR0Rk0hS8ApeZxYG7gOtIXzx9rZk97O4bxvq4++cy+v8VcFnGUxx298WlK3nyXHfhLC5pq+fOJ17nQ4vnkkroDyIRCZ9ikm0JsMXdt7r7IPAAsCJP/5XA/aUobqqZGZ+77jy6eg7zwNq3yl2OiMikKCb45wI7Mqa7grYTmNnZwALglxnNlWbWaWbPm9mHTrnSKXLNea1ctbCZr63ZTE//YLnLEREpuWKC37K05bpe4U3AT9x9JKOt3d07gD8FvmFm52R9EbNVwQais7u7fL+WaWb83YqLOHRkmH94fFPZ6hARmSzFBH8XMC9jug3YmaPvTUwY5nH3ncH9VuApxo//Z/Zb7e4d7t7R2tpaRFmT57xZdXz8ffN5YO1bdG7fX9ZaRERKrZjgXwssMrMFZpYiHe4nnJ1jZucDjcBvMtoazawieNwCLAU2TFx2OvrsdefR1ljFf/nx7/XLnSISKgWD392HgU8DjwMbgR+7+6tmdoeZfTCj60rgAXfPHAZ6F9BpZr8HngT+PvNsoOmstiLB1z+6mK6eAb78r2dEySIiRbHxOT09dHR0eGdnZ7nLAOCrj7/GXU++wTdvWsyKxVmPaYuIlJ2ZrQuOpxakE9UL+Owfn8eSBU38zU/Ws77rQOEFRESmOQV/Acl4jLtvvpyW2gpW/WAd7/QeKXdJIiKnRcFfhObaClbf8h76jgzxZ/e+oPP7ReSMpuAv0kVn1XPvre/lzf0D3Prd39J7ZKjcJYmInBIF/0m46pxm7r75cjbu6mXl6ufZe+houUsSETlpCv6T9EfvmsU/3tLB1u5+/tM9v9HlGkXkjKPgPwXXnD+T//upJew7dJQP3fUsz2/dV+6SRESKpuA/Re85u4mHbltKfXWSm+99gft+vY3p+J0IEZGJFPyn4ZzWWn5221KuPb+VOx7ZwCe/38mePp3uKSLTm4L/NM2oTLL6Yx3c/icX8uyWvSz/xq949OVd2vsXkWlLwV8CsZjxiaULeOSvruashkr+8ocvcut317Jtb3+5SxMROYGCv4QWzarjZ3+5lC994EJefLOH67/+DP/w2GscPKxz/kVk+lDwl1giHuPPr17ALz//fm5892y+/dQb/Luv/JK7ntzCwKB+3llEyk+/zjnJXt15kK/9YjNPvLaHltoUt1w1nz+78myaalLlLk1EQuRkfp1TwT9F1r3Zw51PvM7Tm7upTMb48OVt/PnVCzintbbcpYlICCj4p7HN7/TxnV9t46Hfvc3gyChL5jfx0ffO48Z3z6Y6lSh3eSJyhlLwnwG6+47yk3Vd/LhzB9v29lNbkeADl8zh318yh6sWNpOI6/CLiBSv5MFvZsuBbwJx4F53//sJ8z8OfBV4O2j6lrvfG8y7FfjvQfv/cPfvF3q9KAT/GHdn7fYe/nntDh57ZRf9gyM01aS4/qJZ3PjuOVyxoJlUQhsBEcmvpMFvZnFgM3Ad0EX64usrM6+dGwR/h7t/esKyTUAn0AE4sA54j7v35HvNKAV/piNDIzy1qZtHX97FExvfoX9whNqKBEvPbeaa82dyzfmtzKmvKneZIjINnUzwFzOovATY4u5bgyd/AFgBFHMF8uuBNe6+P1h2DbAcuL+Y4qKmMhln+cWzWX7xbI4MjfDM5m6e3NTN05v28Pir7wBw/qw63nduM1csaGbJgiadHSQiJ62Y4J8L7MiY7gKuyNLvw2b2B6T/Ovicu+/IsWzWK5ab2SpgFUB7e3sRZYVbZTLOsotms+yi2bg7r+85xFOb9vDUpm7u/+1bfPfZ7QCcN6uWKxY0894FTVw2r4G2xirMrLzFi8i0VkzwZ0uRieND/wrc7+5HzewvgO8Df1jksulG99XAakgP9RRRV2SYGefNquO8WXWs+oNzGBweZX3XAV7Ytp/nt+7jpy928U/PvwlAU02KS9vquXReA5e2NXBJWz3NtRVlfgciMp0UE/xdwLyM6TZgZ2YHd8/8Qfp/BL6Ssew1E5Z96mSLlPFSiRgd85vomN/Ebdeey/DIKBt29fL7roOs33GA33cd4KnN3YwdvplTX8n5s+u4YPYMLphdxwVz6ljYUquDxiIRVUzwrwUWmdkC0mft3AT8aWYHM5vj7ruCyQ8CG4PHjwP/08wag+llwBdPu2oZJxGPcUlbA5e0NcCVZwNw6Ogwr7x9kPVdB9i4q4+Nu3p5dstehkbSW4NEzDintZZFs2pZ2FrLwpYaFrbWsKClhrrKZDnfjohMsoLB7+7DZvZp0iEeB+5z91fN7A6g090fBv7azD4IDAP7gY8Hy+43sy+T3ngA3DF2oFcmV21FgisXNnPlwuZjbUMjo2zt7ue13b28truPTbv7WN91kEdf3sVoxuBaa13FuA1Be1M1bY3VtDVWUV+V1DEEkTOcvsAlHB0e4a19A7zR3c+2vf1s7T6Uvt/bz/7+wXF96yoSzG2soq2xmnlNVcc2CHMbqphdX0lTdYpYTBsGkalW6tM5JeQqEnEWzapj0ay6E+YdHBhiR88AXT2H6Qrud+wfYMf+AZ57Yy8DgyPj+qfiMVrrKphdX8nsGZXMmlHJ7PqK9P2MSmbXp9sqk/GpensiMoGCX/Kqr05SX13PxXPrT5jn7vQMDNHVM8DbPYfZ3XuE3b1HeOdg+n7jrl6e3LTnhI0DQH1VkubaFC01FbTUpWiuqaCltiLdVpsKHqen6yoSGl4SKSEFv5wyM6OpJkVTTSp9YDkLd6fv6PCxjcHug0d4p/cIe/qOsu/QIHsPHWXT7j729e/jwED2C9akEjFaalI0VKdorEmm76uTNFSlaKhO0lidvh9rb6xOMaMqSVxDTiJZKfhlUpkZMyqTzKhMZh1KyjQ0Msr+/vTGYO+hQfYdOr5x6D50lIMDQ/QMDLLrQC8HDg9xYGBw3EHp8a+bvh5yY3WS2soEVck4VakE1ck4ibgRM8OM9H1QZ3oaDCMWA7D0dNAvFvzVEcvsGzy2CX3HP2d6OhbL3nf8cwZ9xpaZUFdmn3x9j9d+/LUy+x1fNrOe7H1z1Vmwb2x8e/bXmrCOCvSV0lDwy7SRjMeYFRwXKMboqNN3ZJgDhwfpCTYKYxuHnoEhDgb3/UeHGRgc4eDhIXYfPMzwiOPAqDvux+/dJ7YDOKM+vi9j02QuCx70dfecGyQ5PekNwviNXmzCRuJYnxM2PLk2UOOXz/naJz0j76ysr9VUneLHf3FVnqVKQ8EvZ6xYzIJjEEnObi7cf6r52EaB9AYi70Yi68Yk1wbq5PpOfK1jfUePL0Owocvse2zZUXK+Vr6+ed/Tsdc+vnzWvhPqYuw1j9U59r7G9z3efmLfia81Vme+jXWuWfnOisy77c8xs65yaiJZwS8yScb2PgHieff9RKaWvrMvIhIxCn4RkYhR8IuIRIyCX0QkYhT8IiIRo+AXEYkYBb+ISMQo+EVEImZa/h6/mXUDb57i4i3A3hKWc6bT+jhO62I8rY/xzvT1cba7txbTcVoG/+kws85iL0YQBVofx2ldjKf1MV6U1oeGekREIkbBLyISMWEM/tXlLmCa0fo4TutiPK2P8SKzPkI3xi8iIvmFcY9fRETyCE3wm9lyM9tkZlvM7AvlrmcqmNk8M3vSzDaa2atm9pmgvcnM1pjZ68F9Y9BuZnZnsI7Wm9nl5X0HpWdmcTP7nZk9EkwvMLMXgnXxz2aWCtorguktwfz55ax7MphZg5n9xMxeCz4jV0X8s/G54P/JK2Z2v5lVRvXzEYrgN7M4cBdwA3AhsNLMLixvVVNiGPi8u78LuBK4LXjfXwCecPdFwBPBNKTXz6Lgtgq4e+pLnnSfATZmTH8F+HqwLnqATwbtnwR63P1c4OtBv7D5JvCYu18AXEp6vUTys2Fmc4G/Bjrc/WIgDtxEVD8f6UuRndk34Crg8YzpLwJfLHddZVgP/wJcB2wC5gRtc4BNweP/A6zM6H+sXxhuQBvpMPtD4BHSlzzdCyQmfk6Ax4GrgseJoJ+V+z2UcF3MALZNfE8R/mzMBXYATcG/9yPA9VH9fIRij5/j/6hjuoK2yAj+FL0MeAGY5e67AIL7mUG3sK+nbwB/A4wG083AAXcfDqYz3++xdRHMPxj0D4uFQDfw3WDo614zqyGinw13fxv4X8BbwC7S/97riOjnIyzBn+2CppE5XcnMaoGfAp919958XbO0hWI9mdkHgD3uvi6zOUtXL2JeGCSAy4G73f0yoJ/jwzrZhHp9BMcyVgALgLOAGtLDWxNF4vMRluDvAuZlTLcBO8tUy5QysyTp0P+huz8YNL9jZnOC+XOAPUF7mNfTUuCDZrYdeID0cM83gAYzSwR9Mt/vsXURzK8H9k9lwZOsC+hy9xeC6Z+Q3hBE8bMB8MfANnfvdvch4EHgfUT08xGW4F8LLAqO0KdIH7R5uMw1TTozM+A7wEZ3/1rGrIeBW4PHt5Ie+x9rvyU4g+NK4ODYn/1nOnf/oru3uft80v/+v3T3m4EngY8E3Saui7F19JGgf2j26Nx9N7DDzM4Pmv4I2EAEPxuBt4Arzaw6+H8ztj4i+fko+0GGUt2AG4HNwBvAfyt3PVP0nq8m/efneuCl4HYj6bHIJ4DXg/umoL+RPvvpDeBl0mc4lP19TMJ6uQZ4JHi8EPgtsAX4f0BF0F4ZTG8J5i8sd92TsB4WA53B5+NnQGOUPxvA3wGvAa8A/wRURPXzoW/uiohETFiGekREpEgKfhGRiFHwi4hEjIJfRCRiFPwiIhGj4BcRiRgFv4hIxCj4RUQi5v8Drkqh4kr5XLYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(np.arange(max_iter), loss_list)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "同样的参数,获得了一样的结果,所以我的实现是完全正确的!区别在于运行速度,Pytorch 比 numpy 要快大约一倍,可能是 Pytorch 优化过矩阵预算,且单精度比双精度计算快。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 2 }