{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# 线性回归的简洁实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T21:59:54.049026Z", "start_time": "2019-07-03T21:59:52.500967Z" }, "attributes": { "classes": [], "id": "", "n": "2" } }, "outputs": [], "source": [ "import d2l\n", "from mxnet import autograd, np, npx, gluon\n", "npx.set_np()\n", "\n", "true_w = np.array([2, -3.4])\n", "true_b = 4.2\n", "features, labels = d2l.synthetic_data(true_w, true_b, 1000)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "读取数据。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T21:59:54.120474Z", "start_time": "2019-07-03T21:59:54.051605Z" }, "attributes": { "classes": [], "id": "", "n": "3" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X =\n", "[[ 0.8370042 -1.1026353 ]\n", " [-1.2860336 -1.6586353 ]\n", " [-0.591276 -1.2689118 ]\n", " [ 1.1089611 1.827097 ]\n", " [-2.1164808 -1.1797674 ]\n", " [ 0.4593352 -0.20153503]\n", " [-0.16823442 -0.38846034]\n", " [ 0.5477088 -1.7779099 ]\n", " [-1.8187165 -1.2048249 ]\n", " [ 1.0532789 0.24552767]]y =\n", "[ 9.621203 7.265286 7.3253717 0.20482737 3.9688 5.806935\n", " 5.183671 11.342728 4.669879 5.468723 ]\n" ] } ], "source": [ "def load_array(data_arrays, batch_size, is_train=True):\n", " dataset = gluon.data.ArrayDataset(*data_arrays)\n", " return gluon.data.DataLoader(dataset, batch_size, shuffle=is_train)\n", " \n", "batch_size = 10\n", "data_iter = load_array((features, labels), batch_size)\n", "for X, y in data_iter:\n", " print('X =\\n%sy =\\n%s' %(X, y))\n", " break" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "定义模型和初始化模型参数。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T21:59:54.129952Z", "start_time": "2019-07-03T21:59:54.124685Z" }, "attributes": { "classes": [], "id": "", "n": "5" }, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "from mxnet.gluon import nn\n", "from mxnet import init\n", "\n", "net = nn.Sequential()\n", "net.add(nn.Dense(1))\n", "net.initialize(init.Normal(sigma=0.01))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "定义损失函数和优化函数。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T21:59:54.135619Z", "start_time": "2019-07-03T21:59:54.132150Z" }, "attributes": { "classes": [], "id": "", "n": "8" } }, "outputs": [], "source": [ "from mxnet import gluon\n", "\n", "loss = gluon.loss.L2Loss() \n", "trainer = gluon.Trainer(net.collect_params(),\n", " 'sgd', {'learning_rate': 0.03})" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "训练。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T21:59:56.549677Z", "start_time": "2019-07-03T21:59:54.137373Z" }, "attributes": { "classes": [], "id": "", "n": "10" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss: 0.040415\n", "epoch 2, loss: 0.000156\n", "epoch 3, loss: 0.000051\n", "Error in estimating w [[ 0.000283 -0.00072527]]\n", "Error in estimating b [0.00046492]\n" ] } ], "source": [ "for epoch in range(1, 4):\n", " for X, y in data_iter:\n", " with autograd.record():\n", " l = loss(net(X), y)\n", " l.backward()\n", " trainer.step(batch_size)\n", " l = loss(net(features), labels)\n", " print('epoch %d, loss: %f' % (epoch, l.mean()))\n", " \n", "w = net[0].weight.data()\n", "print('Error in estimating w', true_w.reshape(w.shape) - w)\n", "b = net[0].bias.data()\n", "print('Error in estimating b', true_b - b) " ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }