{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "x_data = torch.tensor([[1],[2],[3]], dtype=torch.float)\n", "y_data = torch.tensor([[2],[4],[6]], dtype = torch.float)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class LinearRegression(torch.nn.Module):\n", " def __init__(self):\n", " super(LinearRegression,self).__init__()\n", " self.linear = torch.nn.Linear(1,1)\n", " def forward(self,x):\n", " y_pred = self.linear(x)\n", " return y_pred" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model = LinearRegression()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.MSELoss(size_average=False)\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "epoch_list = []\n", "loss_list = []" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before Training 3.9182913303375244\n" ] } ], "source": [ "print(\"Before Training \", model(torch.tensor([[4]], dtype = torch.float)).item())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "for epoch in range(500):\n", " y_pred = model(x_data)\n", " loss = criterion(y_pred,y_data)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " epoch_list.append(epoch)\n", " loss_list.append(loss.item())" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAFMVJREFUeJzt3X+MHGd9x/HPZ3bvh88/Yie+hCQGHISFQKEFdKKhKZQGglxKoa34gwhoSlP5n/4IFRIkQirqPy2oFT+qtrRWkgJqFKoCFVEEhTQhokg0cE4C+WHSOCQFk5C7ECdO/Otub7/9Y2bPe+edHXt3fXfP3vslnWbm2dmZ57lcPvv4mWd2HBECAKQvW+0KAAAGg0AHgCFBoAPAkCDQAWBIEOgAMCQIdAAYEgQ6AAwJAh0AhgSBDgBDol61g+2bJL1D0kxEXFqU/Y2k35Y0J+lRSR+IiGerjrV9+/bYuXNnXxUGgPVm3759T0fEZNV+rrr13/abJL0g6Qttgf42SXdGRMP2JyQpIj5SdbKpqamYnp4+nfoDAAq290XEVNV+lUMuEfFtSc8sK/tmRDSKzf+RtKOnWgIABmYQY+h/KOnrAzgOAKAPfQW67Y9Kaki6ucs+e2xP256enZ3t53QAgC56DnTbVyu/WPre6DIQHxF7I2IqIqYmJyvH9AEAPaqc5dKJ7d2SPiLp1yPi6GCrBADoRWUP3fYtkr4r6RW2D9q+RtLfS9os6Xbb99n+p7NcTwBAhcoeekRc1aH4xrNQFwBAH5K4U/SO/U/pH+86sNrVAIA1LYlAv+vhWd3w34+tdjUAYE1LItBtqcnDrAGgqyQCPbNFngNAd0kEukQPHQCqJBHomS2R5wDQVRKBzhg6AFRLItAzOugAUCmJQLdNDx0AKqQR6BKzXACgQhqBzrRFAKiUSKBLwSg6AHSVRKBnZsgFAKokEegWF0UBoEoSgc60RQColkSgi4uiAFApiUDPnC+7PLoUANa9JALdyhO9SZ4DQKkkAp0eOgBUSyLQXQQ6PXQAKJdIoOeJzs1FAFAukUDPl4y4AEC5JAI9a/XQCXQAKJVEoBcddO4WBYAukgj0xR76KtcDANayykC3fZPtGdsPtJWda/t2248Uy21ns5InZ7kQ6QBQ5nR66J+TtHtZ2XWS7oiIXZLuKLbPGjOGDgCVKgM9Ir4t6Zllxe+S9Pli/fOSfmfA9VqiNYbOjUUAUK7XMfQLIuJJSSqW5w+uSqdi2iIAVDvrF0Vt77E9bXt6dna2p2NwURQAqvUa6E/ZvlCSiuVM2Y4RsTcipiJianJysqeTcVEUAKr1Gui3Srq6WL9a0lcHU53OuCgKANVOZ9riLZK+K+kVtg/avkbSxyVdafsRSVcW22cNF0UBoFq9aoeIuKrkpbcMuC6lGEMHgGpJ3CnKGDoAVEsi0DOmLQJApSQC/eQj6Eh0ACiTRqDTQweASokEOtMWAaBKEoG+OIbOPBcAKJVEoPOQaAColkSgn3wEHYkOAGWSCPQWeugAUC6JQG/10LlXFADKJRHojKEDQLU0Al1MWwSAKkkEesZ3uQBApSQCnTtFAaBaIoHe+vpcEh0AyqQR6MWSHjoAlEsi0DO+ywUAKiUR6DzgAgCqJRHoPIIOAKolEeiihw4AlZIIdMbQAaBaEoF+cpYLiQ4AZZIIdMbQAaBaEoG+OMuFb+cCgFJJBTpxDgDl+gp0239u+0HbD9i+xfb4oCq25DzFKDqzXACgXM+BbvtiSX8maSoiLpVUk/SeQVWsXcbzLQCgUr9DLnVJG2zXJU1IeqL/Kp2q9eVcDKEDQLmeAz0ifibpbyX9RNKTkp6LiG8OqmLtssUxdBIdAMr0M+SyTdK7JF0i6SJJG22/r8N+e2xP256enZ3t8Vz5kh46AJTrZ8jlrZIei4jZiJiX9BVJv7p8p4jYGxFTETE1OTnZ46lad4qS6ABQpp9A/4mky2xPOB/kfouk/YOp1lIZ0xYBoFI/Y+h3S/qSpHsk3V8ca++A6rXE4hOL6KEDQKl6P2+OiI9J+tiA6lIq45miAFApjTtFxbRFAKiSRqAv9tBJdAAok1Sg00MHgHJJBHrr63OZ5wIA5ZIIdHroAFAtiUDnEXQAUC2JQG8NuPD1uQBQLo1A5xF0AFApkUDPl0xbBIBySQQ6Y+gAUC2JQGcMHQCqJRHo9NABoFoSgX5yHjqJDgBlkgj0FuIcAMolEehZxvehA0CVJAJ98ZtcyHMAKJVEoGfcWAQAlZIIdC6KAkC1pAKdPAeAcmkEurgoCgBVkgj0xYdEr241AGBNSyLQW9+22OQJFwBQKolAp4cOANWSCPTWGDoddAAol0agF7XkoigAlOsr0G1vtf0l2z+yvd/2GwZVsSXnKZbkOQCUq/f5/s9I+s+IeLftUUkTA6jTKU7eKUqiA0CZngPd9hZJb5L0B5IUEXOS5gZTreXnypeMoQNAuX6GXF4maVbSv9i+1/YNtjcOqF5L8IALAKjWT6DXJb1O0mcj4rWSjki6bvlOtvfYnrY9PTs728fp+C4XAOimn0A/KOlgRNxdbH9JecAvERF7I2IqIqYmJyd7OlGrhw4AKNdzoEfEzyX91PYriqK3SHpoILVaZnEMnUF0ACjV7yyXP5V0czHD5ceSPtB/lU61OG3xbBwcAIZEX4EeEfdJmhpQXUpxURQAqqVxpygPuACASokEOo+gA4AqSQS6lPfS+S4XACiXTKBnNmPoANBFMoFuMYYOAN0kE+iZzXe5AEAXyQR6LTM9dADoIqlAbywQ6ABQJplAz8wYOgB0k0yg12uZFhhEB4BSyQR6ZqtBoANAqWQCvZbxbYsA0E0ygV7PMi0whg4ApZIJ9CwTY+gA0EUygV6zCXQA6CKZQM8yM+QCAF0kE+j1zFwUBYAukgl0pi0CQHfJBHqNHjoAdJVMoNcZQweArpIJ9CxjlgsAdJNMoDNtEQC6SyfQ6aEDQFcEOgAMibQCnYuiAFCq70C3XbN9r+3bBlGhMpmZtggA3Qyih36tpP0DOE5X9YwbiwCgm74C3fYOSb8l6YbBVKcc0xYBoLt+e+iflvRhSc0B1KWrms0zRQGgi54D3fY7JM1ExL6K/fbYnrY9PTs72+vpVKvRQweAbvrpoV8u6Z22H5f0RUlX2P7X5TtFxN6ImIqIqcnJyZ5Pxo1FANBdz4EeEddHxI6I2CnpPZLujIj3DaxmyzBtEQC6S2oeevOsj9QDQLrqgzhIRNwl6a5BHKtMzVaDRAeAUsn00PNpi6tdCwBYu5IJ9HrGtEUA6CaZQK9lVoMuOgCUSibQM1vMWgSAcskEei0T89ABoIuEAj0j0AGgi4QCXdxYBABdpBPo3PoPAF2lE+hZXlUecgEAnSUU6PmSh1wAQGfJBHqWWZK4uQgASiQT6PUi0BlHB4DOkgn0zHmgM+QCAJ0lE+i11pALgQ4AHSUX6MxFB4DO0gt0eugA0FE6gW4CHQC6SSbQM3roANBVMoE+UmOWCwB0k0ygj9ZqkqS5Bg+5AIBO0gn0el5VAh0AOksm0MeKQD/RWFjlmgDA2pRMoNNDB4Dukgn0xR46D4oGgI6SCfRWD/3EPIEOAJ30HOi2X2z7W7b3237Q9rWDrNhyY/Vilgs9dADoqN7HexuSPhQR99jeLGmf7dsj4qEB1W2JxSGXeS6KAkAnPffQI+LJiLinWH9e0n5JFw+qYsstXhSlhw4AHQ1kDN32TkmvlXT3II7XyRizXACgq74D3fYmSV+W9MGIONzh9T22p21Pz87O9nyexYuiBDoAdNRXoNseUR7mN0fEVzrtExF7I2IqIqYmJyd7PtdojR46AHTTzywXS7pR0v6I+OTgqtRZvZaplpk7RQGgRD899MslvV/SFbbvK37ePqB6dTRWz+ihA0CJnqctRsR3JHmAdak0SqADQKlk7hSV8h46F0UBoLOkAp0eOgCUSyrQx+o1eugAUCKpQB+tMeQCAGXSCvR6xrRFACiRVKCPj2R8fS4AlEgq0LeMj+jw8fnVrgYArElJBfo5G0b03DECHQA6SSrQt2wY0WECHQA6SirQz9kwoiNzC5rnO9EB4BTJBbokeukA0EFSgb5lQ/7VM4ePN1a5JgCw9iQV6K0eOhdGAeBUSQX6lnGGXACgTFKB3uqhP0ugA8Apkgr0F50zLkn62aFjq1wTAFh7kgr0zeMj2r5pVI8/fWS1qwIAa05SgS5JO8/bqMd+QaADwHLpBfr2jfTQAaCD5AL90ou2aOb5E3qMUAeAJZIL9Le+6gJJ0tfuf3KVawIAa0tygb5j24TeuGu7PnvXo/qvh55Sg+91AQBJUn21K9CLv/69V+v9N35Pf/SFaW0eq2vXBZv08vM36aXnbdSObRuKnwlNbhpTlnm1qwsAKyLJQN+xbUJfv/aNuuvhGX3nwNM6MPOC7vzRjJ5+YW7JfqO1TBcXAX/x1g160TnjOn/zuC7YMqYLtozr/C1jOm/jmGqEPoAhkGSgS9L4SE27L71Quy+9cLHs6FxDPzt0TAcPHdPBQ0d18NnW+jHt3//UKYEvSbXM2r5pNA/4zeOa3Dym8zaOatvG0cXluROj2rZxROduHNXEaLK/MgBDrq90sr1b0mck1STdEBEfH0itejQxWteuCzZr1wWbO74+12jq6RdO6KnDxzXz/AnNHD6upw6f3D546Kju/ckhHTo6p2Z0Psf4SFYE/Ki2TYxqy4a6No+NaPN4XVs25MvN4yPaUiw3j9d1TlG+aayuei25yxYAEtFzoNuuSfoHSVdKOijp+7ZvjYiHBlW5QRutZ7po6wZdtHVD1/2azdDh4/N65sicDh2d0y9eyJfPHJkvlnM6dGROzxyd088PH9fzx+d1+FhDx+YXKuswVs80MVrTxGi9WNa0odjeMFrTxrb1iZGTr42PZBqr1zRWzzQ2kmm0lmlspNiu5+t5Wb49WstkM5QErCf99NBfL+lARPxYkmx/UdK7JK3ZQD9dWWZtnRjV1onRM3rf/EJTLxxv6PnjDR0+Pq/Dx+fz9WP58vnjDR2da+jo3IKOzi3o2HxDR04s6NjcgmaeP66jJxaK1/J9GmX/TDhNo62wLz4IRmpWvZapnlkjtUz1mjWS5ct6LdNI5mXrxXta+2RLX2uVZbZqWf6TubWtxfLMVpZZtbbyxdeK8sz5773WdrzMWnbsfNu2LMnFtop1F8exnG8rL2tfb70un3x/61iL72tbX/I6H5BY4/oJ9Isl/bRt+6CkX+mvOmkbqWX5UMzGM/sgKDPXaOrY3IKOzjd0Yr6pE42m5hpNnWgs6ERrOd/U3EKzeL1V3lzyemu9sRBqNJuaXwg1FppqNEPzC/kxj8wt5GULoflmvmwsNDXfjFPL+/ygSdmSDwoVHyiLZSc/AE7uf3JryceBO64u3f909llWt07v6HScpWVnds4lZ3HJeh/HLHPGH6dn+IZePq7PpA1/9buv1usvObeHs5y+fgK9U0tO+T/d9h5JeyTpJS95SR+nW39G65lG65nO0chqV2WJiNBCMw/2RjNfbzZDCxFqRqjZVL7ezLcXFpdaXG8vb4aWHGOhGYqi7ORxTh4zlL8ekf/BNYuVVnmzbT2K+ub7R7F/vp63JX9/qHW8tn3bjt9aV1Hf9uOfPH8+XLf4e1ryO2tbb3ulvXz577jX45Tt33plyb59HK99/5LV02rH6TjTLkSc4Ql66qKc4Zs2jtV6OcsZ6SfQD0p6cdv2DklPLN8pIvZK2itJU1NT67drN0Ts1tDMatcEQLt+plx8X9Iu25fYHpX0Hkm3DqZaAIAz1XMPPSIatv9E0jeUT1u8KSIeHFjNAABnpK956BHxNUlfG1BdAAB94C4XABgSBDoADAkCHQCGBIEOAEOCQAeAIeEzvaOqr5PZs5L+r8e3b5f09ACrkwLavD7Q5vWhnza/NCImq3Za0UDvh+3piJha7XqsJNq8PtDm9WEl2syQCwAMCQIdAIZESoG+d7UrsApo8/pAm9eHs97mZMbQAQDdpdRDBwB0kUSg295t+2HbB2xft9r1GRTbN9mesf1AW9m5tm+3/Uix3FaU2/bfFb+DH9p+3erVvDe2X2z7W7b3237Q9rVF+dC2WZJsj9v+nu0fFO3+y6L8Ett3F+3+t+JrqGV7rNg+ULy+czXr3yvbNdv32r6t2B7q9kqS7cdt32/7PtvTRdmK/X2v+UBvexj1b0p6laSrbL9qdWs1MJ+TtHtZ2XWS7oiIXZLuKLalvP27ip89kj67QnUcpIakD0XEKyVdJumPi/+Ww9xmSToh6YqI+GVJr5G02/Zlkj4h6VNFuw9JuqbY/xpJhyLi5ZI+VeyXomsl7W/bHvb2tvxGRLymbYriyv1954/aWrs/kt4g6Rtt29dLun616zXA9u2U9EDb9sOSLizWL5T0cLH+z5Ku6rRfqj+SvirpynXW5glJ9yh//u7TkupF+eLfufJnDLyhWK8X+3m1636G7dxRhNcVkm5T/sjKoW1vW7sfl7R9WdmK/X2v+R66Oj+M+uJVqstKuCAinpSkYnl+UT5Uv4fin9WvlXS31kGbi+GH+yTNSLpd0qOSno2IRrFLe9sW2128/pyk81a2xn37tKQPS2oW2+dpuNvbEpK+aXtf8TxlaQX/vvt6wMUKOa2HUa8DQ/N7sL1J0pclfTAiDnd5cvrQtDkiFiS9xvZWSf8h6ZWddiuWSbfb9jskzUTEPttvbhV32HUo2rvM5RHxhO3zJd1u+0dd9h14u1PooZ/Ww6iHyFO2L5SkYjlTlA/F78H2iPIwvzkivlIUD3Wb20XEs5LuUn4NYavtVqeqvW2L7S5eP0fSMytb075cLumdth+X9EXlwy6f1vC2d1FEPFEsZ5R/cL9eK/j3nUKgr7eHUd8q6epi/Wrl48yt8t8vroxfJum51j/jUuG8K36jpP0R8cm2l4a2zZJke7Lomcv2BklvVX6x8FuS3l3strzdrd/HuyXdGcUgawoi4vqI2BERO5X//3pnRLxXQ9reFtsbbW9urUt6m6QHtJJ/36t9EeE0LzS8XdL/Kh93/Ohq12eA7bpF0pOS5pV/Wl+jfOzwDkmPFMtzi32tfLbPo5LulzS12vXvob2/pvyflD+UdF/x8/ZhbnPRjl+SdG/R7gck/UVR/jJJ35N0QNK/SxoryseL7QPF6y9b7Tb00fY3S7ptPbS3aN8Pip8HW1m1kn/f3CkKAEMihSEXAMBpINABYEgQ6AAwJAh0ABgSBDoADAkCHQCGBIEOAEOCQAeAIfH/ry1sfMzof9IAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(epoch_list,loss_list)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "After Training 7.980043411254883\n" ] } ], "source": [ "print(\"After Training \", model(torch.tensor([[4]], dtype = torch.float)).item())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Parameter containing:\n", " tensor([[ 1.9884]]), Parameter containing:\n", " tensor(1.00000e-02 *\n", " [ 2.6272])]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(model.parameters())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }