{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/d2l-ai/d2l-pytorch-sagemaker-studio-lab/blob/main/chapter_multilayer-perceptrons/mlp-scratch.ipynb)\n", "\n", "# 4.2 Implementation of Multilayer Perceptrons from Scratch\n", "\n", "\n", "Now that we have characterized\n", "multilayer perceptrons (MLPs) mathematically,\n", "let us try to implement one ourselves. To compare against our previous results\n", "achieved with softmax regression\n", "([Section 3.6](../chapter_linear-networks/softmax-regression-scratch.ipynb)),\n", "we will continue to work with\n", "the Fashion-MNIST image classification dataset\n", "([Section 3.5](../chapter_linear-networks/image-classification-dataset.ipynb)).\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.1 Initializing Model Parameters\n", "\n", "Recall that Fashion-MNIST contains 10 classes,\n", "and that each image consists of a $28 \\times 28 = 784$\n", "grid of grayscale pixel values.\n", "Again, we will disregard the spatial structure\n", "among the pixels for now,\n", "so we can think of this as simply a classification dataset\n", "with 784 input features and 10 classes.\n", "To begin, we will implement an MLP\n", "with one hidden layer and 256 hidden units.\n", "Note that we can regard both of these quantities\n", "as hyperparameters.\n", "Typically, we choose layer widths in powers of 2,\n", "which tend to be computationally efficient because\n", "of how memory is allocated and addressed in hardware.\n", "\n", "Again, we will represent our parameters with several tensors.\n", "Note that *for every layer*, we must keep track of\n", "one weight matrix and one bias vector.\n", "As always, we allocate memory\n", "for the gradients of the loss with respect to these parameters.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_inputs, num_outputs, num_hiddens = 784, 10, 256\n", "\n", "W1 = nn.Parameter(\n", " torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)\n", "b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))\n", "W2 = nn.Parameter(\n", " torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)\n", "b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))\n", "\n", "params = [W1, b1, W2, b2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.2 Activation Function\n", "\n", "To make sure we know how everything works,\n", "we will implement the ReLU activation ourselves\n", "using the maximum function rather than\n", "invoking the built-in `relu` function directly.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def relu(X):\n", " a = torch.zeros_like(X)\n", " return torch.max(X, a)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.3 Model\n", "\n", "Because we are disregarding spatial structure,\n", "we `reshape` each two-dimensional image into\n", "a flat vector of length `num_inputs`.\n", "Finally, we implement our model\n", "with just a few lines of code.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def net(X):\n", " X = X.reshape((-1, num_inputs))\n", " H = relu(X @ W1 + b1) # Here '@' stands for matrix multiplication\n", " return (H @ W2 + b2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.4 Loss Function\n", "\n", "To ensure numerical stability,\n", "and because we already implemented\n", "the softmax function from scratch\n", "([Section 3.6](../chapter_linear-networks/softmax-regression-scratch.ipynb)),\n", "we leverage the integrated function from high-level APIs\n", "for calculating the softmax and cross-entropy loss.\n", "Recall our earlier discussion of these intricacies\n", "in [Section 3.7.2](../chapter_linear-networks/softmax-regression-concise.ipynb#3.7.2-Softmax-Implementation-Revisited).\n", "We encourage the interested reader\n", "to examine the source code for the loss function\n", "to deepen their knowledge of implementation details.\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.5 Training\n", "\n", "Fortunately, the training loop for MLPs\n", "is exactly the same as for softmax regression.\n", "Leveraging the `d2l` package again,\n", "we call the `train_ch3` function\n", "(see [Section 3.6](../chapter_linear-networks/softmax-regression-scratch.ipynb)),\n", "setting the number of epochs to 10\n", "and the learning rate to 0.1.\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-07-24T08:34:05.217597\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.3, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 10, 0.1\n", "updater = torch.optim.SGD(params, lr=lr)\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To evaluate the learned model,\n", "we apply it on some test data.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-07-24T08:34:05.660244\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.3, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.predict_ch3(net, test_iter)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.6 Summary\n", "\n", "* We saw that implementing a simple MLP is easy, even when done manually.\n", "* However, with a large number of layers, implementing MLPs from scratch can still get messy (e.g., naming and keeping track of our model's parameters).\n", "\n", "\n", "## 4.2.7 Exercises\n", "\n", "1. Change the value of the hyperparameter `num_hiddens` and see how this hyperparameter influences your results. Determine the best value of this hyperparameter, keeping all others constant.\n", "1. Try adding an additional hidden layer to see how it affects the results.\n", "1. How does changing the learning rate alter your results? Fixing the model architecture and other hyperparameters (including number of epochs), what learning rate gives you the best results?\n", "1. What is the best result you can get by optimizing over all the hyperparameters (learning rate, number of epochs, number of hidden layers, number of hidden units per layer) jointly?\n", "1. Describe why it is much more challenging to deal with multiple hyperparameters.\n", "1. What is the smartest strategy you can think of for structuring a search over multiple hyperparameters?\n" ] } ], "metadata": { "instance_type": "ml.g4dn.xlarge", "kernelspec": { "display_name": "d2l", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 4 }