{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.8 Multilayer Perceptron" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.8.1 Hidden Layers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Single linear transformation $$ \\hat{\\mathbf{o}} = \\mathrm{softmax}(\\mathbf{W} \\mathbf{x} + \\mathbf{b}) $$ ![](https://github.com/diveintodeeplearning/d2l-en/raw/master/img/singlelayer.svg?sanitize=true)\n", " - ***linearity*** is a strong assumption\n", " - It means that increasing the value of the input should either drive the value of the output up or drive it down, irrespective of the value of the other inputs\n", " - In case of classifying cats and dogs based on black and white images\n", " - Increasing each pixel value either increases the probability that it depicts a dog or decreases it --> not reasonable\n", " \n", "- From one to many\n", " - We can model more generally by incorporating ***one or more hidden layers***. \n", " - Multi-layer Perceptron (MLP)\n", " - E.g., a hidden layer with 5 hidden units \n", " - both the hidden layer and the output layer in the multilayer perceptron are fully connected layers.\n", " ![](https://github.com/diveintodeeplearning/d2l-en/raw/master/img/mlp.svg?sanitize=true)\n", "\n", "- From linear to nonlinear\n", " - Wrong mathematical model\n", "$$ \\begin{aligned} \\mathbf{h} & = \\mathbf{W}_1 \\mathbf{x} + \\mathbf{b}_1 \\\\ \\mathbf{o} & = \\mathbf{W}_2 \\mathbf{h} + \\mathbf{b}_2 \\\\ \\hat{\\mathbf{y}} & = \\mathrm{softmax}(\\mathbf{o}) \\end{aligned} $$
\n", " - We can collapse out the hidden layer by an equivalently parametrized single layer perceptron

$$\\mathbf{o} = \\mathbf{W}_2 \\mathbf{h} + \\mathbf{b}_2 = \\mathbf{W}_2 (\\mathbf{W}_1 \\mathbf{x} + \\mathbf{b}_1) + \\mathbf{b}_2 = (\\mathbf{W}_2 \\mathbf{W}_1) \\mathbf{x} + (\\mathbf{W}_2 \\mathbf{b}_1 + \\mathbf{b}_2) = \\mathbf{W} \\mathbf{x} + \\mathbf{b}$$
\n", " - To fix this we need another key ingredient - a nonlinearity $\\sigma$ such as $\\mathrm{max}(x,0)$ after each layer

$$ \\begin{aligned} \\mathbf{h} & = \\sigma(\\mathbf{W}_1 \\mathbf{x} + \\mathbf{b}_1) \\\\ \\mathbf{o} & = \\mathbf{W}_2 \\mathbf{h} + \\mathbf{b}_2 \\\\ \\hat{\\mathbf{y}} & = \\mathrm{softmax}(\\mathbf{o}) \\end{aligned} $$
\n", " - Clearly we could continue stacking such hidden layers, e.g. $\\mathbf{h}_1 = \\sigma(\\mathbf{W}_1 \\mathbf{x} + \\mathbf{b}_1)$ and $\\mathbf{h}_2 = \\sigma(\\mathbf{W}_2 \\mathbf{h}_1 + \\mathbf{b}_2)$ on top of each other to obtain a true multilayer perceptron.\n", " \n", " - Multilayer perceptrons are universal approximators.\n", " - Even for a single-hidden-layer neural network, with enough nodes, and the right set of weights, it could model any function at all! \n", " - Actually learning that function is the hard part. \n", " - It turns out that we can approximate functions much more compactly if we use deeper (vs wider) neural networks.\n", " \n", "- Vectorization and mini-batch\n", " - denote by $\\mathbf{X}$ the matrix of inputs from a minibatch. \n", " - Then an MLP with two hidden layers can be expressed as $$ \\begin{aligned} \\mathbf{H}_1 & = \\sigma(\\mathbf{W}_1 \\mathbf{X} + \\mathbf{b}_1) \\\\ \\mathbf{H}_2 & = \\sigma(\\mathbf{W}_2 \\mathbf{H}_1 + \\mathbf{b}_2) \\\\ \\mathbf{O} & = \\mathrm{softmax}(\\mathbf{W}_3 \\mathbf{H}_2 + \\mathbf{b}_3) \\end{aligned} $$\n", " - With some abuse of notation, we define the nonlinearity $\\sigma$ to apply to its inputs on a ***row-wise fashion***\n", " - i.e. ***one observation at a time***, often one coordinate at a time. \n", " - This is true for most activation functions\n", " - But, **batch normalization*** is a notable exception from that rule.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.8.2 Activation Functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- ReLU (rectified linear unit) function $$\\mathrm{ReLU}(x) = \\max(x, 0).$$\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import gluonbook as gb\n", "from mxnet import autograd, nd\n", "\n", "def xyplot(x_vals, y_vals, name):\n", " gb.set_figsize(figsize=(5, 2.5))\n", " gb.plt.plot(x_vals.asnumpy(), y_vals.asnumpy())\n", " gb.plt.xlabel('x')\n", " gb.plt.ylabel(name + '(x)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- ReLU activation function is a two-stage linear function.\n", "- ***ReLU reduces the issue of the vanishing gradient problem***. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = nd.arange(-8.0, 8.0, 0.1)\n", "x.attach_grad()\n", "with autograd.record():\n", " y = x.relu()\n", "xyplot(x, y, 'relu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The input is negative --> the derivative of ReLU function is 0\n", "- The input is positive --> the derivative of ReLU function is 1.\n", "- Note that the ReLU function is not differentiable when the input is 0.\n", " - Instead, we pick its left-hand-side (LHS) derivative 0 at location 0." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y.backward()\n", "xyplot(x, x.grad, 'grad of relu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Parameterized ReLU\n", " - https://arxiv.org/abs/1502.01852: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification (Feb 2015) $$\\mathrm{pReLU}(x) = \\max(0, x) - \\alpha x$$\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Sigmoid Function\n", " - The Sigmoid function can transform the value of an element in $\\mathbb{R}$ to the interval $(0,1)$. $$\\mathrm{sigmoid}(x) = \\frac{1}{1 + \\exp(-x)}.$$\n", " - In the “Recurrent Neural Network”, we will describe how to utilize the function's ability to control the flow of information in a neural network thanks to its capacity to transform the value range between 0 and 1. \n", " - When the input is close to 0, the Sigmoid function approaches a linear transformation." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "with autograd.record():\n", " y = x.sigmoid()\n", "xyplot(x, y, 'sigmoid')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The derivative of Sigmoid function $$\\frac{d}{dx} \\mathrm{sigmoid}(x) = \\frac{-\\exp(x)}{(1 + \\exp(-x))^2} = \\mathrm{sigmoid}(x)\\left(1-\\mathrm{sigmoid}(x)\\right).$$\n", " - The input is 0, the derivative of the Sigmoid function reaches a maximum of 0.25\n", " - As the input deviates further from 0, the derivative of Sigmoid function approaches 0." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y.backward()\n", "xyplot(x, x.grad, 'grad of sigmoid')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Tanh Function\n", " - The Tanh (Hyperbolic Tangent) function transforms the value of an element to the interval between -1 and 1: $$\\text{tanh}(x) = \\frac{1 - \\exp(-2x)}{1 + \\exp(-2x)}.$$\n", " - Tanh function is symmetric at the origin of the coordinate system." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "with autograd.record():\n", " y = x.tanh()\n", "xyplot(x, y, 'tanh')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The derivative of the Tanh function $$\\frac{d}{dx} \\mathrm{tanh}(x) = 1 - \\mathrm{tanh}^2(x).$$\n", " - The input is 0, the derivative of the Tanh function reaches a maximum of 1.0\n", " - As the input deviates further from 0, the derivative of Than function approaches 0." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y.backward()\n", "xyplot(x, x.grad, 'grad of tanh')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.9 Implementing a Multilayer Perceptron from Scratch" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import gluonbook as gb\n", "from mxnet import nd\n", "from mxnet.gluon import loss as gloss" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "batch_size = 256\n", "train_iter, test_iter = gb.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.9.1 Initialize Model Parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- an MLP with one hidden layer\n", " - the number of hidden units: 256" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "num_inputs, num_outputs, num_hiddens = 784, 10, 256\n", "\n", "W1 = nd.random.normal(scale=0.01, shape=(num_inputs, num_hiddens))\n", "b1 = nd.zeros(num_hiddens)\n", "W2 = nd.random.normal(scale=0.01, shape=(num_hiddens, num_outputs))\n", "b2 = nd.zeros(num_outputs)\n", "params = [W1, b1, W2, b2]\n", "\n", "for param in params:\n", " param.attach_grad()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.9.2 Activation Function" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def relu(X):\n", " return nd.maximum(X, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.9.3 The model" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def net(X):\n", " X = X.reshape((-1, num_inputs)) # X: (-1, 784), W1: (784, 256)\n", " H = relu(nd.dot(X, W1) + b1) # nd.dot(X, W1): (-1, 256), b1: (256,), W2: (256, 10)\n", " return nd.dot(H, W2) + b2 # nd.dot(H, W2): (-1, 10), b2: (10,)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.9.4 The Loss Function\n", "- For better numerical stability, we use Gluon’s functions, including softmax calculation and cross-entropy loss calculation." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "loss = gloss.SoftmaxCrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.9.5 Training" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss 0.8047, train acc 0.699, test acc 0.791\n", "epoch 2, loss 0.4860, train acc 0.818, test acc 0.840\n", "epoch 3, loss 0.4252, train acc 0.842, test acc 0.861\n", "epoch 4, loss 0.3975, train acc 0.852, test acc 0.861\n", "epoch 5, loss 0.3736, train acc 0.862, test acc 0.873\n", "epoch 6, loss 0.3519, train acc 0.871, test acc 0.875\n", "epoch 7, loss 0.3374, train acc 0.875, test acc 0.875\n", "epoch 8, loss 0.3235, train acc 0.882, test acc 0.880\n", "epoch 9, loss 0.3195, train acc 0.882, test acc 0.877\n", "epoch 10, loss 0.3070, train acc 0.886, test acc 0.884\n" ] } ], "source": [ "num_epochs, lr = 10, 0.5\n", "\n", "gb.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for X, y in test_iter:\n", " break\n", "\n", "true_labels = gb.get_fashion_mnist_labels(y.asnumpy())\n", "pred_labels = gb.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())\n", "titles = [truelabel + '\\n' + predlabel for truelabel, predlabel in zip(true_labels, pred_labels)]\n", "\n", "gb.show_fashion_mnist(X[0:9], titles[0:9])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.10 Multilayer Perceptron in Gluon" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import gluonbook as gb\n", "from mxnet import gluon, init\n", "from mxnet.gluon import loss as gloss, nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.10.1 The Model\n", "- Note that Gluon automagically infers the missing parameteters, such as the fact that the second layer needs a matrix of size 256×10. \n", " - This happens the first time the network is invoked." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "net = nn.Sequential()\n", "net.add(nn.Dense(256, activation='relu'))\n", "net.add(nn.Dense(10))\n", "net.initialize(init.Normal(sigma=0.01))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss 0.8092, train acc 0.700, test acc 0.816\n", "epoch 2, loss 0.4949, train acc 0.816, test acc 0.846\n", "epoch 3, loss 0.4288, train acc 0.843, test acc 0.862\n", "epoch 4, loss 0.3998, train acc 0.853, test acc 0.867\n", "epoch 5, loss 0.3741, train acc 0.862, test acc 0.862\n", "epoch 6, loss 0.3517, train acc 0.870, test acc 0.874\n", "epoch 7, loss 0.3437, train acc 0.874, test acc 0.878\n", "epoch 8, loss 0.3287, train acc 0.879, test acc 0.880\n", "epoch 9, loss 0.3201, train acc 0.882, test acc 0.876\n", "epoch 10, loss 0.3086, train acc 0.886, test acc 0.877\n" ] } ], "source": [ "batch_size = 256\n", "train_iter, test_iter = gb.load_data_fashion_mnist(batch_size)\n", "\n", "loss = gloss.SoftmaxCrossEntropyLoss()\n", "\n", "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})\n", "\n", "num_epochs = 10\n", "gb.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, trainer)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for X, y in test_iter:\n", " break\n", "\n", "true_labels = gb.get_fashion_mnist_labels(y.asnumpy())\n", "pred_labels = gb.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())\n", "titles = [truelabel + '\\n' + predlabel for truelabel, predlabel in zip(true_labels, pred_labels)]\n", "\n", "gb.show_fashion_mnist(X[0:9], titles[0:9])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }