{ "cells": [ { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T02:19:02.272939Z", "start_time": "2019-04-07T02:19:02.268259Z" }, "slideshow": { "slide_type": "slide" } }, "source": [ "***\n", "***\n", "\n", "# Introduction to Neural Network\n", "\n", "\n", "***\n", "***" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:41:09.971531Z", "start_time": "2019-04-07T05:41:09.967172Z" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "\n", "http://playground.tensorflow.org/" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T09:37:45.438709Z", "start_time": "2019-04-07T09:37:45.434406Z" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "- Deep Learning http://www.deeplearningbook.org\n", " - An MIT Press book by Ian Goodfellow, Yoshua Bengio and Aaron Courville\n", "\n", "- Neural Networks and Deep Learning http://neuralnetworksanddeeplearning.com/index.html\n", "\n", " - A free online book explaining the core ideas behind artificial neural networks and deep learning. [Code](https://github.com/mnielsen/neural-networks-and-deep-learning). By [Michael Nielsen](http://michaelnielsen.org/) / Dec 2017" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## House Price\n", "\n", "Let’s start with a simple example. \n", "- Say you’re helping a friend who wants to buy a house.\n", "\n", "- She was quoted $400,000 for a 2000 sq ft house (185 meters). \n", "\n", "Is this a good price or not?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "So you ask your friends who have bought houses in that same neighborhoods, and you end up with three data points:\n", "\n", "\n", "\n", "| Area (sq ft) (x) | Price (y) | \n", "| -------------|:-------------:|\n", "|2,104|399,900|\n", "|1,600|329,900|\n", "|2,400|369,000|" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "$$y = f(X) = W X$$\n", "\n", "- Calculating the prediction is simple multiplication.\n", "- But before that, we need to think about the weight we’ll be multiplying by. \n", "- “training” a neural network just means finding the weights we use to calculate the prediction.\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "A simple predictive model (“regression model”)\n", "- takes an input, \n", "- does a calculation, \n", "- and gives an output \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T03:15:14.317623Z", "start_time": "2019-04-07T03:15:14.313438Z" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "Model Evaluation\n", "- If we apply our model to the three data points we have, how good of a job would it do?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Loss Function (also, cost function)**\n", "\n", "- For each point, the error is measured by the difference between the **actual value** and the **predicted value**, raised to the power of 2. \n", "- This is called **Mean Square Error**. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- We can't improve much on the model by varying the weight any more. \n", "- But if we add a bias (intercept) we can find values that improve the model.\n", "\n", "\n", "\n", "$$y = 0.1 X + 150$$" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Gradient Descent**\n", "\n", "- Automatically get the correct weight and bias values \n", "- minimize the loss function.\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## softmax\n", "\n", "The softmax function, also known as softargmax or normalized exponential function, is a function that takes as input a vector of K real numbers, and normalizes it into a probability distribution consisting of K probabilities. \n", "\n", "$$softmax = \\frac{e^x}{\\sum e^x}$$\n" ] }, { "cell_type": "code", "execution_count": 128, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T13:57:33.466056Z", "start_time": "2019-04-07T13:57:33.459580Z" }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "array([0.02364054, 0.06426166, 0.1746813 , 0.474833 , 0.02364054,\n", " 0.06426166, 0.1746813 ])" ] }, "execution_count": 128, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def softmax(s):\n", " return np.exp(s) / np.sum(np.exp(s), axis=0)\n", "\n", "softmax([1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "That is, prior to applying softmax, some vector components could be negative, or greater than one; and might not sum to 1; but after applying softmax, each component will be in the interval (0,1), and the components will add up to 1, so that they can be interpreted as probabilities. \n", "\n", "Furthermore, the larger input components will correspond to larger probabilities. \n", "\n", "Softmax is often used in neural networks, to map the non-normalized output of a network to a probability distribution over predicted output classes." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Activation Function\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T03:59:39.599371Z", "start_time": "2019-04-07T03:59:39.593104Z" }, "code_folding": [], "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "0.6224593312018546" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "def sigmoid(x):\n", " return 1/(1 + np.exp(-x))\n", "\n", "sigmoid(0.5)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T04:01:24.575040Z", "start_time": "2019-04-07T04:01:24.567328Z" }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "0.5" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Naive scalar relu implementation. \n", "# In the real world, most calculations are done on vectors\n", "def relu(x):\n", " if x < 0:\n", " return 0\n", " else:\n", " return x\n", "\n", "relu(0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# WHAT IS PYTORCH?\n", "\n", "It’s a Python-based scientific computing package targeted at two sets of audiences:\n", "\n", "- A replacement for NumPy to use the power of GPUs\n", "- a deep learning research platform that provides maximum flexibility and speed" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T04:43:15.243401Z", "start_time": "2019-04-07T04:43:14.333403Z" } }, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch import nn, optim\n", "from torch.autograd import Variable\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:28.414285Z", "start_time": "2019-04-07T05:26:28.201996Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAFRdJREFUeJzt3X+QXWd93/H3B0n+0RIswEuiWKKisWkSPLHsbISVhERI40Q2TEQbu3HKDC526oFAQphM+FE6mXFdxjFu69QdxoxTh8iUBrvmRxwHJri2FcKMLGVFZAWDgQVMbeHiBWwTD7Uci2//uI/weitp77XuatcP79fMmfuc5zzn3u89u/vZs+ecuydVhSSpX89Z7AIkSQvLoJekzhn0ktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXMGvSR1bvliFwBw8skn19q1axe7DEl6Vtm9e/c3q2pivnFLIujXrl3L1NTUYpchSc8qSb42zDgP3UhS5wx6SeqcQS9JnTPoJalzQwd9kmVJ/jbJrW3+JUl2JplOcmOS41r/8W1+ui1fuzClS5KGMcoe/VuAz8+avxK4uqpOBR4GLmn9lwAPt/6r2zhJ0iIZKuiTrAZeBfy3Nh9gE3BzG7INeE1rb23ztOWb23hJ47RjB1xxxeBROoJhr6P/Q+BtwA+1+RcCj1TVk23+AeCU1j4FuB+gqp5M8mgb/82xVCxpEO6bN8MTT8Bxx8Htt8OGDYtdlZaoeffok7waeKiqdo/zhZNcmmQqydTMzMw4n1rq3/btg5A/cGDwuH37YlekJWyYQzc/B/xKkvuADzE4ZPNfgJVJDv5FsBrY19r7gDUAbflJwLfmPmlVXVdVk1U1OTEx7yd4Jc22ceNgT37ZssHjxo2LXZGWsHmDvqreWVWrq2otcCFwR1W9FrgTOL8Nuwj4s9a+pc3Tlt9RVTXWqqUfdBs2DA7XXH65h200r6P5XzdvBz6U5D8Afwtc3/qvBz6QZBr4NoNfDpLGbcMGA15DGSnoq2o7sL21vwKsP8SYx4ELxlCbJGkM/GSsJHXOoJekzhn0ktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXMGvSR1zqCXpM4Z9JLUOYNekjpn0EtS5wx6SeqcQS9JnTPoJalzBr0kdc6gl6TOGfSS1DmDXpI6N2/QJzkhya4kdye5J8llrX9Tks8k+WySbUmWt/4kuSbJdJK9Sc5a6DchSTq8Yfbo9wObquoMYB2wJcnPAtuAC6vqdOBrwEVt/LnAaW26FLh27FVLkoY2b9DXwGNtdkWbDgBPVNUXW/9twK+29lbghrbeXcDKJKvGXLckaUhDHaNPsizJHuAhBqG+C1ieZLINOR9Y09qnAPfPWv2B1jf3OS9NMpVkamZm5pnWL0max1BBX1UHqmodsBpYD7wMuBC4Osku4O8Z7OUPraquq6rJqpqcmJgYsWxJ0rBGuuqmqh4B7gS2VNWOqnpFVa0HPgUcPIyzj6f27mHwy2HfOIqVJI1umKtuJpKsbO0TgXOAe5O8qPUdD7wdeF9b5Rbgde3qm7OBR6vqwQWpXpI0r+VDjFkFbEuyjMEvhpuq6tYkVyV5deu7tqruaOM/DpwHTAPfBV6/AHVLkoaUqlrsGpicnKypqanFLkOSnlWS7K6qyfnG+clYSeqcQS9JnTPoJalzBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUucMeknqnEEvSZ0z6CWpcwa9JHXOoJekzhn0ktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXPzBn2SE5LsSnJ3knuSXNb6Nyf5TJI9ST6d5NTWf3ySG5NMJ9mZZO3CvgVJ0pEMs0e/H9hUVWcA64AtSc4GrgVeW1XrgP8B/Ls2/hLg4ao6FbgauHL8ZUuShjVv0NfAY212RZuqTc9r/ScBX2/trcC21r4Z2JwkY6tYkjSS5cMMSrIM2A2cCry3qnYm+Q3g40n+L/Ad4Ow2/BTgfoCqejLJo8ALgW/Oec5LgUsBXvziF4/hrUiSDmWok7FVdaAdolkNrE9yOvBW4LyqWg28H/jPo7xwVV1XVZNVNTkxMTFq3ZKkIY101U1VPQLcCZwLnFFVO9uiG4Gfbe19wBqAJMsZHNb51liqlSSNbJirbiaSrGztE4FzgM8DJyV5aRt2sA/gFuCi1j4fuKOqaqxVS5KGNswx+lXAtnac/jnATVV1a5J/A3w4yfeAh4GL2/jrgQ8kmQa+DVy4AHVLkoY0b9BX1V7gzEP0fxT46CH6HwcuGEt1kqSj5idjJalzBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUucMeknqnEEvSZ0z6CWpcwa9JHXOoJekzhn0ktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXMGvSR1zqCXpM7NG/RJTkiyK8ndSe5Jclnr/+ske9r09SQfa/1Jck2S6SR7k5y10G9CknR4894cHNgPbKqqx5KsAD6d5BNV9YqDA5J8GPizNnsucFqbXg5c2x4lSYtg3j36Gnisza5oUx1cnuR5wCbgY61rK3BDW+8uYGWSVeMtW5I0rKGO0SdZlmQP8BBwW1XtnLX4NcDtVfWdNn8KcP+s5Q+0PknSIhgq6KvqQFWtA1YD65OcPmvxrwN/OuoLJ7k0yVSSqZmZmVFXlyQNaaSrbqrqEeBOYAtAkpOB9cBfzBq2D1gza35165v7XNdV1WRVTU5MTIxatyQ9++3YAVdcMXhcQPOejE0yAfxDVT2S5ETgHODKtvh84NaqenzWKrcAb07yIQYnYR+tqgfHXLckPbvt2AGbN8MTT8Bxx8Htt8OGDQvyUsPs0a8C7kyyF/gbBsfob23LLuT/P2zzceArwDTwR8BvjqlWSerH9u2DkD9wYPC4ffuCvdS8e/RVtRc48zDLNh6ir4A3HXVlktSzjRsHe/IH9+g3blywlxrmOnpJ0rht2DA4XLN9+yDkF+iwDRj0krR4NmxY0IA/yP91I0mdM+glqXMGvSR1zqCXpM4Z9JLUOYNekjpn0EtS5wx6SeqcQS9JnTPoJalzBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUucMeknqnEEvSZ2bN+iTnJBkV5K7k9yT5LLWnyTvTvLFJJ9P8tuz+q9JMp1kb5KzFvpNSJIOb5hbCe4HNlXVY0lWAJ9O8gngJ4A1wI9X1feSvKiNPxc4rU0vB65tj5KkRTBv0FdVAY+12RVtKuCNwL+qqu+1cQ+1MVuBG9p6dyVZmWRVVT049uolSfMa6hh9kmVJ9gAPAbdV1U7gx4BfSzKV5BNJTmvDTwHun7X6A61PkrQIhgr6qjpQVeuA1cD6JKcDxwOPV9Uk8EfAH4/ywkkubb8kpmZmZkatW5I0pJGuuqmqR4A7gS0M9tQ/0hZ9FPip1t7H4Nj9Qatb39znuq6qJqtqcmJiYtS6JUlDGuaqm4kkK1v7ROAc4F7gY8Ar27BfBL7Y2rcAr2tX35wNPOrxeUlaPMNcdbMK2JZkGYNfDDdV1a1JPg18MMlbGZys/Y02/uPAecA08F3g9eMvW5I0rGGuutkLnHmI/keAVx2iv4A3jaU6SdJR85OxktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXMGvSR1zqCXpM4Z9JLUOYNekjpn0EtS5wx6SeqcQS9JnTPoJalzBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUufmDfokJyTZleTuJPckuaz1/0mSrybZ06Z1rT9JrkkynWRvkrMW+k1Ikg5v3puDA/uBTVX1WJIVwKeTfKIt+72qunnO+HOB09r0cuDa9ihJWgTz7tHXwGNtdkWb6girbAVuaOvdBaxMsuroS5UkPRNDHaNPsizJHuAh4Laq2tkWvbsdnrk6yfGt7xTg/lmrP9D6JEmLYKigr6oDVbUOWA2sT3I68E7gx4GfAV4AvH2UF05yaZKpJFMzMzMjli1JGtZIV91U1SPAncCWqnqwHZ7ZD7wfWN+G7QPWzFptdeub+1zXVdVkVU1OTEw8s+olSfMa5qqbiSQrW/tE4Bzg3oPH3ZMEeA3w2bbKLcDr2tU3ZwOPVtWDC1K9JGlew1x1swrYlmQZg18MN1XVrUnuSDIBBNgDvKGN/zhwHjANfBd4/fjLliQNa96gr6q9wJmH6N90mPEFvOnoS5MkjYOfjJWkzhn0ktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXMGvSR1zqCXpM4Z9JLUOYNekjpn0EtS5wx6SeqcQS9JnTPoJalzBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUufmDfokJyTZleTuJPckuWzO8muSPDZr/vgkNyaZTrIzydrxly1JGtYwe/T7gU1VdQawDtiS5GyAJJPA8+eMvwR4uKpOBa4GrhxjvZKkEc0b9DVwcI99RZsqyTLgKuBtc1bZCmxr7ZuBzUkypnolSSMa6hh9kmVJ9gAPAbdV1U7gzcAtVfXgnOGnAPcDVNWTwKPACw/xnJcmmUoyNTMzczTvQZJ0BEMFfVUdqKp1wGpgfZJfAC4A/uszfeGquq6qJqtqcmJi4pk+jSRpHiNddVNVjwB3Aq8ETgWmk9wH/KMk023YPmANQJLlwEnAt8ZVsCRpNMNcdTORZGVrnwicA+yuqh+pqrVVtRb4bjv5CnALcFFrnw/cUVU1/tIlScNYPsSYVcC2dvL1OcBNVXXrEcZfD3yg7eF/G7jw6MuUJD1T8wZ9Ve0FzpxnzHNntR9ncPxekrQE+MlYSeqcQS9JnTPoJalzBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUucMeknqnEEvSZ0z6CWpcwa9JHXOoJekzhn0ktQ5g16SOmfQS1LnDHpJ6pxBL0mdM+glqXPzBn2SE5LsSnJ3knuSXNb6r299e5PcnOS5rf/4JDcmmU6yM8nahX0LkqQjGWaPfj+wqarOANYBW5KcDby1qs6oqp8C/jfw5jb+EuDhqjoVuBq4cgHqliQNad6gr4HH2uyKNlVVfQcgSYATgWpjtgLbWvtmYHMbM347dsAVVwweJUmHtHyYQUmWAbuBU4H3VtXO1v9+4Dzgc8DvtuGnAPcDVNWTSR4FXgh8c6yV79gBmzfDE0/AccfB7bfDhg1jfQlJ6sFQJ2Or6kBVrQNWA+uTnN76Xw/8KPB54NdGeeEklyaZSjI1MzMzYtnA9u2DkD9wYPC4ffvozyFJPwBGuuqmqh4B7gS2zOo7AHwI+NXWtQ9YA5BkOXAS8K1DPNd1VTVZVZMTExOjV75x42BPftmywePGjaM/hyT9ABjmqpuJJCtb+0TgHOALSU5tfQF+Bbi3rXILcFFrnw/cUVXFuG3YMDhcc/nlHraRpCMY5hj9KmBbO07/HOAm4C+Av07yPCDA3cAb2/jrgQ8kmQa+DVw49qoP2rDBgJekecwb9FW1FzjzEIt+7jDjHwcuOMq6JElj4idjJalzBr0kdc6gl6TOGfSS1DmDXpI6l4W4xH3kIpIZ4GvPcPWTGfe/VxgP6xqNdY1uqdZmXaM5mrr+SVXN+4nTJRH0RyPJVFVNLnYdc1nXaKxrdEu1NusazbGoy0M3ktQ5g16SOtdD0F+32AUchnWNxrpGt1Rrs67RLHhdz/pj9JKkI+thj16SdARLMuiT/HGSh5J8dk7/byW5t92k/D2z+t/Zbkb+hSS/PKt/S+ubTvKOhair3Qh9T5vuS7JnidS1Lsldra6pJOtbf5Jc0157b5KzZq1zUZIvtemiQ73WGOo6I8mOJH+X5M/bf0A9uOxYba81Se5M8rn2vfSW1v+CJLe1939bkue3/mOyzY5Q1wVt/ntJJuess+Db7Ah1XdV+Hvcm+WjavzNfAnVd3mrak+STSX609S/q13HW8t9NUklOPmZ1VdWSm4BfAM4CPjur75XA/wKOb/Mvao8/yeDfJB8PvAT4MrCsTV8G/ilwXBvzk+Oua87y/wT8/lKoC/gkcG5rnwdsn9X+BIN/L302sLP1vwD4Snt8fms/fwHq+hvgF1v7YuDyRdheq4CzWvuHgC+2138P8I7W/w7gymO5zY5Q108A/wzYDkzOGn9MttkR6volYHnrv3LW9lrsup43a8xvA+9bCl/HNr8G+EsGnxs6+VjVtST36KvqUwz+l/1sbwT+oKr2tzEPtf6twIeqan9VfRWYBta3abqqvlJVTzC4C9bWBagL+P4NWP4l8KdLpK4CDu4tnwR8fVZdN9TAXcDKJKuAXwZuq6pvV9XDwG3MupPYGOt6KfCp1r6Np+5Mdiy314NV9ZnW/nsGt8I8haff2H4b8JpZtS34NjtcXVX1+ar6wiFWOSbb7Ah1fbKqnmzD7mJwq9GlUNd3Zg37xwx+Fg7WtWhfx7b4auBts2o6JnUtyaA/jJcCr0iyM8lfJfmZ1v/9m5E3D7S+w/UvlFcA36iqLy2Run4HuCrJ/cB/BN65ROq6h6d+uC+g3XZysepKspbB/RZ2Aj9cVQ+2Rf8H+OHFqm1OXYezlOq6mMFe6ZKoK8m72/f+a4HfXwp1JdkK7Kuqu+cMW/C6nk1Bv5zBnzBnA78H3NT2opeKX+epvfml4I3AW6tqDfBWBnf+WgouBn4zyW4Gf9Y+sViFJHku8GHgd+bsBVKDv50X5ZK0I9W1mA5XV5J3AU8CH1wqdVXVu9r3/geBNy92XQy2z7/lqV86x9SzKegfAD7S/rzZBXyPwf+I+P7NyJvVre9w/WOXwU3Q/wVw46zuxa7rIuAjrf0/GfzZvOh1VdW9VfVLVfXTDH4xfnkx6kqygsEP4Qer6uB2+kb7k5n2ePDw4DGr7TB1Hc6i15XkXwOvBl7bfjkuibpm+SBPHR5czLp+jMH5iruT3Nde4zNJfuSY1PVMDuwfiwlYy9NP4r0B+Pet/VIGf9IEeBlPP/HzFQYnfZa39kt46sTPy8ZdV+vbAvzVnL5FrYvBccGNrb0Z2N3ar+LpJ3521VMnfr7K4KTP81v7BQtQ18GT6M8BbgAuPtbbq733G4A/nNN/FU8/GfueY7nNDlfXrOXbefrJ2GOyzY6wvbYAnwMmFuN7/wh1nTar/VvAzUvp69jG3MdTJ2MXvK6j+iFeqInBnt6DwD8w2JO/pH1j/Hfgs8BngE2zxr+LwZ7hF2hXmrT+8xic8f4y8K6FqKv1/wnwhkOMX7S6gJ8Hdrcfpp3AT8/6Jnxve+2/4+nBcTGDE2fTwOsXqK63tPf+ReAPaB/aO8bb6+cZHJbZC+xp03nAC4HbgS8xuMLrBcdymx2hrn/ett9+4BvAXx7LbXaEuqYZ7HAd7HvfEqnrwwxyYi/w5wxO0C7613HOmPt4KugXvC4/GStJnXs2HaOXJD0DBr0kdc6gl6TOGfSS1DmDXpI6Z9BLUucMeknqnEEvSZ37fx1iIYZvFmeYAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_train = np.array([[2104],[1600],[2400]], dtype=np.float32)\n", "y_train = np.array([[399.900], [329.900], [369.000]], dtype=np.float32)\n", "\n", "# x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n", "# [9.779], [6.182], [7.59], [2.167], [7.042],\n", "# [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n", "\n", "# y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n", "# [3.366], [2.596], [2.53], [1.221], [2.827],\n", "# [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)\n", "\n", "plt.plot(x_train, y_train, 'r.')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 105, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:28.745318Z", "start_time": "2019-04-07T05:26:28.741971Z" } }, "outputs": [], "source": [ "x_train = torch.from_numpy(x_train)\n", "y_train = torch.from_numpy(y_train)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T04:47:17.479477Z", "start_time": "2019-04-07T04:47:17.468259Z" } }, "source": [ "`nn.Linear`\n", "\n", "Applies a linear transformation to the incoming data: $y = xA^T + b$" ] }, { "cell_type": "code", "execution_count": 106, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:30.527459Z", "start_time": "2019-04-07T05:26:30.516864Z" } }, "outputs": [], "source": [ "# Linear Regression Model\n", "class LinearRegression(nn.Module):\n", " def __init__(self):\n", " super(LinearRegression, self).__init__()\n", " self.linear = nn.Linear(1, 1) # input and output is 1 dimension\n", "\n", " def forward(self, x):\n", " out = self.linear(x)\n", " return out\n", "\n", "model = LinearRegression()" ] }, { "cell_type": "code", "execution_count": 107, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:31.378932Z", "start_time": "2019-04-07T05:26:31.375266Z" } }, "outputs": [], "source": [ "# Define Loss and Optimizatioin function\n", "criterion = nn.MSELoss()\n", "optimizer = optim.SGD(model.parameters(), lr=1e-9)#1e-4)" ] }, { "cell_type": "code", "execution_count": 108, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:32.848256Z", "start_time": "2019-04-07T05:26:32.681980Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch[50/1000], loss: 3798.923096\n", "Epoch[100/1000], loss: 2773.092773\n", "Epoch[150/1000], loss: 2336.129639\n", "Epoch[200/1000], loss: 2150.003906\n", "Epoch[250/1000], loss: 2070.720459\n", "Epoch[300/1000], loss: 2036.951782\n", "Epoch[350/1000], loss: 2022.566162\n", "Epoch[400/1000], loss: 2016.437866\n", "Epoch[450/1000], loss: 2013.828125\n", "Epoch[500/1000], loss: 2012.717407\n", "Epoch[550/1000], loss: 2012.243286\n", "Epoch[600/1000], loss: 2012.041260\n", "Epoch[650/1000], loss: 2011.954956\n", "Epoch[700/1000], loss: 2011.918335\n", "Epoch[750/1000], loss: 2011.904053\n", "Epoch[800/1000], loss: 2011.897217\n", "Epoch[850/1000], loss: 2011.894409\n", "Epoch[900/1000], loss: 2011.893311\n", "Epoch[950/1000], loss: 2011.892456\n", "Epoch[1000/1000], loss: 2011.890991\n" ] } ], "source": [ "num_epochs = 1000\n", "for epoch in range(num_epochs):\n", " inputs = Variable(x_train)\n", " target = Variable(y_train)\n", "\n", " # forward\n", " out = model(inputs)\n", " loss = criterion(out, target)\n", " # backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if (epoch+1) % 50 == 0:\n", " print('Epoch[{}/{}], loss: {:.6f}'\n", " .format(epoch+1, num_epochs, loss.data.item()))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad\n", " l_n = \\left( x_n - y_n \\right)^2,\n", "$$\n", " where :`N` is the batch size. " ] }, { "cell_type": "code", "execution_count": 109, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:38.112863Z", "start_time": "2019-04-07T05:26:38.108523Z" } }, "outputs": [ { "data": { "text/plain": [ "LinearRegression(\n", " (linear): Linear(in_features=1, out_features=1, bias=True)\n", ")" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.eval()" ] }, { "cell_type": "code", "execution_count": 118, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:54:13.026497Z", "start_time": "2019-04-07T05:54:12.777585Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAETCAYAAAD3WTuEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl8VNX5x/HPY1gCAoKIgGyhij9ZjRpXfloFF2oRrNL+sLjghlWpuGFVrFTRFoWKtIKKQgWbVlREqKKCAu6y7yASNEBQISCgNIIs5/fHuUMWCJAhM3cy832/XvPKPWfuzDyZLM8899x7jjnnEBERicZhYQcgIiIVl5KIiIhETUlERESipiQiIiJRUxIREZGoKYmIiEjUlERERCRqSiIiIhI1JREREYlapbADiLWjjjrKZWRkhB2GiEiFMmfOnA3OuXoH2i/pk0hGRgazZ88OOwwRkQrFzFYdzH46nCUiIlFTEhERkagpiYiISNSSfkxkX3bs2EFeXh7btm0LOxQB0tPTady4MZUrVw47FBEpo5RMInl5edSsWZOMjAzMLOxwUppzjo0bN5KXl0fz5s3DDkdEyiglD2dt27aNunXrKoEkADOjbt26qgpFykmDBmC2961Bg9i8XkomEUAJJIHoZyFSftatK1v/oUrZJCIiIodOSSQkeXl5dO3alRYtWnDsscfSp08ffvrpp33u+/XXX9OtW7cDPufFF1/M5s2bo4rnT3/6E4MHDz7gfjVq1Njv/Zs3b2b48OFRxSAiFY+SyMHIzoaMDDjsMP81O/uQns45x2WXXcall17KihUr+OKLL9i6dSv9+vXba9+dO3dyzDHH8Oqrrx7weSdNmkTt2rUPKbZDpSQiEo4tW6BXr/i/rpLIgWRn+5/MqlXgnP/aq9chJZKpU6eSnp7OtddeC0BaWhpDhgxh1KhRFBQU8MILL9ClSxc6dOhAx44dyc3NpU2bNgAUFBTwm9/8hlatWvGrX/2K008/fc+0LhkZGWzYsIHc3FxatmzJjTfeSOvWrbnwwgv58ccfAXjuuec49dRTOfHEE7n88sspKCjYb6xfffUVZ555Jm3btuWBBx7Y079161Y6duzIySefTNu2bZkwYQIA9957LytXriQzM5O+ffuWup+IlJ+JE6FVKxg5Mv6vnRBJxMzSzGyemb0RtLPNbLmZLTazUWZWOeg3M/ubmeWY2UIzOznmwfXrByX/0RYU+P4oLVmyhFNOOaVYX61atWjatCk5OTkAzJ07l1dffZX333+/2H7Dhw+nTp06LF26lAEDBjBnzpx9vsaKFSu49dZbWbJkCbVr12bcuHEAXHbZZcyaNYsFCxbQsmVLRh7gt65Pnz7cfPPNLFq0iIYNG+7pT09PZ/z48cydO5dp06Zx11134Zxj4MCBHHvsscyfP59BgwaVup+IHLr166F7d+jaFerWhRkzoH79fe9bWv+hSogkAvQBlhVpZwMnAG2BasANQf8vgBbBrRfwdMwjW726bP3l5IILLuDII4/cq/+jjz6ie/fuALRp04Z27drt8/HNmzcnMzMTgFNOOYXc3FwAFi9ezNlnn03btm3Jzs5myZIl+43j448/5oorrgDgqquu2tPvnOP++++nXbt2nH/++axdu5Z1+zj942D3E5GD55w/GNKqFYwfDwMGwOzZkJUF337r7y95+/bb2MQSehIxs8bAL4HnI33OuUkuAMwEGgd3dQXGBHd9BtQ2s4Z7PWl5atq0bP0HoVWrVntVEN9//z2rV6/muOOOA+Dwww+P+vkBqlatumc7LS2NnTt3AtCzZ0+eeuopFi1aRP/+/Q/q+ox9nYKbnZ1Nfn4+c+bMYf78+dSvX3+fz3Ww+4nIwVmzBi65BK68Elq0gHnz4IEHoEqVcOIJPYkATwL3ALtL3hEcxroKeDvoagSsKbJLXtBX8nG9zGy2mc3Oz88/tOgefRSqVy/eV726749Sx44dKSgoYMyYMQDs2rWLu+66i549e1K95GuV0L59e15++WUAli5dyqJFi8r02j/88AMNGzZkx44dZB/EuE779u156aWXAIrtv2XLFo4++mgqV67MtGnTWLXKzxpds2ZNfvjhhwPuJyJls3s3PPMMtG4N06bBk0/CRx/5aiRMoSYRM+sMrHfO7fvAPgwHPnDOfViW53XOjXDOZTnnsurVO+CaKvvXoweMGAHNmvnLPps18+0ePaJ+SjNj/PjxvPLKK7Ro0YLjjz+e9PR0/vznPx/wsbfccgv5+fm0atWKBx54gNatW3PEEUcc9GsPGDCA008/nfbt23PCCScccP+hQ4cybNgw2rZty9q1a/f09+jRg9mzZ9O2bVvGjBmz57nq1q1L+/btadOmDX379i11PxE5eCtWQIcOcPPNcPrpsHgx9OkDaWlhRwYW5iCnmf0FX2nsBNKBWsBrzrkrzaw/cBJwmXNud7D/s8B059y/g/Zy4Fzn3DelvUZWVpYruSjVsmXLaNmyZSy+pZjbtWsXO3bsID09nZUrV3L++eezfPlyqoRVy5aTivwzEYmVnTthyBB48EGoWhWeeAKuvdZ/no01M5vjnMs60H6hTsDonLsPuA/AzM4F7g4SyA3ARUDHSAIJTAR6m9lLwOnAlv0lkGRUUFDAeeedx44dO3DOMXz48AqfQERkbwsWwPXXw5w5cOmlMGwYHHNM2FHtLVFn8X0GWAV8GgzqvuacexiYBFwM5AAFwLWhRRiSmjVrarlfkSS2fTs88ggMHAhHHgkvvwzdusWn+ohGwiQR59x0YHqwvc+4grO1bo1fVCIi8fPpp776WLYMrr7aH76qWzfsqPYvEc7OEhFJaf/9L9x+O7RvD1u3wqRJMHp04icQSKBKREQkFb37Ltx4I+Tmwq23wl/+AjVrhh3VwVMlIiISgk2b/KGrCy7wFwp+8AE89VTFSiCgJBKatLQ0MjMz99xyc3OZPXs2t912GwDTp0/nk08+2bP/66+/ztKlS/e0H3zwQd59991yiSUycWNREydOZODAgeXy/CJS3Pjx/iLB0aPh3nv9mVhnnx12VNHR4awDaNBg3yuC1a9/aHPRVKtWjfnz5xfry8jIICvLn5Y9ffp0atSowVlnnQX4JNK5c2daBZenPvzww9G/+EHo0qULXbp0ielriKSadevg97+HV16BzEx48004OfbTyMaUKpEDiOdSk9OnT6dz587k5ubyzDPPMGTIEDIzM3n//feZOHEiffv2JTMzk5UrV9KzZ889a4xkZGTQv3//PdOtf/755wDk5+dzwQUX0Lp1a2644QaaNWu2V8VRmhdeeIHevXsDfr6t2267jbPOOouf/exnxdY2GTRoEKeeeirt2rWjf//+5fyOiCQH52DMGGjZEiZM8LMmzZxZ8RMIqBLh9tuhREFw0M49d9/9mZl+Xpv9+fHHH/fMstu8eXPGjx+/576MjAx+97vfUaNGDe6++27AVwadO3cudYXDo446irlz5zJ8+HAGDx7M888/z0MPPUSHDh247777ePvttw847fv+fPPNN3z00Ud8/vnndOnShW7dujF58mRWrFjBzJkzcc7RpUsXPvjgA84555yoX0ck2axeDTfdBG+/DWed5df8SKbZf1I+iYRlX4ezDsVll10G+GnfX3vtNcBPGx9JTp06daJOnTpRP/+ll17KYYcdRqtWrfZM5T558mQmT57MSSedBPiFqlasWKEkIoKfMPHpp/2Yh3Pw97/DLbf4BVKTSconkQNVDPu7SnT69HIN5ZBEpn4vOu17LJ4f2LOolHOO++67j5tuuqncX0+kIlu+HG64wc+ye+GF8OyzfmXtZJRkOTF5lJxSvWT7YBSdNn7y5Mls2rSpXGO86KKLGDVqFFu3bgVg7dq1rF+/vlxfQ6Qi2bHDT1dy4omwZAm88II/jJWsCQSURA4o3ktNRlxyySWMHz+ezMxMPvzwQ7p3786gQYM46aSTWLly5UE9R//+/Zk8eTJt2rThlVdeoUGDBtQs5ST0du3a0bhxYxo3bsydd955UM9/4YUX8tvf/nbPGuzdunUrc6ITSRbz5vlp2u+7Dzp3hqVL4ZprEnfOq/IS6lTw8ZBsU8GXxfbt20lLS6NSpUp8+umn3HzzzeU6DlOeUuVnIsln2za/PO1jj8FRR/nZdi+/POyoDl2FmApeYmv16tX85je/Yffu3VSpUoXnnnsu7JBEksrHH/urzpcv9+t8DB7sZ95NJUoiSaxFixbMmzcv7DBEks7WrXD//X6akqZN4Z13/AB6KkrZMZFkP4xXkehnIRXJO+/4dc6fespffb54ceomEEjRJJKens7GjRv1zysBOOfYuHEj6enpYYcisl/ffQc9e0KnTlC9Onz4IQwdCjVqhB1ZuFLycFbjxo3Jy8sjPz8/7FAEn9QbN24cdhgipRo3zk/TvmED9OsHDzwA+tzjpWQSqVy5Ms2bNw87DBFJcN98A717w2uvwUkn+Ws+gtmKJJCSh7NERPbHOX+hYKtWfqbdgQP9hIlKIHtLiCRiZmlmNs/M3gjazc1shpnlmNlYM6sS9FcN2jnB/Rlhxi0iySc3Fy66yJ+y27YtLFwIf/gDVErJ4zYHlhBJBOgDLCvSfgwY4pw7DtgEXB/0Xw9sCvqHBPuJSHnKzvbzdBx2mP+anR12RHGxaxf87W/Qpg18+qm/aHD6dDj++LAjS2yhJxEzawz8Eng+aBvQAYgsWjEauDTY7hq0Ce7vGOwvIuUhOxt69YJVq/wxnVWrfDvJE8myZXDOOdCnj19hcMmS5JxxNxYS4S16ErgH2B206wKbnXORqWjzgEbBdiNgDUBw/5Zg/2LMrJeZzTaz2ToDS6QM+vWDgoLifQUFvj8J7djhF4jKzITPP/cLR02a5C8glIMTahIxs87AeufcnPJ8XufcCOdclnMuq169euX51CLJbfXqsvVXYHPnwqmn+tN1L73UT5h41VXJP2FieQu7EmkPdDGzXOAl/GGsoUBtM4sMYzUG1gbba4EmAMH9RwAb4xmwSFIr7SN4En00//FHv1DUaafB+vUwfjyMHRv7mbmTVahJxDl3n3OusXMuA+gOTHXO9QCmAZF1YK8BJgTbE4M2wf1TnS47Fyk/jz7qL8cuqnp1358EPvjAr/Xx2GP+6vOlS30VItELuxIpzR+AO80sBz/mEVkcfCRQN+i/E7g3pPhEklOPHjBiBDRr5o/rNGvm2z16hB3ZIfn+e3/F+c9/Djt3wrvvwvPPQ+3aYUdW8aXkeiIikjreegtuugny8vzZV488AocfHnZUiU/riYhIStu4Ee64A1580V95/skncMYZYUeVfBL1cJaISFScg5dfhpYt4d//hj/+0Z+JpQQSG6pERCRpfP21v0hwwgTIyvJjH+3ahR1VclMlIiIVnnMwcqQ/bPXOOzBokJ+6RAkk9lSJiEiF9uWXcOONMHWqP/vq+efhuOPCjip1qBIRkQpp1y548kk/0+6sWfDMMz6RKIHElyoREalwliyB66+HGTPgl7/0CUSLY4ZDlYiIVBg//QQDBvhVBnNy/OTC//mPEkiYVImISIUwa5avPhYtgiuugKFDQfOrhk+ViIgktIIC6NvXX+fx3XcwcSL8619KIIlClYiIJKzp0/2ZVzk5fm2sxx+HI44IOyopSpWIiCScLVvgd7+D887z14BMnQrPPqsEkoiUREQkobz5JrRuDc89B3fdBQsX+mQiiUlJREQSQn6+n3G+c2eoU8dfcT548N7Lm0hiURIRkVA55ydKbNUKXnkF/vQnmDPHrzwoiU8D6yISmrw8uPlmeOMNnzRGjoQ2bcKOSspClYiIxN3u3X7BxNat4b334Ikn/HofSiAVjyoREYmrnBx/2u706X7A/Lnn4Nhjw45KoqVKRETiYtcu+Otf/fTsc+f65PHee0ogFV2oScTM0s1sppktMLMlZvZQ0N/RzOaa2Xwz+8jMjgv6q5rZWDPLMbMZZpYRZvwicnAWLYIzz4S774bzz4elS+GGG8As7MjkUIVdiWwHOjjnTgQygU5mdgbwNNDDOZcJ/At4INj/emCTc+44YAjwWAgxi8hB2r4d+veHk0+G3Fx46SW/6mCjRmFHJuUl1CTivK1Bs3Jwc8GtVtB/BPB1sN0VGB1svwp0NNNnGZFENGMGnHIKPPwwdO/uq4//+z9VH8km7EoEM0szs/nAemCKc24GcAMwyczygKuAgcHujYA1AM65ncAWoO4+nrOXmc02s9n5+fnx+DZEJPDf/8Kdd/rDV1u2+NN3X3wRjjoq7MgkFkJPIs65XcFhq8bAaWbWBrgDuNg51xj4B/BEGZ9zhHMuyzmXVU9TfYrEzdSpfuB8yBA/99WSJX7RKEleoSeRCOfcZmAa8AvgxKAiARgLnBVsrwWaAJhZJfyhro1xDlVESti82Z+227EjHHaYP313+HCoVeuAD5UKLuyzs+qZWe1guxpwAbAMOMLMjg92i/QBTASuCba7AVOdcy6OIYtICRMn+osGR42Ce+7xEyb+/OdhRyXxEvbFhg2B0WaWhk9oLzvn3jCzG4FxZrYb2ARcF+w/EnjRzHKA74DuYQQtIrB+Pdx2G4wd6w9hTZgAWVlhRyXxFmoScc4tBE7aR/94YPw++rcBv45DaCJSCuf82uZ9+sDWrX7N8z/8ASpXDjsyCUPYlYiIVCBr1vgB80mT/HK1I0f62XcldSXMwLqIJK7du+Hpp/3Yx/Tp8OST8NFHSiCiSkREimjQANat27u/cmXYscNPWTJiBDRvHv/YJDEpiYjIHvtKIOATyKhR0LOnrjiX4pREROSgXHtt2BFIItKYiIiIRE1JREQAv9aHSFkpiYikuHXr/DjH3XeHHYmUm+xsyMjwc9BkZPh2jCiJiKSwe+7xZ2RFlDZfaf368YlHykF2NvTqBatW+StDV63y7RglEiURkRT05Ze++hg0yLf/8hf//2b9ev+15O3bb8ONV8qgXz8oKCjeV1Dg+2NAZ2eJpJirroJ//rOwvWkT1K4dXjxSzlavLlv/IVIlIpIiFizw1UckgYwc6asMJZAk07Rp2foPkZKISJJzzq/zkZnp27Vq+aMb1123/8dJBfXoo1C9evG+6tV9fwwoiYgksQ8+8CfoTJ3q26+/7pesrVYt3Lgkhnr08HPTNGvmS89mzXy7R4+YvJzGRESS0M6d0KYNLF/u2yecAIsWQSX9xaeGHj1iljRKUiUikmQmTPATJkYSyPvvw7JlSiASG/q1EkkSP/4IRx/tF4oC6NAB3n1XEyZKbKkSEUkCo0b5sdNIApk/H957TwlEYk+ViEgFtnkz1KlT2L7ySnjxxfDikdQTaiViZulmNtPMFpjZEjN7KOg3M3vUzL4ws2VmdluR/r+ZWY6ZLTSzk8OMXyRMjz1WPIGsXKkEIvEXdiWyHejgnNtqZpWBj8zsLaAl0AQ4wTm328yODvb/BdAiuJ0OPB18FUkZ33wDxxxT2O7bFx5/PLx4JLWFmkSccw4IjuJSObg54Gbgt8653cF+64N9ugJjgsd9Zma1zayhc+6bOIcuEoo77vDrm0d8+60mR5RwhT6wbmZpZjYfWA9Mcc7NAI4F/s/MZpvZW2bWIti9EbCmyMPzgr6Sz9kreOzs/Pz8WH8LIjGXk+MHySMJZPBgfyW6EoiELfQk4pzb5ZzLBBoDp5lZG6AqsM05lwU8B4wq43OOcM5lOeey6pU2t7VIBeAcdO8OLVoU9m3ZAnfdFV5MIkWFnkQinHObgWlAJ3yF8Vpw13igXbC9Fj9WEtE46BNJOvPm+SlLxo717dGjfVKpVSvcuESKCvvsrHpmVjvYrgZcAHwOvA6cF+z2c+CLYHsicHVwltYZwBaNh0iy2b0bzj4bTg7OPaxb119IePXV4cYlsi9lGlg3s/rOuXXl+PoNgdFmloZPaC87594ws4+AbDO7Az/wfkOw/yTgYiAHKACuLcdYREI3bZq/0jzijTfgl78MLx6RAynr2Vmrzex14Fnn3NRDfXHn3ELgpH30bwb2+tMJzsq69VBfVyTR7NjhJ0n88kvfbtcO5s6FtLRw4xI5kLIezvoC+DUwJbgQ8C4zqxuDuERSxrhxUKVKYQL5+GO/gJQSiFQEZUoizrm2wP8CL+JPrR0E5JlZtpmdE4P4RJJWQYFf16NbN9/u1MmPh5x1VrhxiZRFmQfWnXOfOOd6AscAffDjE1cA08xsqZn1MbM6+3sOkVQ3YgQcfjhs2+bbixbBW29pwkSpeKI+O8s5t8U59/ci1ckYoBnwBL46ecHMssopTpGk8N13PlHcdJNvX3utP223TZtw4xKJVnmd4rsB2ARsAwx/seDVwAwze93Mjiyn1xGpsB55xJ+uG/HVV34Kd5GKLOokYmaVzay7mU0DlgG3A/nAncBRQAfgHaALMKwcYhWpkNau9dXHH//o2/ff76uPjIxQwxIpF2WegNHMjgN6AT2BusBu/MWBw51z7xXZdTow3cxexV+FLpJyeveGYUU+Qq1fD5qJR5JJWS82fA84F3/I6htgADDCOff1fh42B/hVtAGKVETLl/vrPiKGDoXbbgsvHpFYKWslch5+fqvhwOvOuV0H8Zj/APtLMiJJwzl/yu5rrxX2ff891KwZXkwisVTWJNLSObe8LA9wzi0GFpfxdUQqnFmz4LTTCtvZ2fDb34YXj0g8lCmJlDWBiKSC3bvhzDNh5kzfbtAAcnOhatVQwxKJi4SZCl6kIpoyxU9PEkkgb73ll69VApFUEfYa6yIV0k8/wbHHQl6eb59yCsyYofmuJPWoEhEpo5df9pVGJIF8+inMnq0EIqlJlYjIQdq6FY44wo+BAFxyCUyYoPmuJLWpEhE5CMOG+dN0IwlkyRKYOFEJRESViMh+bNhQ/Arzm26CZ54JLx6RRKNKRKQU/fsXTyCrVyuBiJSkJCJSwpo1/jDVww/7dv/+/kr0Jk3CjUskEYWaRMws3cxmmtkCM1tiZg+VuP9vZra1SLuqmY01sxwzm2FmGfGOWZLbTTdB06aF7Q0b4E9/Ci0ckYQXdiWyHejgnDsRyAQ6mdkZAMGCViVXSLwe2OScOw4YAjwWz2AleS1d6quPESN8e9gwX30UXf9DRPYWahJxXqTSqBzcnJml4ddvv6fEQ7oCo4PtV4GOZjo/RqLnHHTuDK1b+3ZaGvzwA9xyS7hxiVQUYVcimFmamc0H1gNTnHMzgN7AROfcNyV2bwSsAXDO7QS24Nc0KfmcvcxstpnNzs/Pj+03IBXWZ5/BYYfBm2/69tixsHMn1KgRblwiFUnop/gG08lnmlltYLyZnQP8Gr9uSbTPOQIYAZCVleXKI05JHrt2wamnwrx5vt20KaxYAVWqhBuXSEUUeiUS4ZzbjF+r5DzgOCDHzHKB6maWE+y2FmgCYGaVgCOAjfGPViqqt9+GSpUKE8iUKbBqlRKISLRCrUTMrB6wwzm32cyqARcAjznnGhTZZ2swkA4wEbgG+BToBkx1zqnSkAPavh2aNYN163z7jDPg44/94SwRiV7Yf0INgWlmthCYhR8TeWM/+48E6gaVyZ3AvXGIUSq4f/0L0tMLE8isWX7SRCUQkUMXaiXinFsInHSAfWoU2d6GHy8ROaAffoBatQrbl10Gr76q+a5EypM+i0lSGjq0eAL5/HMYN04JRKS8hX52lkh5Wr8e6tcvbPfuDX//e3jxiCQ7VSKSNPr1K55A8vKUQERiTUlEKrzcXH+Y6s9/9u1HHvFXojdqFGpYIilBh7OkQrvuOvjHPwrbGzfCkUeGF49IqlElsi/Z2ZCR4c8BzcjwbUkoixb56iOSQJ591lcfSiAi8aVKpKTsbOjVCwoKfHvVKt8G6NEjvLgE8ImiUyeYPNm3q1Xz07VXrx5uXCKpSpVISf36FSaQiIIC3y+hilxhHkkg48b5H40SiEh4VImUtHp12fol5nbtgsxMWLzYt487zq//UblyuHGJiCqRvRVd1u5g+iWm3njDT5gYSSBTp/oZd5VARBKDkkhJjz669/GR6tV9v8TNtm1+kPySS3z7nHN8RXLeeeHGJSLFKYmU1KOHXyO1WTN/+k+zZr6tQfW4GT3aD5hv2uTbc+bA++9rwkSRRKQxkX3p0UNJIwRbtkDt2oXt7t39DLya70okcemznSSEv/61eAJZsQL+/W8lEJFEp0pEQrVuHTRoUNi+4w544onw4hGRslElIqHp27d4Avn6ayUQkYpGSUTi7ssv/WGqwYN9e+BAfyV6w4bhxiUiZafDWRJXV15ZfCqyTZuKj4WISMWiSkTiYsECX31EEsjIkb76UAIRqdhCTSJmlm5mM81sgZktMbOHgv5sM1tuZovNbJSZVQ76zcz+ZmY5ZrbQzE4OM345MOf8BYKZmb5dq5af7+q668KNS0TKR9iVyHagg3PuRCAT6GRmZwDZwAlAW6AacEOw/y+AFsGtF/B03COWg/bBB/4CwenTffv11/21INWqhRqWiJSjUMdEnHMO2Bo0Kwc355ybFNnHzGYCjYNmV2BM8LjPzKy2mTV0zn0Tz7hl/3buhNat4YsvfPuEE/z6H5U0AieSdMKuRDCzNDObD6wHpjjnZhS5rzJwFfB20NUIWFPk4XlBX8nn7GVms81sdn5+fuyCl728/rqfHDGSQN5/H5YtUwIRSVahJxHn3C7nXCa+2jjNzNoUuXs48IFz7sMyPucI51yWcy6rXr165RmulOLHH6FGDfjVr3y7QwfYvdtPnCgiySv0JBLhnNsMTAM6AZhZf6AecGeR3dYCTYq0Gwd9EqKRI/1Ex//9r2/Pnw/vvacpS0RSQdhnZ9Uzs9rBdjXgAuBzM7sBuAi4wjm3u8hDJgJXB2dpnQFs0XhIeDZv9onihuC0hyuv9GdjnXhiuHGJSPyEfaS6ITDazNLwCe1l59wbZrYTWAV8av7j7GvOuYeBScDFQA5QAFwbTtgycCDcd19he+VK+NnPwotHRMIR9tlZC4GT9tG/z7iCs7JujXVcUrqvv4ZGRU5luOceeOyx8OIRkXCFXYlIBXLHHfDkk4Xtb7+F+vXDi0dEwpcwA+uSuFas8GMfkQQyeLAf+1ACERFVIlIq5/zqgi+/XNi3ZYufukREBFSJSCm8CQyyAAANOElEQVTmzvVTlkQSyJgxPqkogYhIUapEpJjIBYIff+zbRx0Fa9ZAenq4cYlIYlIlIntMmwZpaYUJ5I03ID9fCURESqdKRNixA/7nf+Crr3y7XTt/OCstLdy4RCTxqRJJcePGQZUqhQnk44/9AlJKICJyMFSJpKj//hfq1oXt2327UyeYNEnzXYlI2agSSUHPPutn3I0kkEWL4K23lEBEpOxUiaSQ777z1UfEddf5GXhFRKKlSiRFDBhQPIF89ZUSiIgcOlUiSW7tWmjcuLB9//3w6KPhxSMiyUVJJIn17g3DhhW2168HLfQoIuVJh7OS0PLlfpA8kkCGDvVTliiBiEh5UyWSRJyDyy6D118v7Pv+e6hZM7yYRCS5qRJJErNm+QkTIwkkO9snFSUQEYklVSIV3O7dcOaZMHOmbzds6M+8qlo13LhEJDWoEqnApkzx05NEEsjbb/vla5VARCReQk0iZpZuZjPNbIGZLTGzh4L+5mY2w8xyzGysmVUJ+qsG7Zzg/oww4w/LTz9BkyZw4YW+fcopsHMnXHRRuHGJSOoJuxLZDnRwzp0IZAKdzOwM4DFgiHPuOGATcH2w//XApqB/SLBfShk71lcaeXm+/dlnMHu2JkwUkXCEmkSctzVoVg5uDugAvBr0jwYuDba7Bm2C+zuapcaMT1u3+oHz7t19u0sXPx5y+unhxiUiqS3sSgQzSzOz+cB6YAqwEtjsnNsZ7JIHNAq2GwFrAIL7twB1KcHMepnZbDObnZ+fH+tvIeaeesqfZeWcby9dChMmaMJEEQlf6EnEObfLOZcJNAZOA04oh+cc4ZzLcs5l1avAV9ht2OATxe9/79s33eQTScuW4cYlIhIRehKJcM5tBqYBZwK1zSxy+nFjYG2wvRZoAhDcfwSwMc6hxsWDDxa/wnz1anjmmfDiERHZl7DPzqpnZrWD7WrABcAyfDLpFux2DTAh2J4YtAnun+pc5CBPcli92lcfAwb4dv/+vvpo0iTcuERE9iXsiw0bAqPNLA2f0F52zr1hZkuBl8zsEWAeEJm0fCTwopnlAN8B3cMIOlZ69YLnnitsb9hQfPp2EZFEE2oScc4tBE7aR/+X+PGRkv3bgF/HIbS4WroUWrcubA8bBrfcEl48IiIHK+xKJKU5B5dcAm++6duVKsHmzXD44eHGJSJysBJmYD3VfPqpv+4jkkDGjoUdO5RARKRiUSUSZ7t2wamnwrx5vt20KaxYAVWqhBuXiEg0VInE0Vtv+UNWkQQyZQqsWqUEIiIVlyqRONi+3Vcc69f79plnwkcf+cNZIiIVmf6NxVh2NqSnFyaQWbPgk0+UQEQkOagSiZHvv4cjjihsX345vPKK5rsSkeSiz8Mx8OSTxRPI8uXw6qtKICKSfFSJlKP166F+/cJ2797w97+HF4+ISKypEikn999fPIHk5SmBiEjyUxI5RLm5/jDVX/7i24884q9Eb9Rovw8TEUkKOpx1CHr2hNGjC9vffQd16oQWjohI3KkSicKiRb76iCSQESN89aEEIiKpRpVIGTgHF13krzQHqFbNT9devXq4cYmIhEWVyEH6+GN/gWAkgYwbBwUFSiAiktpUiZTQoAGsW1f6/cceC8uWQeXK8YtJRCRRqRIpYX8JZNo0yMlRAhERiVASKYNzzw07AhGRxKIkIiIiUQs1iZhZEzObZmZLzWyJmfUJ+jPN7DMzm29ms83stKDfzOxvZpZjZgvN7OQw4xcRSXVhD6zvBO5yzs01s5rAHDObAjwOPOSce8vMLg7a5wK/AFoEt9OBp4OvIiISglArEefcN865ucH2D8AyoBHggFrBbkcAXwfbXYExzvsMqG1mDcszpqLzXx1Mv4hIKgu7EtnDzDKAk4AZwO3AO2Y2GJ/ozgp2awSsKfKwvKDvmxLP1QvoBdC0adMyxfHtt2UOXUQkZSXEwLqZ1QDGAbc7574HbgbucM41Ae4ARpbl+ZxzI5xzWc65rHr16pV/wCIiAiRAEjGzyvgEku2cey3ovgaIbL8CnBZsrwWaFHl446BPRERCEPbZWYavMpY5554octfXwM+D7Q7AimB7InB1cJbWGcAW51yxQ1kiIhI/YY+JtAeuAhaZ2fyg737gRmComVUCthGMbwCTgIuBHKAAuDa+4YqISFGhJhHn3EdAaSuPn7KP/R1wa0yDEhGRg2b+/3LyMrN8YFWUDz8K2FCO4ZQXxVU2iqvsEjU2xVU2hxJXM+fcAc9MSvokcijMbLZzLivsOEpSXGWjuMouUWNTXGUTj7hCPztLREQqLiURERGJmpLI/o0IO4BSKK6yUVxll6ixKa6yiXlcGhMREZGoqRIREZGopVwSMbNRZrbezBaX6P+9mX0erGvyeJH++4L1S5ab2UVF+jsFfTlmdm8s4jKzscGaKvPNLLfIBZlhx1Xm9V7M7BozWxHcrolRXCea2admtsjM/mNmtYrcF6/3q7Q1co40synB9z/FzOoE/XF5z/YT16+D9m4zyyrxmJi/Z/uJa1Dw97jQzMabWe0EiWtAENN8M5tsZscE/aH+HIvcf5eZOTM7Km5xOedS6gacA5wMLC7Sdx7wLlA1aB8dfG0FLACqAs2BlUBacFsJ/AyoEuzTqrzjKnH/X4EHEyEuYDLwi2D7YmB6ke238BeQngHMCPqPBL4MvtYJtuvEIK5ZwM+D7euAASG8Xw2Bk4PtmsAXwes/Dtwb9N8LPBbP92w/cbUE/geYDmQV2T8u79l+4roQqBT0P1bk/Qo7rlpF9rkNeCYRfo5BuwnwDv66uKPiFVfKVSLOuQ+A70p03wwMdM5tD/ZZH/R3BV5yzm13zn2Fn27ltOCW45z70jn3E/BSsG95xwXsmWPsN8C/EySusq73chEwxTn3nXNuEzAF6BSDuI4HPgi2pwCXF4krXu9XaWvkdAVGB7uNBi4tElvM37PS4nLOLXPOLd/HQ+Lynu0nrsnOuZ3Bbp/hJ1tNhLi+L7Lb4fi/hUhcof0cg7uHAPcUiSkucaVcEinF8cDZZjbDzN43s1OD/tLWLymtP1bOBtY55yITUYYd1+3AIDNbAwwG7kuQuJZQ+I/j1xTO+BxKXFZ8jZz6rnCy0G+ByDJncY+tRFylSaS4rsN/mk6IuMzs0eB3vwfwYCLEZWZdgbXOuQUldot5XEoiXiV8WXcG0Bd4Ofj0nyiuoLAKSQSHtN5LDF0H3GJmc/Cl/k9hBWJ7r5Gzh/PHE0I5LXJ/cYWptLjMrB9+Ge3sRInLOdcv+N3PBnqHHRf+/bmfwoQWV0oiXh7wWlDyzQR24+ecKW39krita2J+JuPLgLFFusOOq6zrvcQlLufc5865C51zp+CT7sow4rJ9r5GzLjiMQPA1csg0brGVEldpQo/LzHoCnYEeQeJNiLiKyKbwkGmYcR2LHx9aYGa5wWvMNbMGcYkrmoGUin4DMig+IPs74OFg+3h8mWdAa4oP4n2JH8CrFGw3p3AQr3V5xxX0dQLeL9EXalz447DnBtsdgTnB9i8pPog30xUO4n2FH8CrE2wfGYO4IidEHAaMAa6L9/sVfO9jgCdL9A+i+MD64/F8z0qLq8j90yk+sB6X92w/71cnYClQL4zf/f3E1aLI9u+BVxPp5xjsk0vhwHrM4zqkP+KKeMN/Qv0G2IGvQK4Pfun+CSwG5gIdiuzfD/+JdjnBGUlB/8X4MyNWAv1iEVfQ/wLwu33sH1pcwP8Cc4I/1BnAKUV+wYcFr72I4v+UrsMPguYA18Yorj7B9/4FMJDgYto4v1//iz9UtRCYH9wuBuoC7+EXWHs38gcbr/dsP3H9Knj/tgPrgHfi+Z7tJ64c/Ie5SN8zCRLXOPz/iYXAf/CD7aH/HEvsk0thEol5XLpiXUREoqYxERERiZqSiIiIRE1JREREoqYkIiIiUVMSERGRqCmJiIhI1JREROLAzK4Ppuh+az/7vBnsc0s8YxM5FEoiInHgnBsJTAQ6mdmtJe83s5sJpu12zg2Pd3wi0dLFhiJxYmZH4692Phy/JsTyoP94YB7wI9DGOfdteFGKlI0qEZE4cX6dmhuB6sA/zaxSMMHmP4O+XkogUtFUCjsAkVTinJtgZqPw8xZFpu4+FXjBHXhmXZGEo8NZInFmZjXxk1c2DbrWAO2cX6lOpELR4SyROAuSxcMUrg1+sxKIVFRKIiJxZmbVgD8U6fp1WLGIHColEZH4exw4ARiKXw/iOjO7JNyQRKKjMRGRODKzC4G38af6ngq0AGYDm/Gn924IMTyRMlMlIhInZnYk8A/8aoxXOue2O+cWA38E6gNPhxmfSDSURETi5xngGOAB59zCIv1/BT4EupnZlaFEJhIlHc4SiQMzuwoYA3wAnOec213i/ub4dbN3Am2dc3nxj1Kk7JRERGLMzJriE4ThrwdZVcp+NwDPAVOAi5z+OKUCUBIREZGoaUxERESipiQiIiJRUxIREZGoKYmIiEjUlERERCRqSiIiIhI1JREREYmakoiIiERNSURERKKmJCIiIlH7f1gLS3bZv0mkAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "predict = model(Variable(x_train))\n", "predict = predict.data.numpy()\n", "plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')\n", "plt.plot(x_train.numpy(), predict, 'b-s', label='Fitting Line')\n", "plt.xlabel('X', fontsize= 20)\n", "plt.ylabel('y', fontsize= 20)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用pytorch建立卷积神经网络并处理MNIST数据。\n", "https://computational-communication.com/pytorch-mnist/" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }