{
"cells": [
{
"cell_type": "markdown",
"id": "eb46f4c0",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 多层感知机的简洁实现\n",
"\n",
"通过高级API更简洁地实现多层感知机"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f4b9d183",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:20.711610Z",
"iopub.status.busy": "2023-08-18T07:04:20.711337Z",
"iopub.status.idle": "2023-08-18T07:04:22.715766Z",
"shell.execute_reply": "2023-08-18T07:04:22.714884Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "8b016771",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"隐藏层\n",
"包含256个隐藏单元,并使用了ReLU激活函数"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a11cfbe9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:22.719981Z",
"iopub.status.busy": "2023-08-18T07:04:22.719298Z",
"iopub.status.idle": "2023-08-18T07:04:22.748628Z",
"shell.execute_reply": "2023-08-18T07:04:22.747813Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = nn.Sequential(nn.Flatten(),\n",
" nn.Linear(784, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 10))\n",
"\n",
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, std=0.01)\n",
"\n",
"net.apply(init_weights);"
]
},
{
"cell_type": "markdown",
"id": "8e13fc47",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"训练过程"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "78ac9bf1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:22.761842Z",
"iopub.status.busy": "2023-08-18T07:04:22.761295Z",
"iopub.status.idle": "2023-08-18T07:05:05.308680Z",
"shell.execute_reply": "2023-08-18T07:05:05.307786Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"batch_size, lr, num_epochs = 256, 0.1, 10\n",
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(), lr=lr)\n",
"\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
"d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}