{ "cells": [ { "cell_type": "markdown", "id": "d1278dcb", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# 多层感知机的从零开始实现\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "0be61c4f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.369567Z", "iopub.status.busy": "2023-08-18T06:59:24.368990Z", "iopub.status.idle": "2023-08-18T06:59:24.501326Z", "shell.execute_reply": "2023-08-18T06:59:24.500151Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "id": "1484071d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "实现一个具有单隐藏层的多层感知机,\n", "它包含256个隐藏单元" ] }, { "cell_type": "code", "execution_count": 3, "id": "7730f280", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.508163Z", "iopub.status.busy": "2023-08-18T06:59:24.506257Z", "iopub.status.idle": "2023-08-18T06:59:24.520861Z", "shell.execute_reply": "2023-08-18T06:59:24.519861Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_inputs, num_outputs, num_hiddens = 784, 10, 256\n", "\n", "W1 = nn.Parameter(torch.randn(\n", " num_inputs, num_hiddens, requires_grad=True) * 0.01)\n", "b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))\n", "W2 = nn.Parameter(torch.randn(\n", " num_hiddens, num_outputs, requires_grad=True) * 0.01)\n", "b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))\n", "\n", "params = [W1, b1, W2, b2]" ] }, { "cell_type": "markdown", "id": "609b91ed", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "实现ReLU激活函数" ] }, { "cell_type": "code", "execution_count": 4, "id": "5f46a813", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.528151Z", "iopub.status.busy": "2023-08-18T06:59:24.526356Z", "iopub.status.idle": "2023-08-18T06:59:24.533695Z", "shell.execute_reply": "2023-08-18T06:59:24.532654Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def relu(X):\n", " a = torch.zeros_like(X)\n", " return torch.max(X, a)" ] }, { "cell_type": "markdown", "id": "10221d2d", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "实现我们的模型" ] }, { "cell_type": "code", "execution_count": 6, "id": "f55fe0ea", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.554675Z", "iopub.status.busy": "2023-08-18T06:59:24.552824Z", "iopub.status.idle": "2023-08-18T06:59:24.560084Z", "shell.execute_reply": "2023-08-18T06:59:24.559049Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def net(X):\n", " X = X.reshape((-1, num_inputs))\n", " H = relu(X@W1 + b1)\n", " return (H@W2 + b2)\n", "\n", "loss = nn.CrossEntropyLoss(reduction='none')" ] }, { "cell_type": "markdown", "id": "09b086e0", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "多层感知机的训练过程与softmax回归的训练过程完全相同" ] }, { "cell_type": "code", "execution_count": 7, "id": "c83cc0c7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.567796Z", "iopub.status.busy": "2023-08-18T06:59:24.566005Z", "iopub.status.idle": "2023-08-18T07:00:19.750339Z", "shell.execute_reply": "2023-08-18T07:00:19.748990Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:19.710036\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 10, 0.1\n", "updater = torch.optim.SGD(params, lr=lr)\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)" ] }, { "cell_type": "markdown", "id": "e08ce45c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在一些测试数据上应用这个模型" ] }, { "cell_type": "code", "execution_count": 8, "id": "8230ba7c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:19.755336Z", "iopub.status.busy": "2023-08-18T07:00:19.754506Z", "iopub.status.idle": "2023-08-18T07:00:20.323813Z", "shell.execute_reply": "2023-08-18T07:00:20.322738Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:20.249980\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.predict_ch3(net, test_iter)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }