{ "cells": [ { "cell_type": "markdown", "id": "eb46f4c0", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# 多层感知机的简洁实现\n", "\n", "通过高级API更简洁地实现多层感知机" ] }, { "cell_type": "code", "execution_count": 1, "id": "f4b9d183", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:20.711610Z", "iopub.status.busy": "2023-08-18T07:04:20.711337Z", "iopub.status.idle": "2023-08-18T07:04:22.715766Z", "shell.execute_reply": "2023-08-18T07:04:22.714884Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "8b016771", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "隐藏层\n", "包含256个隐藏单元,并使用了ReLU激活函数" ] }, { "cell_type": "code", "execution_count": 2, "id": "a11cfbe9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:22.719981Z", "iopub.status.busy": "2023-08-18T07:04:22.719298Z", "iopub.status.idle": "2023-08-18T07:04:22.748628Z", "shell.execute_reply": "2023-08-18T07:04:22.747813Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = nn.Sequential(nn.Flatten(),\n", " nn.Linear(784, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, 10))\n", "\n", "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, std=0.01)\n", "\n", "net.apply(init_weights);" ] }, { "cell_type": "markdown", "id": "8e13fc47", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "训练过程" ] }, { "cell_type": "code", "execution_count": 4, "id": "78ac9bf1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:22.761842Z", "iopub.status.busy": "2023-08-18T07:04:22.761295Z", "iopub.status.idle": "2023-08-18T07:05:05.308680Z", "shell.execute_reply": "2023-08-18T07:05:05.307786Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:05.270258\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "batch_size, lr, num_epochs = 256, 0.1, 10\n", "loss = nn.CrossEntropyLoss(reduction='none')\n", "trainer = torch.optim.SGD(net.parameters(), lr=lr)\n", "\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }