{ "cells": [ { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T02:19:02.272939Z", "start_time": "2019-04-07T02:19:02.268259Z" }, "slideshow": { "slide_type": "slide" } }, "source": [ "***\n", "***\n", "\n", "# Introduction to Neural Network\n", "\n", "\n", "***\n", "***" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:41:09.971531Z", "start_time": "2019-04-07T05:41:09.967172Z" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "\n", "http://playground.tensorflow.org/" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T09:37:45.438709Z", "start_time": "2019-04-07T09:37:45.434406Z" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "- Deep Learning http://www.deeplearningbook.org\n", " - An MIT Press book by Ian Goodfellow, Yoshua Bengio and Aaron Courville\n", "\n", "- Neural Networks and Deep Learning http://neuralnetworksanddeeplearning.com/index.html\n", "\n", " - A free online book explaining the core ideas behind artificial neural networks and deep learning. [Code](https://github.com/mnielsen/neural-networks-and-deep-learning). By [Michael Nielsen](http://michaelnielsen.org/) / Dec 2017" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## House Price\n", "\n", "Let’s start with a simple example. \n", "- Say you’re helping a friend who wants to buy a house.\n", "\n", "- She was quoted $400,000 for a 2000 sq ft house (185 meters). \n", "\n", "Is this a good price or not?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "So you ask your friends who have bought houses in that same neighborhoods, and you end up with three data points:\n", "\n", "\n", "\n", "| Area (sq ft) (x) | Price (y) | \n", "| -------------|:-------------:|\n", "|2,104|399,900|\n", "|1,600|329,900|\n", "|2,400|369,000|" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "$$y = f(X) = W X$$\n", "\n", "- Calculating the prediction is simple multiplication.\n", "- But before that, we need to think about the weight we’ll be multiplying by. \n", "- “training” a neural network just means finding the weights we use to calculate the prediction.\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "A simple predictive model (“regression model”)\n", "- takes an input, \n", "- does a calculation, \n", "- and gives an output \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T03:15:14.317623Z", "start_time": "2019-04-07T03:15:14.313438Z" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "Model Evaluation\n", "- If we apply our model to the three data points we have, how good of a job would it do?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Loss Function (also, cost function)**\n", "\n", "- For each point, the error is measured by the difference between the **actual value** and the **predicted value**, raised to the power of 2. \n", "- This is called **Mean Square Error**. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- We can't improve much on the model by varying the weight any more. \n", "- But if we add a bias (intercept) we can find values that improve the model.\n", "\n", "\n", "\n", "$$y = 0.1 X + 150$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Gradient Descent**\n", "\n", "- Automatically get the correct weight and bias values \n", "- minimize the loss function.\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## softmax\n", "\n", "The softmax function, also known as softargmax or normalized exponential function, is a function that takes as input a vector of K real numbers, and normalizes it into a probability distribution consisting of K probabilities. \n", "\n", "$$softmax = \\frac{e^x}{\\sum e^x}$$\n" ] }, { "cell_type": "code", "execution_count": 128, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T13:57:33.466056Z", "start_time": "2019-04-07T13:57:33.459580Z" }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "array([0.02364054, 0.06426166, 0.1746813 , 0.474833 , 0.02364054,\n", " 0.06426166, 0.1746813 ])" ] }, "execution_count": 128, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def softmax(s):\n", " return np.exp(s) / np.sum(np.exp(s), axis=0)\n", "\n", "softmax([1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "That is, prior to applying softmax, some vector components could be negative, or greater than one; and might not sum to 1; but after applying softmax, each component will be in the interval (0,1), and the components will add up to 1, so that they can be interpreted as probabilities. \n", "\n", "Furthermore, the larger input components will correspond to larger probabilities. \n", "\n", "Softmax is often used in neural networks, to map the non-normalized output of a network to a probability distribution over predicted output classes." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Activation Function\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T03:59:39.599371Z", "start_time": "2019-04-07T03:59:39.593104Z" }, "code_folding": [], "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "0.6224593312018546" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "def sigmoid(x):\n", " return 1/(1 + np.exp(-x))\n", "\n", "sigmoid(0.5)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T04:01:24.575040Z", "start_time": "2019-04-07T04:01:24.567328Z" }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "0.5" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Naive scalar relu implementation. \n", "# In the real world, most calculations are done on vectors\n", "def relu(x):\n", " if x < 0:\n", " return 0\n", " else:\n", " return x\n", "\n", "relu(0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# WHAT IS PYTORCH?\n", "\n", "It’s a Python-based scientific computing package targeted at two sets of audiences:\n", "\n", "- A replacement for NumPy to use the power of GPUs\n", "- a deep learning research platform that provides maximum flexibility and speed" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:13:38.257567Z", "start_time": "2019-06-15T03:13:37.133622Z" }, "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch import nn, optim\n", "from torch.autograd import Variable\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:13:54.278106Z", "start_time": "2019-06-15T03:13:54.075090Z" }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD+CAYAAADfwXXpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAFNlJREFUeJzt3X+Q3PV93/HnC4FiEv9QLS4jIQvJwlHBSceMdZ3xtWQQwgy4k2TGeGoS0kAcG9nYUxg8MZ4y9ozDlIHB/ie4bjJKaJk6OImdgt2YGqpSq65dOZODGKt2LBnLxkgV6iEgbkNVLN27f3y/stfKSbcn3WlPH56PmZ29/eznu/u6O3jx4bPf201VIUlq0xmjDiBJWjiWvCQ1zJKXpIZZ8pLUMEtekhpmyUtSwyx5SWqYJS9JDbPkJalhZ446wDnnnFNr164ddQxJOq08+uijz1TV2GzzRl7ya9euZXJyctQxJOm0kuTJYea5XSNJDbPkJalhlrwkNcySl6SGWfKS1LChSj7J0iTfTPKH/e2bknw/yc4kbxmYd2eSPUl2JNmwUKElScMZ9hTKW4HvASQ5H3gf8PPAauA/J1kD/CJwMbAWuBS4B7hofuNKAmD7dti2DTZuhImJUafRIjZrySe5EPiHwKfpSvytwKer6n8D30zyPWADcBVwb1UdArYmGUuyoqqeXrD00kvR9u1w2WXw4ouwdCk88ohFr2M67nZNkgB3AzcNDK8GBk/C3wOsnGF8bz8+0+NuTjKZZHJqaupEcksvXdu2dQV/+HB3vW3bqBNpEZttT/49wLaqemJgbCkwPXB7Gjh8nPG/o6q2VNV4VY2Pjc36V7mSBm3c2K3glyzprjduHHUiLWKzbdf8BvCKJP8UeDXwM3Qr+1UDc14DPAXsO2r8XLpVvqT5NDHRbdG4J68hHLfkq+ofHfk6yW/S7cl/Hvhkko/Rvcj6auBrwIPAe5PcB2wCdlXVswsTW3qJm5iw3DWUOb9BWVU9muSPgG8AB4F3VVUleQC4BNgNHACumdekkqQ5S1WNNMD4+Hj5LpSSNDdJHq2q8dnm+RevktQwS16SGmbJS1LDLHlJapglL0kNs+QlqWGWvCQ1zJKXpIZZ8pLUMEtekhpmyUtSwyx5SWqYJS9JDbPkJalhlrwkNcySl6SGWfKS1DBLXpIaZslLUsNmLfkkZyTZmmRXkp1JrujH7k7y7SQ7kkwMzL8zyZ5+fMPCxpckHc+ZQ8wp4Nqq2pfkSuB24FzgNcAFwBuA+5K8HrgUuBhY2399D3DRAuSWJA1h1pV8dfb1N9cAjwPjwBeq6nBVPQYcAtYBVwH3VtWhqtoKjCVZsUDZJUmzGGpPPsktSQ4ANwO3Ad8AfjnJWUkuBF4LjAGrgScHDt0LrJzh8TYnmUwyOTU1dbLfgyTpGIYq+aq6q6qWA7cCDwN/QFfgXwc+AHwLOAAsBaYHDp0GDs/weFuqaryqxsfGxk7uO5AkHdMwe/I/UlX3J7kbeFVV3QCQ5KeAXcBTwD5g1cAh5wJ75imrJGmOhjm7Zt2RffX+LJqDwN8mWZokwIeBz1bVQeBB4LokS5JcDuyqqmcXML8k6TiGWckvAx5KsgTYD1wNnAd8Afgp4IvAu/u5DwCXALvptm+ume/AkqThzVry/dkz62e4a90Mc6eBG/uLJGnE/ItXSWqYJS9JDbPkJalhlrwkNcySl6SGWfKS1DBLXpIaZslLUsMseUlqmCUvSQ2z5CWpYZa8JDXMkpekhlnyktQwS16SGmbJS1LDLHlJapglL0kNs+QlqWGzlnySM5JsTbIryc4kV/TjH03ynSTfTvK2gfl3JtmTZEeSDQsZXpJ0fLN+kDdQwLVVtS/JlcDtSQ4CFwMX0H2g95eAf59kUz++FrgUuAe4aCGCS5JmN+tKvjr7+ptrgMeB/9cfexh4GbC/v/8q4N6qOlRVW4GxJCvmP7YkaRjDrORJcgvwQWAKuKKqnkzyIPBV4Czg1/qpq4HPDRy6F1gJPH3U420GNgOcd955J5NfknQcQ73wWlV3VdVy4Fbg4SSrgUuA9wCfBm7qpy4FpgcOnaZb7R/9eFuqaryqxsfGxk4mvyTpOOZ0dk1V3Q+8HPgd4IGqeqyq7gDelORCYB+wauCQc4E98xVWkjQ3w5xds+7IvnqSCeAgXXGPp7MSWAH8AHgQuC7JkiSXA7uq6tmFiy9JOp5h9uSXAQ8lWUL3AuvVwG7gPuB7wAvAv6iqvUkeoNvG2Q0cAK5ZiNCSpOHMWvJV9Riwfoa7/skMc6eBG/uLJGnE/ItXSWqYJS9JDbPkJalhlrwkNcySl6SGWfKS1DBLXpIaZslLUsMseUlqmCUvSQ2z5CWpYZa8JDXMkpekhlnyktQwS16SGmbJS1LDLHlJapglL0kNs+QlqWGWvCQ1bNaST3JGkq1JdiXZmeSKJB9P8sTA5XCS1/fz70yyJ8mOJBsW/luQJB3LmUPMKeDaqtqX5Erg9qoaP3JnkvXAp6rqm0k2ARcDa4FLgXuAi+Y/tiRpGLOu5Kuzr7+5Bnj8qCnvAv5N//VVwL1VdaiqtgJjSVbMW1pJ0pwMs5InyS3AB4Ep4IqB8bOAt/Pj1fpq4HMDh+4FVgJPH/V4m4HNAOedd94JRpckzWaoF16r6q6qWg7cCjycJP1dvwJ8paqe728vBaYHDp0GDs/weFuqaryqxsfGxk48vSSdjrZvhzvu6K4X2FAr+SOq6v4kdwPLgWeA64G7BqbsA1YN3D4X2HOyISWpGdu3w2WXwYsvwtKl8MgjMDGxYE83zNk1647sqyeZAA5W1TNJ1gDnA18cmP4gcF2SJUkuB3ZV1bMLEVySTkvbtnUFf/hwd71t24I+3TAr+WXAQ0mWAPuBq/vxdwL/rqpqYO4DwCXAbuAAcM08ZpWk09/Gjd0K/shKfuPGBX26/GRHn3rj4+M1OTk50gySdEpt396t4DduPOGtmiSPDp7Ofixz2pOXJM2DiYkF3Ycf5NsaSFLDLHlJapglL0kNs+QlqWGWvCQ1zJKXpIZZ8pLUMEtekhpmyUtSwyx5SWqYJS9JDbPkJalhlrwkNcySl6SGWfKS1DBLXpIaZslLUsMseUlq2Kwln+SMJFuT7EqyM8kV/fiqJA8leSrJ9oH5dybZk2RHkg0LGV6SdHzDrOQLuLaq1gM3Abf3438M3FdVq4FNAEk2ARcDa4H3A/fMd2BJ0vBmLfnq7OtvrgEe71foqapP9nP+b3//VcC9VXWoqrYCY0lWLERwSdLshtqTT3JLkgPAzcBtwEXA3n4b51tJfrufuhp4cuDQvcDKGR5vc5LJJJNTU1Mn9x1Iko5pqJKvqruqajlwK/Aw8LPABcDb6bZnbkjyBmApMD1w6DRweIbH21JV41U1PjY2dpLfgiTpWOZ0dk1V3Q+8nK68v1RVz1XVM8BXgPXAPmDVwCHnAnvmKaskaY6GObtm3ZF99SQTwEG6F10vS/LKJMuANwF/BTwIXJdkSZLLgV1V9ezCxZckHc+ZQ8xZBjyUZAmwH7i6qr6f5GPAXwIB7qyqJ5LsBi4BdgMHgGsWKLckaQipqpEGGB8fr8nJyZFmkKTTTZJHq2p8tnn+xaskNcySl6SGWfKS1DBLXpIaZslLUsMseUlqmCUvSQ2z5CWpYZa8JDXMkpekhlnyktQwS16SGmbJS1LDLHlJapglL0kNs+QlqWGWvCQ1zJKXpIZZ8pLUsFlLPskZSbYm2ZVkZ5Ir+vEfJnmiv/zJwPw7k+xJsiPJhoUML0k6vjOHmFPAtVW1L8mVwO3Aw8Deqnrd4MQkm4CLgbXApcA9wEXzmliSNLRZV/LV2dffXAM8fpzpVwH3VtWhqtoKjCVZMQ85JUknYKg9+SS3JDkA3Azc1g8vT/KdJF9MMt6PrQaeHDh0L7ByhsfbnGQyyeTU1NRJxJckHc9QJV9Vd1XVcuBW4OEkqapXVNX5wL8GHuinLgWmBw6dBg7P8Hhbqmq8qsbHxsZO7juQJB3TnM6uqar7gZcDywfGPgOcnWQZsA9YNXDIucCeecgpSToBw5xds+7IvnqSCeBg//Wy/votwIGqeh54ELguyZIklwO7qurZBUsvSTquYc6uWQY8lGQJsB+4mm6f/T8kmaZbvb+9n/sAcAmwGzgAXDPviSVJQ5u15KvqMWD9DHe9doa508CN/UWSNGL+xaskNcySl6SGWfKS1DBLXpIaZslLUsMseUlqmCUvSQ2z5CWpYZa8JDXMkpekhlnyktQwS16SGmbJS1LDLHlJapglL0kNs+QlqWGWvCQ1zJKXpIZZ8pLUsFlLPskZSbYm2ZVkZ5IrBu5bnmR/kg8NjN2ZZE+SHUk2LFRwSdLshlnJF3BtVa0HbgJuH7jvY8BjR24k2QRcDKwF3g/cM29JJUlzNmvJV2dff3MN8DhAkjcDh4C/GJh+FXBvVR2qqq3AWJIV85xZkjSkofbkk9yS5ABwM3BbkrOB24Bbjpq6Gnhy4PZeYOUMj7c5yWSSyampqRNLLkma1VAlX1V3VdVy4FbgYeAjwCeq6rmjpi4FpgduTwOHZ3i8LVU1XlXjY2NjJxRckjS7M+cyuaruT3I3cCOwM8kHgBVAJfkusA9YNXDIucCe+QorSZqbWUs+yTrghap6OskEcLCqzh64/yPAoaq6L8lB4L1J7gM2Abuq6tkFyi5JmsUwK/llwENJlgD7gauPM/cB4BJgN3AAuOakE0qSTliqaqQBxsfHa3JycqQZJOl0k+TRqhqfbZ5/8SpJDbPkJalhlrwkNcySl6SGWfKS1DBLXpIaZslLUsMseUlqmCUvSQ2z5CWpYZa8JDXMkpekhlnyktQwS16SGmbJS1LDLHlJapglL0kNs+QlqWGWvCQ1bNaST3JGkq1JdiXZmeSKJOck+XKSbyd5PMkbB+bfmWRPkh1JNixsfEnS8Qyzki/g2qpaD9wE3A68APxSVf0c8HvArQBJNgEXA2uB9wP3LEBmSdKQZi356uzrb64BHq+qF6rq+SRLgNXA4/39VwH3VtWhqtoKjCVZsSDJt2+HO+7oriVJMzpzmElJbgE+CEwBV/Rjvwu8A/gmcGU/dTXwuYFD9wIrgafnKW9n+3a47DJ48UVYuhQeeQQmJub1KSSpBUO98FpVd1XVcrptmYeTpKpuApYBnwI+009dCkwPHDoNHD768ZJsTjKZZHJqamruqbdt6wr+8OHuetu2uT+GJL0EzOnsmqq6H3g5sLy/PU23J/+mfso+YNXAIecCe2Z4nC1VNV5V42NjY3NPvXFjt4JfsqS73rhx7o8hSS8Bs27XJFkHvFBVTyeZAA4CK5L8sKr+Bngr8Gg//UHgvUnuAzYBu6rq2XlPPTHRbdFs29YVvFs1kjSjYfbklwEP9S+y7geuptt7//Mkh4HvAu/s5z4AXALsBg4A18x74iMmJix3SZrFrCVfVY8B62e467UzzJ0GbuwvkqQR8y9eJalhlrwkNcySl6SGWfKS1DBLXpIalqoabYBkCnjyBA8/B3hmHuPMF3PNjbnmxlxz02quNVU161+TjrzkT0aSyaoaH3WOo5lrbsw1N+aam5d6LrdrJKlhlrwkNex0L/ktow5wDOaaG3PNjbnm5iWd67Tek5ckHd/pvpKXJB3Hoi35JGcnmemN0UbKXHOzWHOdTpK8KsnfeUNAaRiLruSTvDLJZ+ne1viWgfFVSR5K8lSS7QPjdybZk2RHkg392JlJ7k2yN8lX5+NfkJlyJfl4kicGLoeTvH7Uufrxjyb5TpJvJ3nbwPgof15nJLm7z7Sj/3yCU53rZUm2JNmV5MkkN/fjNyX5fpKdSd6yGHIlWZvkP9F9GM+vHzV/lLmWJ/nT/vf4nSS/ukhynZPky32ux5O8cTHkGrhveZL9ST50SnNV1aK60H3y1GXAu4A/HBj/EvAb/ddn99ebgC/TvWXy5cDX+vHfAv4ECHA98NmFyjVw/3pgcjHkontP/+3AWcDfB/YvklzvAO4HlgBvBP66f85TmWs58Lb+Mc+h+4/QJcAu4BXA64H/2f/sRp1rNTAB/A7woYG5o871i8DG/v7XAc8vkp/XGmBZf/97gD9bJD+v1f19/xb4wpHf5anKdVLf0EJegN/kx+WwAfhvM8z5V8C7Bm7vBVYAnwfe3I/9NPB/FiLXUeN3Ae9dDLnoPo7xL+j+T+0NwNcXSa5PANcP3LcDOH8UuQaeaxL4MPAvB8b+e/8zHHWuf9B//RF+suQXRa6BsSnglYslF90i4nbgw4vl5wW8GfiDwd/lqcq16LZrjuEiYG+SrUm+leS3+/HV/ORbIuwFVg6OV9ULwAtJ/t5ChUtyFvB2ug81H3muqvoq3UcxfhW4F/i1xZAL+Abwy0nOSnIh3QfPjI0qV5JfAF5Gt+IafP49Rz//iHL9j2NMWTS5+q2tx6rqB4shV5LfBZ6j+7/Ij/dTRp3rCeA2BrZTT2Wu06Xkfxa4gK5ILwZuSPIGYCkwPTBvGjh8nPGF8ivAV6rq+f72SHMlWU23BfEe4NPATYshF91KZi/wdeADwLfoPibylOdKcg7wSbotpLk+/ynJVf1SbgaLIleS1wEfBd69WHJV1U10H1n6KeAziyEX3er9E1X13FHTTkmu06Xk/xfwpap6rqqeAb5Ctwe+D1g1MO9culXYj8aTnA0s6VcaC+V64J6B26PO9c+BB6rqsaq6A3hTv3Ieaa6q+mFV3VBVFwI30K3inzrVufpV0eeBW6vqL2d4/tcsklzHMvJcSdYAfwZcW1XfWyy54EcfQ/p7dFtuiyHXNcAHknyNbuH1viS/fqpynS4lvxW4LN0ZG8vofnl/RbclcV2SJUkuB3ZV1bP9+Dv6Y/8Z8LmFCtb/w34+8MWB4VHnOgiMp7OSbp/vB6POle50yqVJQrcP/tmqOngqcyV5JfDndHvwX+iHHwR+NclPpzs76tXA1xZBrmMZaa4kq+heQL++us+AXiy5fiHJq/opbwUeXQy5qmp1VV1UVRcBv0+3qr/vlOWa7xca5uGFilfQ7WHtB/6m//rS/pveSXcWxG/1c88A7qbbv3oMuKAffxnwx3Srsf8KrFjAXLfRv8AzMHfUud4K/Mf++f8aeOciyfVuYDfdls0fAT8zglwfAv62z3Pksg64Ffhu//P6x4sk18/318/SbWs9AfzcIsj1XwZ+p0cuSxdBrvf1v8Mn6BaG5y+S3+O6gfs/wo9feD0luXxbA0lq2OmyXSNJOgGWvCQ1zJKXpIZZ8pLUMEtekhpmyUtSwyx5SWqYJS9JDbPkJalh/x/AETkZpfVIUQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_train = np.array([[2104],[1600],[2400]], dtype=np.float32)\n", "y_train = np.array([[399.900], [329.900], [369.000]], dtype=np.float32)\n", "\n", "# x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n", "# [9.779], [6.182], [7.59], [2.167], [7.042],\n", "# [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n", "\n", "# y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n", "# [3.366], [2.596], [2.53], [1.221], [2.827],\n", "# [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)\n", "\n", "plt.plot(x_train, y_train, 'r.')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:14:07.411792Z", "start_time": "2019-06-15T03:14:07.401767Z" } }, "outputs": [], "source": [ "x_train = torch.from_numpy(x_train)\n", "y_train = torch.from_numpy(y_train)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-04-07T04:47:17.479477Z", "start_time": "2019-04-07T04:47:17.468259Z" } }, "source": [ "`nn.Linear`\n", "\n", "Applies a linear transformation to the incoming data: $y = xA^T + b$" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:14:36.493688Z", "start_time": "2019-06-15T03:14:36.471235Z" } }, "outputs": [], "source": [ "# Linear Regression Model\n", "class LinearRegression(nn.Module):\n", " def __init__(self):\n", " super(LinearRegression, self).__init__()\n", " self.linear = nn.Linear(1, 1) # input and output is 1 dimension\n", "\n", " def forward(self, x):\n", " out = self.linear(x)\n", " return out\n", "\n", "model = LinearRegression()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:15:35.560781Z", "start_time": "2019-06-15T03:15:35.556997Z" } }, "outputs": [], "source": [ "# Define Loss and Optimizatioin function\n", "criterion = nn.MSELoss()\n", "optimizer = optim.SGD(model.parameters(), lr=1e-9)#1e-4)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:16:37.107844Z", "start_time": "2019-06-15T03:16:36.876061Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch[50/1000], loss: 346492.500000\n", "Epoch[100/1000], loss: 148748.437500\n", "Epoch[150/1000], loss: 64518.070312\n", "Epoch[200/1000], loss: 28639.699219\n", "Epoch[250/1000], loss: 13357.108398\n", "Epoch[300/1000], loss: 6847.381348\n", "Epoch[350/1000], loss: 4074.527344\n", "Epoch[400/1000], loss: 2893.413086\n", "Epoch[450/1000], loss: 2390.306152\n", "Epoch[500/1000], loss: 2176.009766\n", "Epoch[550/1000], loss: 2084.726562\n", "Epoch[600/1000], loss: 2045.842651\n", "Epoch[650/1000], loss: 2029.282349\n", "Epoch[700/1000], loss: 2022.227905\n", "Epoch[750/1000], loss: 2019.222412\n", "Epoch[800/1000], loss: 2017.942139\n", "Epoch[850/1000], loss: 2017.397949\n", "Epoch[900/1000], loss: 2017.164795\n", "Epoch[950/1000], loss: 2017.066162\n", "Epoch[1000/1000], loss: 2017.023682\n" ] } ], "source": [ "num_epochs = 1000\n", "for epoch in range(num_epochs):\n", " inputs = Variable(x_train)\n", " target = Variable(y_train)\n", "\n", " # forward\n", " out = model(inputs)\n", " loss = criterion(out, target)\n", " # backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if (epoch+1) % 50 == 0:\n", " print('Epoch[{}/{}], loss: {:.6f}'\n", " .format(epoch+1, num_epochs, loss.data.item()))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad\n", " l_n = \\left( x_n - y_n \\right)^2,\n", "$$\n", " where :`N` is the batch size. " ] }, { "cell_type": "code", "execution_count": 109, "metadata": { "ExecuteTime": { "end_time": "2019-04-07T05:26:38.112863Z", "start_time": "2019-04-07T05:26:38.108523Z" } }, "outputs": [ { "data": { "text/plain": [ "LinearRegression(\n", " (linear): Linear(in_features=1, out_features=1, bias=True)\n", ")" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.eval()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-06-15T03:16:56.606092Z", "start_time": "2019-06-15T03:16:56.379390Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEXCAYAAACH/8KRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl8ldW1//HPIoCAYpFZpkTAARVFiSgtFsXZer2VVqDSKm0VtVqntjjV1gknfvZapepFRariVBTbSivigCiiNiAIyGBAhqBwMUwigxDW7499Qk4ggYzneU7O9/165WX2Pvucs0iEddbez7O3uTsiIiJVVS/qAEREJL0pkYiISLUokYiISLUokYiISLUokYiISLUokYiISLUokYiISLUokYiISLUokYiISLXUjzqAVGjZsqXn5OREHYaISFqZPn36V+7eam/jMiKR5OTkkJeXF3UYIiJpxcyWVmScprZERKRalEhERKRalEhERKRaMmKNpCzbtm2joKCALVu2RB1KRmvUqBEdOnSgQYMGUYciIlWUsYmkoKCApk2bkpOTg5lFHU5GcncKCwspKCjgoIMOijocEamijJ3a2rJlCy1atFASiZCZ0aJFC1WFIjWobVsw2/2rbdvae8+MTSSAkkgM6HcgUrNWrapcf03I6EQiIiLVp0SSJgoKChgyZMhex/34xz9m7dq1lX79yZMnc+qpp+513EknncR7771X6dcXkbpLiSRCa9asYciQIXTo0IF27dpx5pln8tlnn5U5tkOHDowZM2avrzlu3DgOOOCAGo604r7++mv++Mc/Rvb+Ipls4cJo3leJpKLGjoWcHKhXL/x37Nhqv+SAAQNo3749S5Ys4YsvvuAXv/gFP/jBD3ZbfHb3ar9XqhQWFjK2Bn42IlJx27fDvffCUUdF8/5KJBUxdiwMHQpLl4J7+O/QodVKJlOnTmXt2rUMHz6c+vXDVdgDBgygd+/ePPvsswDUr1+fYcOG0a1bN5YsWULXrl0B2L59O1dccQUHHXQQp5xyCmeddRbPPPMMEPYVKygo2Dn+2muvpVOnTpx66qls3rwZgAceeICDDz6YTp06cf/99+8xzm+++YbBgwfTpUsXzj33XDZu3LjzsWHDhtG1a1eys7N57rnn2LRpEyeddBLLli2ja9euzJs3j7feeovu3buTk5PDwIED2bFjR5V/ZiKyu5kz4fjj4YYb4Ac/gFblbLHYpk3txRCbRGJmDc3sUzN73MxamNkLZvaZmS0ys0FJ4+4xswIzm21mPVMS3M03w6ZNpfs2bQr9VTRr1iz69u27W/8JJ5zA7NmzASgqKuLoo49m3rx5pcaMHj2azz//nM8++4ynn36amTNnlvken3/+Of3792fJkiVs27aN8ePHA5Cbm8uCBQuYMWMGt912G99++225cd599900bdqURYsWce+995aK5ZxzziE/P58JEyZw44030qRJEyZPnkynTp3Iz8+nW7duNG/enKlTp7J48WKWLFnClClTKv2zEpHdbdkS/gnKzYUVK2DcOHjpJfi//wufd3f9Wrmy9mKJ0w2JNwFLEt+3Ah5x98lm1hXIM7OXgBOBPkAOcDLwBNCj1iNbtqxy/RXw7bfflnnpq5nRsGHDne3+/fvvNu6NN97g0ksvpX79+rRr145+/fqV+R7t2rXjxBNPBKBPnz4sXRo28mzZsiW33XYbc+fOZevWraxevbrcOCdOnMgTTzwBQLdu3TjmmGN2Pta4cWNuuOEG5s2bx4oVK8p8fqdOnXjssceYNWsWy5YtK3eciFTc++/DL38J8+fDRRfBn/4EzZtHF08sKhIz6wYcB7wI4O7z3X1y4vt8YBvQGOgPjHH37e4+CWhlZrV4m01Cp06V66+Ao446iqlTp+7WP23aNLp37w5AVlYWjRs33m3Mpk2bSiWbbdu2lfke++yzz87vGzRoQFFREV999RX9+vWjd+/ePPzww7Rs2XKPazBbtmwptX3J1q1bAZg7dy4DBw7k7LPP5rHHHiv3Nc4++2yysrK44447OOmkk9JqvUckbjZuhKuugj59wqTIa6/BmDHRJhGIQSKx8HH7QeDqch4/C5jh7huAjkDy/vgrgAPLed5QM8szs7w9feKukOHDoUmT0n1NmoT+KurXrx9ZWVncddddFBUV4e48//zz5OXlMXDgwD0+t1evXjz55JPs2LGDxYsXM2nSpAq/75IlSzjggAM4/fTTWbx4MSv3Uu+ecMIJjB49GoAPPviAWbNmATBv3jwOPfRQvv/97zN58uSd4xs3bsz69et3/pnmzJnDoEGD2G+//cpMnCJSMa+/DkceCSNHwpVXwpw5cMYZUUcVRJ5IgMuAyYnKo5TEtNYI4NJEV0MgebV2B1BU1ou6+yh3z3X33FblrT5V1ODBMGoUZGeHvQays0N78OBqvez48eP55JNP6NixI+3bt+fZZ5/l1Vdf3esGhtdccw1bt24lOzubq666iu9///tkZWVV6D179OhBly5dyMnJYfTo0XTo0GGP42+//Xb+85//kJ2dzciRI/nud78LwOmnn8769evp3Lkzn3766c7xbdq04Xvf+x5du3ZlwYIFXH/99Rx77LEMGDCAo48+ukIxikiJNWvg5z8PSaNRI3j3XXjwQWjaNOrIkrh7pF/A+8BsYCawDCgEfgdkJ/qOTRo7Grgwqb0MaL639+jZs6fv6tNPP92tL13169fP33nnnajDqLK69LsQqUnjxrm3aeOeleV+003umzen9v2BPK/Av+ORL7a7+3eLvzezIYTF9GeBfwCXuPuMpOETgF+Z2VigH7DQ3dekMNxYWLBgAU2bNqVdu3ZMmTKF+fPn07Nnai5gE5Hat3JlmL566SU45piwFtKj9i8rqrLIE0k5bgO6As8lXbF0ODAe6AssJlQuF0QSXcRWr17NWWedxfbt22ndujUvvvgi++67b9RhiUg1ucNf/wrXXRcW0+++G37zG4j7cT2xSiTuPgYYk2heXM6wqxJfGatPnz4sXrw46jBEpAYtWQKXXhoW1fv0gccfh0MPjTqqionDYruISMbasQMeeihckfX++/CXv8A776RPEoGYVSQiIplk3jy4+OKQQM48Ex59NFwUmm5UkYiIpNi2bXDXXWEBff58eOop+Ne/0jOJgCoSEZGUmjEjbG8ycyacf36Y1qrNDRVTQRVJzFx++eUsrMChAqNGjeKFF16otTiqekCWiJRt82a48Ubo1Stc3vvyy/Dii+mfRECJJFJmRteuXenatevOzRUfeeQRDjnkECAki+Tddq+99tqd3w8dOnSvW6lUVPHW88miPiBLpC55770wjXXPPTBkCHz6KZx3XtRR1Rwlkgpo2zbsjLLrV9tqbheZlZVFfn4++fn5vPvuu7s9/uyzz1JYWLiz/dBDD1XvDUUkpb7+OtxYeOKJ8O23MGlSuKy3rn1GUyKpgFWrKtdfHcVnov/617/mww8/ZNCgQQwfPpxevXpRVFRE165deeONN7j11lu58847dz7nD3/4A0cffTSdOnXamZQ2btzIoEGD6Ny5M+eddx7HHXdchc9br8gBWR9//DEnnHACBx98MBdffLEOrRJJ8u9/wxFHwMMPwzXXhE0WTz016qhqhxbbCb/kcs6G2quTTiq7v0cPeOCBPT+3qKiIww47DIBTTz2VkSNH7nzsoYceYvbs2dx555306dOHm2++mfr165OfH/a23DUhrFq1ilmzZvH4449zxx138Prrr3PXXXex//77s3jxYubOnctxxx1XpT9j8QFZ999/PyeffDLjx49nwIABXHbZZYwbN46OHTvyk5/8hL///e+cV5fqdZEqKCyEa6+Fp5+Gbt1g6lTo3TvqqGqXEkmEsrKymD9/fo281oABAwDo27cvI0aMAMIBWI899hgARxxxRKlDqSqjrAOy5s+fz5w5czjttNMA2Lx5M7169aruH0MkbbmHUwqvvDLs2HvLLeEEw6RjgeosJRL2XjmUcZDhTklHcUSq+BCr4gOsoOIHYFX0tZNff/v27Rx22GFMnz69GlGL1A1ffgm/+hW88gr07Bm2OcmkUxO0RhJjjRs3Zs2aNTtPFWzQoAHr16+v8CmDvXr14vHHHwfgo48+Kvds96o49NBDWbFiBdOmTQPCesn69etr7PVF0oE7jB4dprBeew3uuw8++CCzkggokVRIedd51/b13z/72c+49NJLd05VXXzxxXTv3p233367Qs+/8847ycvLIzs7m4ceeoiePXuWewDWEUccQbNmzWjWrBlffPHFXl+7cePGPPXUUwwZMoQuXbowbNgw6tXT/06SORYvhtNPDzcXHn00fPIJ/O53UD8D53msop9u01lubq7n5eWV6ps3bx7dunWLKKJodOnShbfffptO1ThrvjZk4u9C0ldRUbgb/eabISsrVCFDh0Jd/BxlZtPdPXdv4+rgH12Kffzxx6xbtw6AsWPHsu+++9KxY8eIoxJJX59+GrZ4v/bacMXm3Llw2WV1M4lURgYWYZkjPz+fc889FzOjc+fOPP/889ierhwQkTJ9+y3cey/ceWc4K/2ZZ+CCC/Z8IU4mUSKpw84//3zOP//8qMMQSWt5eWEd5JNPYNAg+POfoXXrqKOKl4wuyDJhfSju9DuQuNq8GYYNg+OPh6++gr//HZ57TkmkLBmbSBo1akRhYaH+IYuQu1NYWEijRo2iDkWklHfegaOOghEjQjXy6adw7rlRRxVfsZnaMrOGwEzgfXe/2MyuBn4DbAaucfd/J8bdA/wUWAsMcfcq3RHXoUMHCgoKWL16dc38AaRKGjVqRIcOHaIOQ3Y1dmy4LGnZMujUCYYPh8GDo46q1m3YANdfH04q7NwZ3nwT+vWLOqr4i00iAW4ClgCYWRfgCuAIoCPwhpllAycCfYAc4GTgCaBHVd6sQYMGHHTQQdUOWqTOGTs2XM+6aVNoL10a2lCnk8mECeEKrC++gOuugzvugCZNoo4qPcRiasvMugHHAS8mus4DXnT3r939U0KC6Qn0B8a4+3Z3nwS0MrNqbuYuIqXcfHNJEim2aVPor4O++gp++lM45xz4znfC+en3368kUhmRJxIL16M+CFyd1N0RWJrULgAOLKN/RaK/rNcdamZ5Zpan6SuRSli2rHL9acodnn8+bG/y4ovwxz+GY3CPPz7qyNJP5IkEuAyY7O75SX0NgeTDLXYARXvo3427j3L3XHfPbdWqVQ2HLFKHlbfzQcx2RKiOFSvghz+En/wEDjoIpk+HW2+FpD1OpRLikEh+Bgwys5nA7YRprZVA+6QxHYDlwJe79LcjVCsiUlOGD999XqdJk9Cf5tzhscfg8MPDaYX33w/TpkH37lFHlt4iTyTu/l137+7uPYA/AOOBVwnJpYmZHQ40J1zRNQG4yMyyzOw0YKG7r4kseJG6aPBgGDUKsrPDrdvZ2aGd5gvtixbBKaeE6wZ69oTZs8Oiejn7mEolxOmqrZ3cfbqZPQPMBbYAF7u7m9l4oC+wGCgELogwTJG6a/DgtE8cxYqKwt3ov/89NGgQcuLFF2t7k5qUsbv/ikjdN2dOuKHwo4/gv/4LHnkE2rff+/Mk0O6/IpKxvv0WbrsNjj02nBvy3HNhixMlkdoRy6ktEZGq+uijUIXMmRNm5x54AFq2jDqquk0ViYjUCZs2wW9+A717w7p18OqrYbt3JZHap4pERNLe22+HBfTFi8M2J/feC/vvH3VUmUMViYikrfXrw+W8/fqFUwonTw4L6koiqaVEIiJp6Z//DDcWPvEE/O53MGsW9O0bdVSZSYlERNLK6tVha5Nzz4UWLeDDD+G++7TJYpSUSEQkLbiHHe67dYOXXoLbbw/H4Obu9S4HqW1abBeR2Fu+HC6/PJwZcsIJYTrr8MOjjkqKqSIRkdjasSOcVnjEEeHKrAcegPfeUxKJG1UkIhJLn30Gl1wSzk8/5ZSwR1bnzlFHJWVRRSIisbJ9O4wYAUcdBTNnhmmsSZOUROJMFYmIxMasWWF7k+nT4b//Gx5+GNq1izoq2RtVJCISua1b4ZZbwhVYy5eHo2/Hj1cSSReqSEQkUtOmhSpk3jy48EL405/C/SGSPlSRiEgkvvkGrrkGvvc92LgR/vUv+OtflUTSkSoSEUm5N94IV2QtWQJXXAF33w1Nm0YdlVSVKhIRSZl168I01mmnhWNvp0yBkSOVRNKdEomIpMQrr4QbCf/6V7jhhnCF1oknRh2V1IRYJBIzq2dmk8xsoZktMLMzEv0jzGyRmX1mZj9KGn+PmRWY2Wwz6xld5CKyN6tWwYABcN550KZNOMHw7ruhceOoI5OaEpc1EgcudPcvzexMYLiZbQH6AIcBnYEpwEtm1i/RnwOcDDwB9IgkahEplzs8/XRYUP/mGxg+PGz33qBB1JFJTYtFReLBl4lmNjAL2EqIrwhoBKxKPN4fGOPu2919EtDKzNqmOmYRKd+yZXD22XDRRWG33lmz4KablETqqrhUJJjZMOB6YDVwhrsvNbMJwAdAA+AniaEdgb8nPXUFcCCwcpfXGwoMBejUqVPtBi8iQNhk8ZFHwhqIOzz4YLgqq14sPrJKbYnNr9fd73P3FsBNwEQz6wj0BS4DXgSuTgxtCOxIeuoOQtWy6+uNcvdcd89t1apV7QYvIixYEE4ovPJK6N0b5syBX/9aSSQTxO5X7O4vA/sBtwHj3X2Gu98NnGBm3YAvgfZJT2kHFKQ+UhEB2LYN7rkHjj46JI8nn4SJEyEnJ+rIJFVikUjMrHPxOoeZ9Qa2EJJDrgUHAm2BDcAE4CIzyzKz04CF7r4mqthFMtnHH8Pxx8ONN8I554RtToYMAbOoI5NUissaSTPgNTPLIiyqDwQWA2OBJcAm4EZ3X2Fm4wlTXouBQuCCSCIWyWBbtsAdd8C990LLljBuHPzoR3t/ntRNsUgk7j4DOKSMh84uY+wO4KrEl4ik2NSp4e70BQtC9XH//dC8edRRSZRiMbUlIvG3cSNcdVW4G33LlrAO8uSTSiISk4pEROKjbdtwN/qu6tULl/ReeSXcdRfst1/qY5N4UiIRkVLKSiIQ7hF5772w7btIMk1tiUiFKYlIWZRIRESkWpRIRAQI6x99+0YdhaQjJRIRYfLksJg+ZUrUkUiNGDs2bC1Qr17479ixtfp2WmwXyWDbtoXDpvLzQ/vII2H16rIX3Nu0SW1sUkVjx8LQobBpU2gvXRraAIMH18pbqiIRyVAvvwwNG5Ykkffeg9mzYeXKMM2169fKlXt+PYmJm28uSSLFNm0K/bVEFYlIhtm0KWxrsnlzaJ9+Orz2mvbHqjOWLatcfw1QRSKSQR57DPbdtySJfPJJuENdSaQOKe/8pVo8l0mJRCQDrF0bkkXxVPlFF4Xpqu7do41LasHw4dCkSem+Jk1Cfy1RIhGp4+66q/R+WJ9/DmPGRBaO1LbBg2HUKMjODp8esrNDu5YW2kFrJCJ11ooV0KFDSfvGG0NSkQwweHCtJo5dKZGI1EFXXQUPPVTSXrUKWreOLh6p2zS1JVKHLFwYZjOKk8gDD4S1ECURqU2qSETqAHc4/3x46aWSvg0boGnT6GKSzKGKRCTN5eWFnTCKk8gzz4TEoiQiqRKLRGJm9cxskpktNLMFZnZGor+9mb1mZsvNbFrS+HvMrMDMZptZz+giF4nOjh3Quzccd1xot2kTTi5M4RqrCBCTRAI4cKG7HwJcDRRf8PwcMNbdOwL9AMysH9AHyAGuA55IebQiEXvzTcjKgg8+CO1//StsYbLPPtHGJZkpFmsk7u7Al4lmNjArUWmYuz+dGJO4F5f+wBh33w5MMrNWZtbW3bUTkNR5334LBx9cstvFscfCRx+FpCISlbhUJJjZMDMrBK4Fbgd6ACsSU17zzey3iaEdgaVJT10BHJjaaEVS729/CxVHcRKZNg2mT1cSkejFoiIBcPf7gPvMrD8wEfgrcBhwMpAFfGhmk4CGwI6kp+4AinZ9PTMbCgwF6FSLe8yI1LZvvoFmzWD79tA+5xz4xz+0P5bER2wqkmLu/jKwHyFBTHH3te7+FTAVOIQwBdY+6SntgIIyXmeUu+e6e26rVq1SELlIzXvkEdhvv5IkMncu/POfSiISL7FIJGbW2czaJr7vDWwhLLSfYmb7m1kz4ATgY2ACcJGZZZnZacBCd18TVewitaGwMCSLX/0qtC+5JFzSe/jh0cYlUpa4TG01A14zsyxgFTDQ3ZeZ2f8D/gMYcI+755vZYqAvsBgoBC6IKmiR2nDrrXDbbSXtpUtrdQdwkWqzcMFUBQebneHuE2sxnlqRm5vreXl5UYchskfLl5dOGLfcArffHl08ImY23d1z9zauslNb/zazz83sFjPrsPfhIlIRl19eOomsXq0kIumjsolkHNAGuA343MwmmNkPE1NSIlJJ8+aFtZBHHw3tkSPDWkjLltHGJVIZlVojcfcBZrY/cD7wM+DMxNf/mdlfgSfc/bOaD1OkbnGHH/4wXMYLYa+s9evDFVoi6abSV225+wZ3f8LdTyLchX4zYdF7GDDfzCab2WAz02YNImX48MOQOIqTyPPPQ1GRkoikr2pd/uvuBe5+j7sfCRwLPA+cCDwFfGFm/2Nm7ff4IiIZoqgIevaEE04I7Q4dYOtWGDgw2rhEqqva95GYWSczGwaMBgYRLtWdRbgX5Gog38yuru77iKSziROhfn2YMSO0X389XKXVsGG0cYnUhCrdR5K4eXAAIXEcT0ge64BHgMfdfaaZ1SdssPgH4E9mttHdtVOvZJStWyEnJ+zMC3D88fD++2FqS6SuqNT/zmY21MzeApYD/0O423wKYeG9nbtf6e4zAdx9u7u/SEg0i4HflvOyInXSc89Bo0YlSeSjj8K270oiUtdUtiJJXKTIl4RNFUe7e/6enuDu35hZHnBOFeITSTtffw3771/SPu+8cHqh9seSuqqyieSfhIOkJrj7bjvu7sEk4INKvpdI2nnwQbg6aUVw3jw47LDo4hFJhcreR/LfVXkTdx9dleeJpIvVq6F165L2FVeEmwtFMoFma0Wq6fe/L51Eli9XEpHMokQiUkVLl4Z1j+HDQ/v228Md6x20C51kmLhsIy+SVn75SxidNGFbWAjNm0cXj0iUVJGIVMKcOaEKKU4ijz4aqhAlEclkqkhEKsAdzjor3KEOsM8+oQrZd99o4xKJA1UkIntRfCd6cRL5299gyxYlEZFiqkhEylFUBMccA7Nnh3bnzjB/PjRoEG1cInGjikSkDBMmhE0Wi5PIW2/BokVKIiJliUUiMbN6ZjbJzBaa2QIzOyPpsRZmtsrMfp/Ud4+ZFZjZbDPrGU3UUhdt2RJOJzwnsaFPnz6hMjn55GjjEomzWCQSwIEL3f0Qwtbzw5Me+3/AjOKGmfUD+gA5wHWELVtEqu3pp6Fx47CIDjB9Orz7rjZZFNmbWKyRuLsTNoKEcOriLAAzOxXYDnyYNLw/MMbdtwOTzKyVmbV195WpjFnqjg0b4DvfKWkPGBBOLdQmiyIVE5vPWmY2zMwKgWuB282sMXA74QjfZB2BpUntFcCBqYlS6pr77y+dRBYuhBdeUBIRqYxYVCQA7n4fcJ+Z9QcmAn8H/uLua6303+qGwI6k9g5gt52IzWwoMBSgU6dOtRW2pKlVq6Bt25L2NdfA//xPdPGIpDMLs0rxYmYFQAtgQaKrLWEd5bfAKcBkd38qMXYZ0MPd15T3erm5uZ6Xl1e7QUvauP56uO++kvYXX8CBqmlFdmNm0909d2/jYlGRmFlnYJO7rzSz3sAWd2+c9PitwHZ3H2tmW4BfmdlYoB+wcE9JRKTY4sXQpUtJ++674YYbootHpK6IRSIBmgGvmVkWsAoYuIex44G+hON7C4ELaj88SXcXXhiuyiq2di00axZdPCJ1SSyntmqaprYy16xZ0KNHSfvxx8POvSKyd2k1tSVS09zh1FPDHekATZuGBfbGjff8PBGpvNhc/itSU4pvIixOIq+8Eu4VURIRqR2qSKTO2L4duncPGysCHHpoOD+kvv4vF6lVqkikTvjHP8KGisVJZPLk8L2SiEjt018zSWubN4cbCzdsCO2TT4Y339Sd6SKppIpE0tbo0dCkSUkS+fjjsC6iJCKSWkok5Rk7FnJywqptTk5oSyysWxeSRfFlvBdcEK7SSr7MV0RSR4mkLGPHwtChsHRp+Bdq6dLQVjKJ3L33wgEHlLQXLdKvRSRqSiRluflm2LSpdN+mTaFfIvHll6EKKd7S5Le/DTm+c+do4xIRLbaXbdmyyvVLrbruutI7865cCW3aRBePiJSmiqQs5W07r+3oUyo/P1QhxUlkxIhQhSiJiMSLEklZhg8PlwMla9Ik9Eutc4dBg+Dgg0v61q0L01kiEj9KJGUZPBhGjYLs7PCRODs7tAcPjjqyOu/jj8OFci+8ENpjxoTEknyKoYjEi9ZIyjN4sBJHCu3YASedFPbJAmjeHFasgEaNIg1LRCpAFYlE7u23ISurJIn8859QWKgkIpIuVJFIZLZtg27dwr0gAEceCTNnhqQiIulDFYlE4uWXoWHDkiTy3nswe7aSiEg6UkUiKbVpE7RoAVu2hPbpp8Nrr2l/LJF0popEUmbUKNh335Ik8sknMHGikohIuotFIjGzemY2ycwWmtkCMzvDzFqY2Qtm9pmZLTKzQUnj7zGzAjObbWY9o4xd9m7NmpAsLr00tIcMCZf0du8eaVgiUkNikUgABy5090OAq4HhQCvgEXc/GDgDeNTMGphZP6APkANcBzwRTchSEcOHh6msYp9/Dk8+GV08IlLzYrFG4u4OfJloZgOz3H0+MD/xeL6ZbQMaA/2BMe6+HZhkZq3MrK27r4widinbihXQoUNJ+8Yb4a67ootHRGpPXCoSzGyYmRUC1wK37/LYWcAMd98AdASWJj28AjiwjNcbamZ5Zpa3evXqWoxcdvXrX5dOIqtWKYmI1GWxSSTufp+7twBuAiaahSVYM+sKjAASM+w0BHYkPXUHUFTG641y91x3z23VqlXtBi8ALFgQ1kJGjgztBx4IayGtW0cbl4jUrlhMbSVz95fN7EGghZntC4wjrJ8sSQz5Emif9JR4nEA8AAAOK0lEQVR2QEFqo5Rk7vDjH4d7Q4pt2ABNm0YXk4ikTiwqEjPrbGZtE9/3BrYA+wAvA5e4+4yk4ROAi8wsy8xOAxa6+5qUBy0A5OWFTRaLk8gzz4TEoiQikjniUpE0A14zsyxgFTAQuA3oCjxnJTcaHA6MB/oCi4FC4IKURyvs2AHf/S58+GFot2kTTiTeZ59o4xKR1LNwwVTdlpub63l5eVGHUWe88QacdlpJ+9//hjPPjC4eEakdZjbd3XP3Ni4uFYmkgW+/ha5dYfny0D72WPjoI+2PJZLpYrFGIvH34oth2qo4iUybBtOnK4mIiCoS2YuNG6FZMyhKXGB9zjnwj39ofywRKaGKRMr18MPh6qviJDJ3bjh0SklERJKpIpHdFBZCy5Yl7aFD4X//N7p4RCTeVJFIKbfeWjqJLF2qJCIie6aKRICwiN6pU0n7D3+A226LLh4RSR9KJMJll5WuOlavLl2ViIjsiaa2Mti8eWHhvDiJjBwZtjdREhGRylBFkoHc4dxz4dVXQzsrC9atg/32izYuEUlPqkgyzAcfhE0Wi5PI88/D9u1KIiJSdapIMkRREfTqBTMS+yh37Aj5+dCwYbRxiUj6U0WSAV57DerXL0kir78Oy5YpiYhIzVBFUodt3Qo5ObAycZr98cfD+++HqS0RkZqif1LqqGefhUaNSpLIRx+VrI+IiNQkVSR1zNdfw/77l7T794dx47Q/lojUHn0+rUP+/OfSSWT+fHjpJSUREaldqkjqgNWroXXrkvYVV4SbC0VEUkEVSZq7+ebSSaSgQElERFIrFonEzOqZ2SQzW2hmC8zsjET/1Wa2LNF3VtL4e8yswMxmm1nP6CKPztKlYcrqrrtC+447wh3r7dtHG5eIZJ64TG05cKG7f2lmZwLDzSwfuAI4AugIvGFm2cCJQB8gBzgZeALoEUnUEfnFL+DJJ0vahYXQvHl08YhIZotFReLBl4lmNjALOA940d2/dvdPgSVAT6A/MMbdt7v7JKCVmbWNIu5UmzMnVCHFSeTRR0MVoiQiIlGKRSIBMLNhZlYIXAvcTqhCliYNKQAOLKN/RaJ/19cbamZ5Zpa3evXq2gs8BdzhzDOhe/fQbtQIvvkGLr002rhERCBGicTd73P3FsBNwESgIbAjacgOoGgP/bu+3ih3z3X33FatWtVe4LVs6tRwE+HEiaE9bhxs3gxNmkQbl4hIsbiskezk7i+b2YPAl0Dy0nEHYHkZ/e0I1UqdUlQExxwDs2eHdufO4b6QBg2ijUtEZFexqEjMrHPxOoeZ9Qa2ABOAQWbWxMwOB5oDMxP9F5lZlpmdBix09zVRxV4bJkwImywWJ5G33oJFi5RERCSe4lKRNANeM7MsYBUw0N2nm9kzwFxCYrnY3d3MxgN9gcVAIXBBVEHXtC1boEOHcBUWwIknwuTJ2h9LROLN3D3qGGpdbm6u5+XlRR3GHj31FFx0UUl7+nQ49tjo4hERMbPp7p67t3FxqUgy1vr10KxZSXvgQHjuOe2PJSLpQ5MmEbr//tJJZOHCcPStkoiIpBNVJBFYtQraJt1CefXV8MAD0cUjIlIdqkhS7PrrSyeRL75QEhGR9KZEkiKLF4cpq/vuC+177gl3rB+42z35IiLpRVNbKfCzn8Ezz5S0164tvTYiIpLOVJHUolmzQhVSnEQefzxUIUoiIlKXqCKpBe5wyinw9tuh3bRpWGBv3DjauEREaoMqkho2ZUq4E704iYwfDxs2KImISN2liqSGbN8ORx4JCxaE9qGHhvND6usnLCJ1nCqSGvDKK2FDxeIk8s47YadeJRERyQT6p64aNm+G1q1h48bQ7tcP3nhDd6aLSGZRRVJFo0eHw6WKk8jMmfDmm0oiIpJ5VJFU0rp1cMABJe3Bg0vfIyIikmlUkVTCvfeWTiKLFimJiIioIilD27bhvo/y/Pa3MGJE6uIREYkzJZIy7CmJrFwJbdqkLhYRkbjT1FYlKYmIiJSmRCIiItUSi0RiZo3MbJSZLTSzpWZ2baJ/hJktMrPPzOxHSePvMbMCM5ttZj2ji1xEROKyRrIvMBG4FGgBzDWzGUAf4DCgMzAFeMnM+iX6c4CTgSeAHhHELCIixKQicfdCd3/Jg6+A5cAOQnxFQCOgeAm8PzDG3be7+ySglZm1LfOFq6i8dRCtj4iI7C4uFclOZnYkIXG8B0wAPgAaAD9JDOkI/D3pKSuAA4GVu7zOUGAoQKdOnSoVw8qVex8jIiJBLCqSYmbWEnga+DnQAegLXAa8CFydGNaQUK0U20GoWkpx91Hunuvuua1atarVuEVEMllsKhIzOwB4FbjJ3f9jZvcB4919BjDDzD4xs27Al0D7pKe2AwpSH7GIiEBMKhIz2x/4J3Cnu/870b0FyLXgQKAtsIEw3XWRmWWZ2WnAQndfE0ngIiISm4rkKuAY4AEzeyDRNxC4A1gCbAJudPcVZjaeMOW1GCgELkh9uCIiUszcPeoYal1ubq7n5eVFHYaISFoxs+nunrvXcZmQSMxsNbC0ik9vCXxVg+HUFMVVOYqrchRX5dTVuLLdfa9XK2VEIqkOM8urSEZONcVVOYqrchRX5WR6XLFYbBcRkfSlRCIiItWiRLJ3o6IOoByKq3IUV+UorsrJ6Li0RiIiItWiikRERKoloxOJmTU2s0OijmNXiqty4hpXOjGz75jZQVHHIekpIxOJme1vZq8QtqYfltTf3sxeM7PlZjYtqX+3g7TMrL6ZjTGzFWb2QU38JSwrLjN7yMzyk76KzOzwqONK9Ff44LEU/rzqmdmDiZhmm1nvCOIq76C2q81smZktMLOz4hCXmeWY2euEPewG7zI+yrhamNkLid/jIjMbFJO4WprZe4m4ZpnZsXGIK+mxFma2ysx+n9K43D3jvoD9gFOAi4HHk/qnAD9LfN848d9+hC3t6wOnATMT/b8AngcMuAR4pbbiSnr8ECAvDnERtqmZRtji/1BgVUzi+jnwMpAFHAvMS7xnKuNqAfwo8ZotCYmuL7AQaAocDnyR+NlFHVdHoDdwG/D7pLFRx3UicFLi8a7Aupj8vLKBZonHLwPGxeTn1THx2JPAv4t/l6mKq1p/oHT/AoZQ8g9QT+DdMsaMBC5Oaq8gbCD5KnBqoq8JsLE24tql/z7gV3GICzgB+JBQ1R4NfBKTuP4CXJL02GygSxRxJb1XHnALYVPS4r73Ez/DqOPqnvj+VkonkljEldS3Gtg/LnERPqgMB26Jy88LOBV4LPl3maq4MnJqqxw9gBVmNsnM5pvZbxP9HSm9vUrxQVo7+919E7DJwlb4tcLMGgADgGfjEJe7f0DJwWNjKH3wWJQ/r7nAf5lZAwvHDhwEtIoqLis5qK3lLu9fsOv7RxTXnHKGxCauxDTgDHffEIe4zOzPwFpCNfxQYkjUceUDt5M09ZzKuJRISrQmnA8/gHAm/OVmdjTlH6RVoQO2atC5wFR3X5doRxqXmXWkcgePpern9RjhL8snwO+A+YRdolMel5U+qK2y75+SuDzxkbQMsYjLzLoCI4BL4xKXu18NNCN8qPtbHOIiVCF/cfe1uwxLSVxKJCX+D5ji7ms9nBs/lbAmUd5BWjv7zawxkJX4xFRbLgGeSGpHHdevSRw85u53AyfYng8eS0lc7r7N3S93927A5YRqZHmq47JdDmor4/07xCSu8kQel5llA+OAC919SVziAnD3HcAjhOnJOMR1AfA7M5tJ+HB3hZkNTlVcSiQlJgGnWLgSqBnhf5CPKf8grQmETwIAP6X0OfI1KvEXqgvwdlJ31HFV9uCxlMRl4VLghmZmhHWJV9x9SyrjsrIPapsADDKzJhauumsOzIxBXOWJNC4za0+4aOISD6ekxiWuI83sO4kh5wHT4xCXu3d09x7u3gN4lFCdjE1ZXDW98JMOX4QrZ/IJVzusT3x/cuIHu4Bwdc0vEmPrAQ8S5hNnAIcl+hsBzxE+Vb4DtK3FuG4nsaiXNDbquM4D/pV4/3nAL2MS16WEQ89WAM8A+0YQ1++BbxLxFH91Bm4CPk/8vL4Xk7iOSPx3DWEKMB84OAZxvZX0Oy3+ahiDuK5I/A7zCR8+u8Tk99g56fFbKVlsT0lc2iJFRESqRVNbIiJSLUokIiJSLUokIiJSLUokIiJSLUokIiJSLUokIiJSLUokIiJSLUokIilgZm3NbJ2ZuZmdX86YlmZWaGY7zKxPqmMUqSolEpEUcPeVhDvbAUaYWaMyho0gbJ3yv+7+XsqCE6km3dkukiJmVo9wGFgv4A/ufkfSY98D3iVspne4u6+PJkqRylMiEUkhMzsG+A9h08tD3X2FmdUn7IPUHejv7uOjjFGksjS1JZJC7v4x4TCkfYF7E91XE5LIeCURSUeqSERSzMyaAp8SzoM4jbBdOkA3d/8issBEqkgViUiKufvXhCrEgH8QziK/XklE0pUqEpEIJA7e+hzIJhwI1trdt0YblUjVqCIRicalhCSyilCR/C7acESqThWJSIqZWQ4wG9gOfB94H6gPHO3uC6OLTKRqVJGIpFBiSusJYD/gRnefTThKuRHwv1HGJlJVqkhEUsjMLgceBj4EvuvuO8ysAfAJcBhwsbs/EWWMIpWlRCKSIklTWo2BXHefmfTYKcAbwFrCZcCroohRpCo0tSWSAokprdGEKa2HkpMIgLu/CfwNOAD4c+ojFKk6VSQiKWBmvwL+AnwBHJa4l2TXMR2A+YS73n/g7v9KbZQiVaNEIiIi1aKpLRERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqRYlEhERqZb/D7wgc9QkjDp+AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "predict = model(Variable(x_train))\n", "predict = predict.data.numpy()\n", "plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')\n", "plt.plot(x_train.numpy(), predict, 'b-s', label='Fitting Line')\n", "plt.xlabel('X', fontsize= 20)\n", "plt.ylabel('y', fontsize= 20)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用pytorch建立卷积神经网络并处理MNIST数据。\n", "https://computational-communication.com/pytorch-mnist/" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }