{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Softmax 回归的从零开始实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.348783Z", "start_time": "2019-07-03T22:03:00.100284Z" }, "attributes": { "classes": [], "id": "", "n": "1" } }, "outputs": [], "source": [ "import d2l\n", "from mxnet import autograd, np, npx, gluon\n", "from IPython import display\n", "npx.set_np()\n", "\n", "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "初始模型参数。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.358820Z", "start_time": "2019-07-03T22:03:02.352386Z" }, "attributes": { "classes": [], "id": "", "n": "3" } }, "outputs": [], "source": [ "num_inputs = 784\n", "num_outputs = 10\n", "\n", "W = np.random.normal(0, 0.01, (num_inputs, num_outputs))\n", "b = np.zeros(num_outputs)\n", "\n", "W.attach_grad()\n", "b.attach_grad()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "实现 softmax 计算子。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.427233Z", "start_time": "2019-07-03T22:03:02.361623Z" }, "attributes": { "classes": [], "id": "", "n": "6" } }, "outputs": [ { "data": { "text/plain": [ "((2, 5), array([1.0000001, 1. ]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def softmax(X):\n", " X_exp = np.exp(X)\n", " partition = X_exp.sum(axis=1, keepdims=True)\n", " return X_exp / partition # The broadcast mechanism is applied here\n", "\n", "X = np.random.normal(size=(2, 5))\n", "X_prob = softmax(X)\n", "X_prob.shape, X_prob.sum(axis=1)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "定义模型。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.434175Z", "start_time": "2019-07-03T22:03:02.430034Z" }, "attributes": { "classes": [], "id": "", "n": "8" } }, "outputs": [], "source": [ "def net(X):\n", " return softmax(np.dot(X.reshape((-1, num_inputs)), W) + b)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "定义损失函数。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.440123Z", "start_time": "2019-07-03T22:03:02.436124Z" }, "attributes": { "classes": [], "id": "", "n": "10" } }, "outputs": [], "source": [ "def cross_entropy(y_hat, y):\n", " return - np.log(y_hat[range(len(y_hat)), y])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "实现分类精度。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.919658Z", "start_time": "2019-07-03T22:03:02.441863Z" }, "attributes": { "classes": [], "id": "", "n": "11" }, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "0.0925" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def accuracy(y_hat, y):\n", " return float((y_hat.argmax(axis=1) == y.astype('float32')).sum())\n", "\n", "def evaluate_accuracy(net, data_iter):\n", " metric = d2l.Accumulator(2) # num_corrected_examples, num_examples\n", " for X, y in data_iter:\n", " metric.add(accuracy(net(X), y), y.size)\n", " return metric[0] / metric[1]\n", "\n", "evaluate_accuracy(net, test_iter)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "训练一个周期。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.927193Z", "start_time": "2019-07-03T22:03:02.921495Z" }, "attributes": { "classes": [], "id": "", "n": "15" } }, "outputs": [], "source": [ "def train_epoch_ch3(net, train_iter, loss, updater):\n", " metric = d2l.Accumulator(3) # 训练损失和,训练精度和,样本数。\n", " if isinstance(updater, gluon.Trainer):\n", " updater = updater.step\n", " for X, y in train_iter:\n", " with autograd.record():\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " l.backward()\n", " updater(X.shape[0])\n", " metric.add(float(l.sum()), accuracy(y_hat, y), y.size)\n", " # Return training loss and training accuracy\n", " return metric[0]/metric[2], metric[1]/metric[2]" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "整个训练函数。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:02.938271Z", "start_time": "2019-07-03T22:03:02.930873Z" }, "attributes": { "classes": [], "id": "", "n": "17" } }, "outputs": [], "source": [ "def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):\n", " trains, test_accs = [], []\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " ylim=[0.3, 0.9],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " for epoch in range(num_epochs):\n", " train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n", " test_acc = evaluate_accuracy(net, test_iter)\n", " animator.add(epoch+1, train_metrics+(test_acc,))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "训练。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:37.341014Z", "start_time": "2019-07-03T22:03:02.940645Z" }, "attributes": { "classes": [], "id": "", "n": "18" }, "scrolled": true, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 10, 0.1\n", "updater = lambda batch_size: d2l.sgd([W, b], lr, batch_size)\n", "train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "预测。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:03:37.687138Z", "start_time": "2019-07-03T22:03:37.342981Z" }, "attributes": { "classes": [], "id": "", "n": "19" }, "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def predict_ch3(net, test_iter, n=6):\n", " for X, y in test_iter:\n", " break\n", " trues = d2l.get_fashion_mnist_labels(y)\n", " preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))\n", " titles = [true+'\\n'+ pred for true, pred in zip(trues, preds)]\n", " d2l.show_images(X[0:n].reshape((n,28,28)), 1, n, titles=titles[0:n])\n", "\n", "predict_ch3(net, test_iter)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }