{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# What are Tensors?" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 29261998.9383\n", "1 23584624.4749\n", "2 21318274.0133\n", "3 19389745.5408\n", "4 16479856.1687\n", "5 12805039.2482\n", "6 9059166.91546\n", "7 6042659.8759\n", "8 3908408.60775\n", "9 2553920.39789\n", "10 1723204.06721\n", "11 1219705.10145\n", "12 906659.056268\n", "13 704582.301008\n", "14 567415.897123\n", "15 469502.722688\n", "16 396243.703489\n", "17 339183.787367\n", "18 293384.908371\n", "19 255753.24473\n", "20 224375.289442\n", "21 197817.587324\n", "22 175121.073496\n", "23 155577.723508\n", "24 138727.89154\n", "25 124054.575745\n", "26 111219.330545\n", "27 99943.0384346\n", "28 90002.3975585\n", "29 81206.7719005\n", "30 73409.0380627\n", "31 66473.3112012\n", "32 60296.3106408\n", "33 54785.7768329\n", "34 49859.0677676\n", "35 45441.2604793\n", "36 41474.757966\n", "37 37900.6254289\n", "38 34674.8838041\n", "39 31756.4912462\n", "40 29118.0035071\n", "41 26731.4581525\n", "42 24563.2300185\n", "43 22591.0640449\n", "44 20795.1155897\n", "45 19157.9008332\n", "46 17663.3262804\n", "47 16297.9786927\n", "48 15048.7541864\n", "49 13904.2761665\n", "50 12855.5370557\n", "51 11893.5831871\n", "52 11009.8840228\n", "53 10198.5198944\n", "54 9452.97741562\n", "55 8766.63119037\n", "56 8134.73416199\n", "57 7552.22098812\n", "58 7015.3292248\n", "59 6519.75327917\n", "60 6061.87657874\n", "61 5638.66314253\n", "62 5247.43856354\n", "63 4885.45681905\n", "64 4550.32597631\n", "65 4239.90233531\n", "66 3952.19792216\n", "67 3685.42375585\n", "68 3437.99282102\n", "69 3208.32364349\n", "70 2995.0123591\n", "71 2796.90365889\n", "72 2612.77458434\n", "73 2441.61153781\n", "74 2282.34987655\n", "75 2134.12359301\n", "76 1996.15517682\n", "77 1867.71870124\n", "78 1748.02217447\n", "79 1636.43626907\n", "80 1532.41129335\n", "81 1435.3645166\n", "82 1344.82770496\n", "83 1260.38274162\n", "84 1181.52556187\n", "85 1107.86359066\n", "86 1039.03117769\n", "87 974.722224799\n", "88 914.826976987\n", "89 858.928900525\n", "90 806.610252238\n", "91 757.653403984\n", "92 711.841667209\n", "93 668.955674145\n", "94 628.786703973\n", "95 591.166198315\n", "96 555.91219558\n", "97 522.862764788\n", "98 491.872428375\n", "99 462.809720147\n", "100 435.556280818\n", "101 409.988779832\n", "102 385.991214644\n", "103 363.455455472\n", "104 342.303748325\n", "105 322.439528753\n", "106 303.77815213\n", "107 286.25152638\n", "108 269.777995546\n", "109 254.289212316\n", "110 239.728038759\n", "111 226.036798421\n", "112 213.162389858\n", "113 201.050099682\n", "114 189.657042403\n", "115 178.930752507\n", "116 168.839022075\n", "117 159.335671601\n", "118 150.386617916\n", "119 141.960890819\n", "120 134.028933026\n", "121 126.552752973\n", "122 119.508618317\n", "123 112.871599105\n", "124 106.615401142\n", "125 100.718185857\n", "126 95.1601396318\n", "127 89.9172743676\n", "128 84.9725273635\n", "129 80.3094850327\n", "130 75.9095180326\n", "131 71.7595557249\n", "132 67.8449877975\n", "133 64.1482616471\n", "134 60.6589306595\n", "135 57.3666798893\n", "136 54.257828578\n", "137 51.3223084883\n", "138 48.5512904041\n", "139 45.9332963261\n", "140 43.4597612423\n", "141 41.1234738552\n", "142 38.9164620437\n", "143 36.8310989551\n", "144 34.8614349703\n", "145 32.9989944077\n", "146 31.2389494061\n", "147 29.5751657677\n", "148 28.0018533674\n", "149 26.5143723378\n", "150 25.1090290368\n", "151 23.7789949676\n", "152 22.5209671013\n", "153 21.3316066277\n", "154 20.2065400681\n", "155 19.1417900359\n", "156 18.1352582169\n", "157 17.1825260105\n", "158 16.2808499372\n", "159 15.4275713396\n", "160 14.6199024945\n", "161 13.8558667482\n", "162 13.1326730867\n", "163 12.4476228187\n", "164 11.7992270904\n", "165 11.1852772281\n", "166 10.6039207665\n", "167 10.053368564\n", "168 9.53252947626\n", "169 9.0388943525\n", "170 8.57125552508\n", "171 8.12835146959\n", "172 7.70876529188\n", "173 7.31119727339\n", "174 6.93478135637\n", "175 6.57803922866\n", "176 6.23990443082\n", "177 5.91946669864\n", "178 5.61584117512\n", "179 5.32809852758\n", "180 5.05546438442\n", "181 4.79691466999\n", "182 4.55190170806\n", "183 4.31959471325\n", "184 4.0993672564\n", "185 3.89053306571\n", "186 3.69272848442\n", "187 3.5049605073\n", "188 3.32690498544\n", "189 3.15811754743\n", "190 2.99800272266\n", "191 2.84612083648\n", "192 2.70218120603\n", "193 2.56559847877\n", "194 2.43598627756\n", "195 2.31303903422\n", "196 2.19641792427\n", "197 2.08576686945\n", "198 1.98084058378\n", "199 1.88122059939\n", "200 1.78671463098\n", "201 1.69700035\n", "202 1.61185950487\n", "203 1.53108818397\n", "204 1.45446845079\n", "205 1.38168499608\n", "206 1.31259415864\n", "207 1.24704228715\n", "208 1.18479076767\n", "209 1.12569966367\n", "210 1.06963559835\n", "211 1.01638970173\n", "212 0.965825057948\n", "213 0.91780839967\n", "214 0.872233916761\n", "215 0.828936944529\n", "216 0.787844706919\n", "217 0.748810185424\n", "218 0.711734247058\n", "219 0.676516900494\n", "220 0.643068334746\n", "221 0.611308116249\n", "222 0.581137457877\n", "223 0.552687236904\n", "224 0.525699566311\n", "225 0.500057688773\n", "226 0.475694935078\n", "227 0.452550412934\n", "228 0.430568156872\n", "229 0.409673907048\n", "230 0.389806573715\n", "231 0.370932381921\n", "232 0.352983443987\n", "233 0.335920213981\n", "234 0.319708565298\n", "235 0.304296629709\n", "236 0.289633110114\n", "237 0.275692202732\n", "238 0.262442493992\n", "239 0.249835379143\n", "240 0.23784557329\n", "241 0.226452367874\n", "242 0.215608440517\n", "243 0.205291982434\n", "244 0.195480341079\n", "245 0.186148690515\n", "246 0.177270139497\n", "247 0.168825227757\n", "248 0.160792294175\n", "249 0.153144502297\n", "250 0.145867509368\n", "251 0.138945619351\n", "252 0.132357061522\n", "253 0.126087562687\n", "254 0.120122854459\n", "255 0.114444458516\n", "256 0.109038268771\n", "257 0.103892027299\n", "258 0.098995553132\n", "259 0.0943323574882\n", "260 0.0898952639829\n", "261 0.0856698048795\n", "262 0.0816465504529\n", "263 0.0778154553274\n", "264 0.0741678325738\n", "265 0.0706947627847\n", "266 0.0673880124412\n", "267 0.064238458385\n", "268 0.0612391624809\n", "269 0.0583816564353\n", "270 0.0556601823579\n", "271 0.053069427739\n", "272 0.0506003196627\n", "273 0.0482491518176\n", "274 0.0460090812469\n", "275 0.0438747302262\n", "276 0.0418408493085\n", "277 0.0399035163839\n", "278 0.0380575463064\n", "279 0.036299097922\n", "280 0.0346233229447\n", "281 0.0330262284609\n", "282 0.031503877848\n", "283 0.0300531029465\n", "284 0.0286707587626\n", "285 0.0273528067848\n", "286 0.0260970026726\n", "287 0.0249004123643\n", "288 0.0237589699625\n", "289 0.0226708018116\n", "290 0.021633820431\n", "291 0.0206448922209\n", "292 0.0197019772237\n", "293 0.0188035993027\n", "294 0.0179464410754\n", "295 0.0171291505117\n", "296 0.0163498514344\n", "297 0.0156066053911\n", "298 0.0148977251492\n", "299 0.014221758725\n", "300 0.0135773068144\n", "301 0.012962249886\n", "302 0.0123757145187\n", "303 0.0118163369769\n", "304 0.0112825023947\n", "305 0.0107732945335\n", "306 0.0102876221896\n", "307 0.00982427812751\n", "308 0.00938206972078\n", "309 0.00896026628877\n", "310 0.00855771803633\n", "311 0.00817356616853\n", "312 0.00780707377096\n", "313 0.00745720068811\n", "314 0.00712345076459\n", "315 0.00680492145638\n", "316 0.0065007764508\n", "317 0.00621051064797\n", "318 0.0059335532637\n", "319 0.00566905975894\n", "320 0.0054165748066\n", "321 0.00517568064705\n", "322 0.00494562595128\n", "323 0.00472595134395\n", "324 0.004516286571\n", "325 0.00431606748656\n", "326 0.00412485701368\n", "327 0.00394233640381\n", "328 0.00376798636873\n", "329 0.00360153498254\n", "330 0.0034425879938\n", "331 0.00329072705371\n", "332 0.00314572164753\n", "333 0.00300724330687\n", "334 0.00287492825219\n", "335 0.00274855064807\n", "336 0.0026278818479\n", "337 0.0025125794175\n", "338 0.00240242046594\n", "339 0.00229721881855\n", "340 0.00219666068919\n", "341 0.00210058824046\n", "342 0.00200881735521\n", "343 0.00192110469496\n", "344 0.00183730831055\n", "345 0.00175725174098\n", "346 0.00168073646648\n", "347 0.0016076003348\n", "348 0.00153771359982\n", "349 0.00147090447885\n", "350 0.0014070576511\n", "351 0.00134603797493\n", "352 0.00128770467549\n", "353 0.00123197628558\n", "354 0.00117868085568\n", "355 0.00112772121847\n", "356 0.00107901471112\n", "357 0.00103244508717\n", "358 0.000987915862711\n", "359 0.000945353945088\n", "360 0.000904654397546\n", "361 0.000865735304185\n", "362 0.000828532206376\n", "363 0.000792939748134\n", "364 0.000758904541284\n", "365 0.000726362131352\n", "366 0.000695237114628\n", "367 0.000665469366721\n", "368 0.000636998610063\n", "369 0.000609763414406\n", "370 0.000583720319318\n", "371 0.000558805930044\n", "372 0.000534968044314\n", "373 0.000512173467121\n", "374 0.000490358710071\n", "375 0.000469488609342\n", "376 0.000449526687963\n", "377 0.000430422479868\n", "378 0.000412146230894\n", "379 0.00039466322235\n", "380 0.000377932326339\n", "381 0.000361920938159\n", "382 0.00034659873668\n", "383 0.000331934151076\n", "384 0.000317903619901\n", "385 0.000304474691838\n", "386 0.000291623739083\n", "387 0.000279323742851\n", "388 0.000267548222855\n", "389 0.000256278541801\n", "390 0.000245493571213\n", "391 0.000235166229351\n", "392 0.000225284297347\n", "393 0.00021582135579\n", "394 0.000206761919814\n", "395 0.00019809028382\n", "396 0.000189786130103\n", "397 0.000181836438281\n", "398 0.00017422642703\n", "399 0.000166938234581\n", "400 0.000159960525744\n", "401 0.000153278944243\n", "402 0.000146880325055\n", "403 0.000140753610841\n", "404 0.000134885748538\n", "405 0.000129266522935\n", "406 0.000123884484981\n", "407 0.000118729332324\n", "408 0.000113793045321\n", "409 0.000109064256962\n", "410 0.000104535895706\n", "411 0.000100198028952\n", "412 9.60419382144e-05\n", "413 9.20618879821e-05\n", "414 8.82483093451e-05\n", "415 8.45950482713e-05\n", "416 8.10960474992e-05\n", "417 7.77430445771e-05\n", "418 7.45311414851e-05\n", "419 7.14534254255e-05\n", "420 6.85046590699e-05\n", "421 6.56795112e-05\n", "422 6.29724733817e-05\n", "423 6.03785890755e-05\n", "424 5.7892722874e-05\n", "425 5.55106718637e-05\n", "426 5.32282361494e-05\n", "427 5.10406825078e-05\n", "428 4.89448352152e-05\n", "429 4.6935677075e-05\n", "430 4.50100354238e-05\n", "431 4.31646099667e-05\n", "432 4.13956138296e-05\n", "433 3.97004647487e-05\n", "434 3.80755850664e-05\n", "435 3.65178860272e-05\n", "436 3.50247961582e-05\n", "437 3.35933714026e-05\n", "438 3.22215532328e-05\n", "439 3.09063452968e-05\n", "440 2.9645585819e-05\n", "441 2.84367405606e-05\n", "442 2.72777668962e-05\n", "443 2.61667329407e-05\n", "444 2.51013501797e-05\n", "445 2.40802574239e-05\n", "446 2.31009022636e-05\n", "447 2.21618921174e-05\n", "448 2.12615223444e-05\n", "449 2.03981352425e-05\n", "450 1.95703954971e-05\n", "451 1.877653602e-05\n", "452 1.8015266419e-05\n", "453 1.72851702352e-05\n", "454 1.65850141998e-05\n", "455 1.59137366858e-05\n", "456 1.52708808101e-05\n", "457 1.46533379366e-05\n", "458 1.40608956291e-05\n", "459 1.34927646819e-05\n", "460 1.29478110483e-05\n", "461 1.24251453509e-05\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "462 1.19238477354e-05\n", "463 1.14429095618e-05\n", "464 1.098163353e-05\n", "465 1.05391013727e-05\n", "466 1.01146537383e-05\n", "467 9.70748474576e-06\n", "468 9.31687028387e-06\n", "469 8.94210531477e-06\n", "470 8.5825945558e-06\n", "471 8.23769054161e-06\n", "472 7.90680480541e-06\n", "473 7.58936772541e-06\n", "474 7.28473974087e-06\n", "475 6.99252291117e-06\n", "476 6.71209205042e-06\n", "477 6.4430614509e-06\n", "478 6.18492658291e-06\n", "479 5.93721360396e-06\n", "480 5.69951164811e-06\n", "481 5.47142875981e-06\n", "482 5.25255539132e-06\n", "483 5.04254555597e-06\n", "484 4.84100017805e-06\n", "485 4.64756344346e-06\n", "486 4.46194790682e-06\n", "487 4.28379534403e-06\n", "488 4.11286394305e-06\n", "489 3.9487860965e-06\n", "490 3.79132018994e-06\n", "491 3.64017549663e-06\n", "492 3.49512942501e-06\n", "493 3.35590547435e-06\n", "494 3.22230253639e-06\n", "495 3.0940368917e-06\n", "496 2.97092430139e-06\n", "497 2.85274840761e-06\n", "498 2.73931894319e-06\n", "499 2.63045198276e-06\n" ] } ], "source": [ "# -*- coding: utf-8 -*-\n", "import numpy as np\n", "\n", "# N is batch size; D_in is input dimension;\n", "# H is hidden dimension; D_out is output dimension.\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "# Create random input and output data\n", "x = np.random.randn(N, D_in)\n", "y = np.random.randn(N, D_out)\n", "\n", "# Randomly initialize weights\n", "w1 = np.random.randn(D_in, H)\n", "w2 = np.random.randn(H, D_out)\n", "\n", "learning_rate = 1e-6\n", "for t in range(500):\n", " # Forward pass: compute predicted y\n", " h = x.dot(w1)\n", " h_relu = np.maximum(h, 0)\n", " y_pred = h_relu.dot(w2)\n", "\n", " # Compute and print loss\n", " loss = np.square(y_pred - y).sum()\n", " print(t, loss)\n", "\n", " # Backprop to compute gradients of w1 and w2 with respect to loss\n", " grad_y_pred = 2.0 * (y_pred - y)\n", " grad_w2 = h_relu.T.dot(grad_y_pred)\n", " grad_h_relu = grad_y_pred.dot(w2.T)\n", " grad_h = grad_h_relu.copy()\n", " grad_h[h < 0] = 0\n", " grad_w1 = x.T.dot(grad_h)\n", "\n", " # Update weights\n", " w1 -= learning_rate * grad_w1\n", " w2 -= learning_rate * grad_w2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch Tensors\n", "\n", "Clearly modern deep neural networks are in need of more than what our beloved numpy can offer.\n", "\n", "Here we introduce the most fundamental PyTorch concept: the *Tensor*. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional array, and PyTorch provides many functions for operating on these Tensors. Like numpy arrays, PyTorch Tensors do not know anything about deep learning or computational graphs or gradients; they are a generic tool for scientific computing.\n", "\n", "However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it to a new datatype.\n", "\n", "Here we use PyTorch Tensors to fit a two-layer network to random data. Like the numpy example above we need to manually implement the forward and backward passes through the network:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 28214897.691271067\n", "1 25380405.792548403\n", "2 26288556.067442656\n", "3 27187362.93774879\n", "4 25326431.49736169\n", "5 20070726.423171997\n", "6 13438367.445337629\n", "7 7935834.941528201\n", "8 4453037.240495725\n", "9 2567232.1655493514\n", "10 1604364.933374187\n", "11 1106295.9881061036\n", "12 831370.3628886025\n", "13 664479.3320915042\n", "14 552383.0191260207\n", "15 470307.21917449264\n", "16 406323.70261433884\n", "17 354377.92758273566\n", "18 311124.330613622\n", "19 274515.3858363455\n", "20 243215.3152763464\n", "21 216254.64485477417\n", "22 192876.48988408546\n", "23 172511.55349881982\n", "24 154696.59197369026\n", "25 139077.64419030334\n", "26 125326.40331724554\n", "27 113168.27359832195\n", "28 102388.44114990594\n", "29 92802.5217316554\n", "30 84252.873688431\n", "31 76614.83165994265\n", "32 69777.57502200827\n", "33 63643.55059441269\n", "34 58122.45377116208\n", "35 53149.5297017009\n", "36 48661.48595352931\n", "37 44605.11924878636\n", "38 40936.86403570355\n", "39 37612.1624785422\n", "40 34589.84976270138\n", "41 31842.023658028404\n", "42 29339.426460701798\n", "43 27055.76113430076\n", "44 24971.357019224655\n", "45 23066.443739543673\n", "46 21322.47401335786\n", "47 19723.635119302293\n", "48 18257.593847584038\n", "49 16911.857851812914\n", "50 15676.901621120574\n", "51 14541.234468931158\n", "52 13495.409479309936\n", "53 12531.688689953091\n", "54 11642.912229433834\n", "55 10823.396586809435\n", "56 10066.756259321584\n", "57 9368.032180714887\n", "58 8722.69606901206\n", "59 8125.648766315389\n", "60 7573.201240118957\n", "61 7061.110367171321\n", "62 6586.37116674181\n", "63 6146.2882079665205\n", "64 5738.038465707713\n", "65 5359.140022478372\n", "66 5007.218182836571\n", "67 4679.821472426938\n", "68 4375.545563822141\n", "69 4092.5672241640546\n", "70 3829.2538478331044\n", "71 3583.9982694811915\n", "72 3355.5044128053214\n", "73 3142.522992099788\n", "74 2944.0477814121896\n", "75 2758.9096522632453\n", "76 2586.098822437947\n", "77 2424.783757172412\n", "78 2274.162424382146\n", "79 2133.472267201043\n", "80 2001.943758391455\n", "81 1879.0327707577635\n", "82 1764.1420179859847\n", "83 1656.6765891071607\n", "84 1556.0836619963645\n", "85 1461.960692876407\n", "86 1373.8333980444747\n", "87 1291.3251499255507\n", "88 1214.077519632569\n", "89 1141.6407961478803\n", "90 1073.7122116708274\n", "91 1010.047731572995\n", "92 950.3514467849104\n", "93 894.3273352336082\n", "94 841.7842243861196\n", "95 792.4734904819334\n", "96 746.1964596283701\n", "97 702.7443149700078\n", "98 661.9300986860596\n", "99 623.5698773736967\n", "100 587.5490628035759\n", "101 553.7059624342619\n", "102 521.87459074208\n", "103 491.95096067483627\n", "104 463.81437045894427\n", "105 437.3619707183019\n", "106 412.4634959739533\n", "107 389.03197571304185\n", "108 366.98684185984854\n", "109 346.2511910920458\n", "110 326.7168373228138\n", "111 308.3204062757866\n", "112 291.0116837719783\n", "113 274.708204616996\n", "114 259.3530469133465\n", "115 244.89019768539188\n", "116 231.2530311334451\n", "117 218.40625489775357\n", "118 206.29494907575645\n", "119 194.87408803031087\n", "120 184.10918123054637\n", "121 173.95670181258504\n", "122 164.3796226045149\n", "123 155.34510760042053\n", "124 146.82159396161046\n", "125 138.78259947243896\n", "126 131.19334880439965\n", "127 124.03206751540091\n", "128 117.27145023435516\n", "129 110.8922230492454\n", "130 104.86954096430226\n", "131 99.17821632714708\n", "132 93.80321977845797\n", "133 88.73180206294792\n", "134 83.93784170194142\n", "135 79.41110559695994\n", "136 75.13158900665832\n", "137 71.09239467909009\n", "138 67.27303860367512\n", "139 63.661083760649944\n", "140 60.24877063365615\n", "141 57.02609438798197\n", "142 53.97754052526591\n", "143 51.095064315871184\n", "144 48.37058476978203\n", "145 45.794669434952766\n", "146 43.357149485835066\n", "147 41.053389754456276\n", "148 38.87377426878407\n", "149 36.81266361362863\n", "150 34.86269390504242\n", "151 33.019418138638315\n", "152 31.27372891445308\n", "153 29.623019411905716\n", "154 28.06002445682043\n", "155 26.58085827334935\n", "156 25.182066968294635\n", "157 23.856794249429644\n", "158 22.603709343965306\n", "159 21.416727958537756\n", "160 20.294056803979785\n", "161 19.230226081371868\n", "162 18.22286712818012\n", "163 17.26986351281531\n", "164 16.36696748965665\n", "165 15.512043060681435\n", "166 14.70277965469339\n", "167 13.936395793035047\n", "168 13.21058501636503\n", "169 12.522788125846773\n", "170 11.871329149475358\n", "171 11.254089594353673\n", "172 10.669772994995135\n", "173 10.115961444046548\n", "174 9.591341026183215\n", "175 9.094685662630582\n", "176 8.623675345308872\n", "177 8.177212815510206\n", "178 7.754271122965591\n", "179 7.354052512118528\n", "180 6.974105103205304\n", "181 6.613862763094033\n", "182 6.273167739637028\n", "183 5.949956651557034\n", "184 5.643680276344654\n", "185 5.353149802081873\n", "186 5.077408776123896\n", "187 4.8164800713806315\n", "188 4.568759942421966\n", "189 4.334537105201893\n", "190 4.112015826773195\n", "191 3.9009179414881707\n", "192 3.7012154731272986\n", "193 3.511612145634661\n", "194 3.331681329765537\n", "195 3.1611259816769106\n", "196 2.9996718148188464\n", "197 2.8461790457236766\n", "198 2.7007109757500025\n", "199 2.562890156220522\n", "200 2.4321240546360414\n", "201 2.308078948186587\n", "202 2.1904870139545665\n", "203 2.0787757790351513\n", "204 1.972721255352237\n", "205 1.8724816279031096\n", "206 1.776974327720918\n", "207 1.6867990743287722\n", "208 1.601016306899063\n", "209 1.5197483114327683\n", "210 1.442438605099003\n", "211 1.369157533522884\n", "212 1.2998218626227995\n", "213 1.2339273899186163\n", "214 1.17146151531626\n", "215 1.1119642766772915\n", "216 1.0557099815853666\n", "217 1.0022163466049716\n", "218 0.9514815204819733\n", "219 0.9033794087224507\n", "220 0.8576643044382202\n", "221 0.8143504655566967\n", "222 0.7732258198716373\n", "223 0.7342158760394923\n", "224 0.6971644104382229\n", "225 0.6619967066271535\n", "226 0.6285948940725881\n", "227 0.5968096996362284\n", "228 0.5667985106167974\n", "229 0.5382560311909526\n", "230 0.5111128765857158\n", "231 0.48532747688128897\n", "232 0.4609265227778163\n", "233 0.4378205148075356\n", "234 0.4157447156268157\n", "235 0.39488582392669613\n", "236 0.3749829100757234\n", "237 0.35613537196222556\n", "238 0.3382650005067456\n", "239 0.32128029946794356\n", "240 0.30518656196033334\n", "241 0.2898877071115251\n", "242 0.2753985487457893\n", "243 0.26155083612243324\n", "244 0.2484203549989168\n", "245 0.23601150551252115\n", "246 0.22414258684202437\n", "247 0.21293119006192796\n", "248 0.20228290632133167\n", "249 0.19217315565631465\n", "250 0.18254562652399353\n", "251 0.17339564711978817\n", "252 0.16472807684149715\n", "253 0.15650172744047652\n", "254 0.14871153441717966\n", "255 0.141252334180054\n", "256 0.13420798837495873\n", "257 0.1275148040973093\n", "258 0.12115617519511047\n", "259 0.11513308130563794\n", "260 0.10940513697614086\n", "261 0.10392053471471474\n", "262 0.09873369084591621\n", "263 0.0938313782580984\n", "264 0.08916601624925224\n", "265 0.08473324384660685\n", "266 0.08052892574508519\n", "267 0.0765146490751043\n", "268 0.07271481811263403\n", "269 0.06908563553494673\n", "270 0.06569151383855609\n", "271 0.06243397875781054\n", "272 0.059288574380887527\n", "273 0.056329070592443964\n", "274 0.053548219642252315\n", "275 0.050900103240740124\n", "276 0.04838548394463982\n", "277 0.0459671132623376\n", "278 0.0437052950628003\n", "279 0.04152813333603267\n", "280 0.03946723512516925\n", "281 0.037533394479267956\n", "282 0.03568910868574027\n", "283 0.03390143972920545\n", "284 0.03224024862347741\n", "285 0.030643039031359953\n", "286 0.029140048215710534\n", "287 0.027684004166977916\n", "288 0.026314751884853438\n", "289 0.025033289714943285\n", "290 0.023797433226710796\n", "291 0.02262336258232206\n", "292 0.021504991155510966\n", "293 0.02045539104525429\n", "294 0.01946322911578169\n", "295 0.018508058725132337\n", "296 0.017591603927018307\n", "297 0.0167319201745022\n", "298 0.01589690324248605\n", "299 0.015139217233493873\n", "300 0.014406465492699028\n", "301 0.013696118249792555\n", "302 0.013027044818570366\n", "303 0.012396137516559824\n", "304 0.011789959562407637\n", "305 0.011218404833911066\n", "306 0.010683310567349003\n", "307 0.01017237445438407\n", "308 0.009680823881585532\n", "309 0.009207409667161714\n", "310 0.008765478168209662\n", "311 0.008343950057827287\n", "312 0.007942077827435279\n", "313 0.007565309388805952\n", "314 0.007207875892412341\n", "315 0.006865805795088331\n", "316 0.006540033828446479\n", "317 0.006230358566105432\n", "318 0.005935793194971506\n", "319 0.0056569668626278435\n", "320 0.005392627609972722\n", "321 0.005142164502647928\n", "322 0.004903280543345323\n", "323 0.0046710665883187286\n", "324 0.0044596153714386855\n", "325 0.004254983110094868\n", "326 0.004053112385123736\n", "327 0.0038716554215878496\n", "328 0.003693314754828869\n", "329 0.0035276847824143864\n", "330 0.0033722945089047496\n", "331 0.003224334560254949\n", "332 0.0030797497759889048\n", "333 0.0029394659957044933\n", "334 0.002811574650088522\n", "335 0.002687472415199954\n", "336 0.0025703362047908573\n", "337 0.002459747191192241\n", "338 0.002356102818997119\n", "339 0.002255833670969709\n", "340 0.002162025463318118\n", "341 0.0020706938958603427\n", "342 0.0019804429030327864\n", "343 0.001900624922958949\n", "344 0.0018219768719105467\n", "345 0.0017470823932357327\n", "346 0.00167403269153521\n", "347 0.0016074784101616224\n", "348 0.0015434483287186662\n", "349 0.0014800083688997212\n", "350 0.0014237042363174357\n", "351 0.0013673234525823919\n", "352 0.001313357088474909\n", "353 0.001260321802043718\n", "354 0.0012110123492108382\n", "355 0.0011649015229149295\n", "356 0.0011204107047679823\n", "357 0.0010796027719008894\n", "358 0.0010387537130391866\n", "359 0.0010000977633572994\n", "360 0.0009621068961130352\n", "361 0.0009280829840783711\n", "362 0.0008925718210586187\n", "363 0.000861557630299048\n", "364 0.0008311890107142728\n", "365 0.0008009163000355368\n", "366 0.0007749579994773548\n", "367 0.0007484465107412408\n", "368 0.0007222093569896337\n", "369 0.0006973693140788217\n", "370 0.0006740144106807122\n", "371 0.0006497241727248526\n", "372 0.0006289278786107411\n", "373 0.0006074793578099702\n", "374 0.0005876308963297938\n", "375 0.0005683213433220757\n", "376 0.0005498372268836205\n", "377 0.0005331871459254289\n", "378 0.000514447040892041\n", "379 0.0004995928681320039\n", "380 0.0004830170838928116\n", "381 0.00046774268823562837\n", "382 0.0004539691694929737\n", "383 0.000439916381955785\n", "384 0.0004264477815108525\n", "385 0.00041461180957180765\n", "386 0.0004007517174830638\n", "387 0.0003885748281060031\n", "388 0.00037587470802746825\n", "389 0.00036685178849288347\n", "390 0.0003563211267975097\n", "391 0.00034580960975350017\n", "392 0.0003373833671450055\n", "393 0.00032734962350788877\n", "394 0.0003172536302761264\n", "395 0.0003087773417211892\n", "396 0.0003003043221507795\n", "397 0.00029157922913469747\n", "398 0.0002838640100005785\n", "399 0.0002773871556948082\n", "400 0.0002693207433572403\n", "401 0.00026274296119056795\n", "402 0.00025629517436694116\n", "403 0.00024865622459442627\n", "404 0.00024191564002837285\n", "405 0.00023580017504160056\n", "406 0.00023021821139053433\n", "407 0.00022492894727565993\n", "408 0.0002187235492969869\n", "409 0.0002130760149380989\n", "410 0.00020846465683301008\n", "411 0.00020339997672215449\n", "412 0.00019872765675885856\n", "413 0.0001944256389800475\n", "414 0.00019000826152877626\n", "415 0.00018525978015550282\n", "416 0.00018077204148654602\n", "417 0.00017703581449693417\n", "418 0.00017288089631457837\n", "419 0.00016859326936614905\n", "420 0.00016520088735737237\n", "421 0.00016153575890172356\n", "422 0.00015850395108210624\n", "423 0.00015459131808931437\n", "424 0.00015143964321755188\n", "425 0.00014752541333357128\n", "426 0.00014445116156402982\n", "427 0.000141356974335205\n", "428 0.00013808374274071355\n", "429 0.0001353787969475273\n", "430 0.00013243990439421038\n", "431 0.00012966755098536842\n", "432 0.00012675479128350375\n", "433 0.0001242446317876872\n", "434 0.00012147723341672523\n", "435 0.00011897251544801257\n", "436 0.00011670839108181286\n", "437 0.00011433415206264785\n", "438 0.00011219214203966876\n", "439 0.00011027887981625295\n", "440 0.00010749669199967837\n", "441 0.00010553591091314041\n", "442 0.00010367609742491235\n", "443 0.00010109258945595334\n", "444 9.932886336866398e-05\n", "445 9.749088251564952e-05\n", "446 9.520718197066069e-05\n", "447 9.345653584887093e-05\n", "448 9.182967794658936e-05\n", "449 8.988625574320175e-05\n", "450 8.847123125822753e-05\n", "451 8.713054826821331e-05\n", "452 8.549155933869346e-05\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "453 8.406219949108618e-05\n", "454 8.268637916582222e-05\n", "455 8.105226215794625e-05\n", "456 7.962860940187444e-05\n", "457 7.804827419660709e-05\n", "458 7.654020379489757e-05\n", "459 7.561118885025808e-05\n", "460 7.45004278926431e-05\n", "461 7.300097409869422e-05\n", "462 7.156340142840112e-05\n", "463 7.07410268018932e-05\n", "464 6.95669895526968e-05\n", "465 6.816465251008319e-05\n", "466 6.710678036121742e-05\n", "467 6.58694634003143e-05\n", "468 6.480972211669878e-05\n", "469 6.364145772613794e-05\n", "470 6.294719978224006e-05\n", "471 6.208284610231818e-05\n", "472 6.100992742692768e-05\n", "473 6.030514397714626e-05\n", "474 5.950513809528657e-05\n", "475 5.863842784931128e-05\n", "476 5.7600118031853054e-05\n", "477 5.656285652166915e-05\n", "478 5.575245490059555e-05\n", "479 5.4907743583479385e-05\n", "480 5.450484055091742e-05\n", "481 5.38567654812111e-05\n", "482 5.3109503177953266e-05\n", "483 5.230015415125244e-05\n", "484 5.148949327894725e-05\n", "485 5.0706583465745525e-05\n", "486 4.996987102404149e-05\n", "487 4.9266432966016405e-05\n", "488 4.862910026705303e-05\n", "489 4.8046019087769065e-05\n", "490 4.737298535380241e-05\n", "491 4.6690329870216485e-05\n", "492 4.610047588500532e-05\n", "493 4.554242994432578e-05\n", "494 4.503223066189277e-05\n", "495 4.456991744473948e-05\n", "496 4.387548054092527e-05\n", "497 4.330432420614205e-05\n", "498 4.256601870983312e-05\n", "499 4.199306232773037e-05\n" ] } ], "source": [ "import torch\n", "dtype = torch.FloatTensor\n", "# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU\n", "\n", "# N is batch size; D_in is input dimension;\n", "# H is hidden dimension; D_out is output dimension.\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "# Create random input and output data\n", "x = torch.randn(N, D_in).type(dtype)\n", "y = torch.randn(N, D_out).type(dtype)\n", "\n", "# Randomly initialize weights\n", "w1 = torch.randn(D_in, H).type(dtype)\n", "w2 = torch.randn(H, D_out).type(dtype)\n", "\n", "learning_rate = 1e-6\n", "for t in range(500):\n", " # Forward pass: compute predicted y\n", " h = x.mm(w1)\n", " h_relu = h.clamp(min=0)\n", " y_pred = h_relu.mm(w2)\n", "\n", " # Compute and print loss\n", " loss = (y_pred - y).pow(2).sum()\n", " print(t, loss)\n", "\n", " # Backprop to compute gradients of w1 and w2 with respect to loss\n", " grad_y_pred = 2.0 * (y_pred - y)\n", " grad_w2 = h_relu.t().mm(grad_y_pred)\n", " grad_h_relu = grad_y_pred.mm(w2.t())\n", " grad_h = grad_h_relu.clone()\n", " grad_h[h < 0] = 0\n", " grad_w1 = x.t().mm(grad_h)\n", "\n", " # Update weights using gradient descent\n", " w1 -= learning_rate * grad_w1\n", " w2 -= learning_rate * grad_w2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Autograd\n", "PyTorch variables and autograd. Autograd package provides cool functionality as the forward pass of your network defines the computational graph; nodes in the graph will be Tensors and edges will be functions that produce output Tensors from input Tensors. Backprop through this graph then allows us to easily compue gradients.\n", "\n", "Here we wrap the PyTorch Tensor in a Variable object; where Vaiabel represents a node in the computational graph. if x is a variable then x.data is a Tensor and x.grad is another Varialble holding the gradient of x w.r.t to some scalar value.\n", "\n", "PyTorch Variables have samer API as PyTorch Tensots: any operation that you can do with Tensor, also works fine with Variables, difference only being that the Variable defines a computational graph, allowing us to automatically compute gradients." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 35878500.0\n", "1 33502642.0\n", "2 31638146.0\n", "3 26216880.0\n", "4 18097450.0\n", "5 10643111.0\n", "6 5868223.0\n", "7 3356485.0\n", "8 2129793.5\n", "9 1508282.875\n", "10 1160753.375\n", "11 940967.3125\n", "12 785975.375\n", "13 668166.125\n", "14 574389.8125\n", "15 497736.9375\n", "16 433985.8125\n", "17 380330.0\n", "18 334801.71875\n", "19 295919.6875\n", "20 262469.40625\n", "21 233624.078125\n", "22 208602.84375\n", "23 186785.21875\n", "24 167705.6875\n", "25 150947.25\n", "26 136179.03125\n", "27 123118.4375\n", "28 111543.015625\n", "29 101252.5\n", "30 92084.828125\n", "31 83890.2109375\n", "32 76550.203125\n", "33 69970.4609375\n", "34 64056.62109375\n", "35 58728.921875\n", "36 53917.6015625\n", "37 49565.42578125\n", "38 45631.03515625\n", "39 42059.48046875\n", "40 38813.0390625\n", "41 35858.09765625\n", "42 33163.74609375\n", "43 30702.73828125\n", "44 28452.41796875\n", "45 26393.0234375\n", "46 24505.55078125\n", "47 22772.90234375\n", "48 21181.724609375\n", "49 19717.416015625\n", "50 18369.517578125\n", "51 17127.080078125\n", "52 15980.390625\n", "53 14921.2587890625\n", "54 13942.4697265625\n", "55 13036.6015625\n", "56 12197.0810546875\n", "57 11419.048828125\n", "58 10696.9755859375\n", "59 10026.7861328125\n", "60 9403.92578125\n", "61 8824.4482421875\n", "62 8285.025390625\n", "63 7782.73583984375\n", "64 7314.56591796875\n", "65 6878.09619140625\n", "66 6470.642578125\n", "67 6090.18603515625\n", "68 5734.34912109375\n", "69 5401.63623046875\n", "70 5090.5068359375\n", "71 4799.25390625\n", "72 4526.57177734375\n", "73 4271.24853515625\n", "74 4031.948974609375\n", "75 3807.61865234375\n", "76 3597.07275390625\n", "77 3399.272216796875\n", "78 3213.479248046875\n", "79 3038.89013671875\n", "80 2874.7353515625\n", "81 2720.358642578125\n", "82 2575.041015625\n", "83 2438.2685546875\n", "84 2309.4228515625\n", "85 2188.03125\n", "86 2073.676513671875\n", "87 1965.752685546875\n", "88 1863.9710693359375\n", "89 1767.9381103515625\n", "90 1677.3052978515625\n", "91 1591.765869140625\n", "92 1510.904541015625\n", "93 1434.5537109375\n", "94 1362.3419189453125\n", "95 1294.068115234375\n", "96 1229.5435791015625\n", "97 1168.46240234375\n", "98 1110.672607421875\n", "99 1055.9576416015625\n", "100 1004.15771484375\n", "101 955.1239013671875\n", "102 908.6265869140625\n", "103 864.5731811523438\n", "104 822.8087158203125\n", "105 783.23779296875\n", "106 745.7089233398438\n", "107 710.0859985351562\n", "108 676.2913818359375\n", "109 644.2283935546875\n", "110 613.8003540039062\n", "111 584.9096069335938\n", "112 557.466796875\n", "113 531.3948364257812\n", "114 506.61962890625\n", "115 483.0841369628906\n", "116 460.7171936035156\n", "117 439.4493408203125\n", "118 419.2239074707031\n", "119 399.9928283691406\n", "120 381.705322265625\n", "121 364.3070373535156\n", "122 347.7393798828125\n", "123 331.97283935546875\n", "124 316.9671630859375\n", "125 302.68853759765625\n", "126 289.0873718261719\n", "127 276.12823486328125\n", "128 263.7820129394531\n", "129 252.02195739746094\n", "130 240.82232666015625\n", "131 230.1421661376953\n", "132 219.9631805419922\n", "133 210.26112365722656\n", "134 201.00880432128906\n", "135 192.19129943847656\n", "136 183.77542114257812\n", "137 175.74905395507812\n", "138 168.08860778808594\n", "139 160.7852325439453\n", "140 153.8109893798828\n", "141 147.15760803222656\n", "142 140.80355834960938\n", "143 134.73838806152344\n", "144 128.94873046875\n", "145 123.42112731933594\n", "146 118.14532470703125\n", "147 113.09927368164062\n", "148 108.27923583984375\n", "149 103.67829132080078\n", "150 99.27864837646484\n", "151 95.07437133789062\n", "152 91.05731964111328\n", "153 87.2179183959961\n", "154 83.54911804199219\n", "155 80.0405044555664\n", "156 76.684814453125\n", "157 73.47515106201172\n", "158 70.40677642822266\n", "159 67.47286224365234\n", "160 64.66546630859375\n", "161 61.97762680053711\n", "162 59.40721893310547\n", "163 56.94810485839844\n", "164 54.59606170654297\n", "165 52.34416961669922\n", "166 50.18893814086914\n", "167 48.12546157836914\n", "168 46.15074920654297\n", "169 44.2588996887207\n", "170 42.44817352294922\n", "171 40.71442794799805\n", "172 39.05426788330078\n", "173 37.46466064453125\n", "174 35.94228744506836\n", "175 34.483211517333984\n", "176 33.08485794067383\n", "177 31.74663734436035\n", "178 30.463415145874023\n", "179 29.23353385925293\n", "180 28.055572509765625\n", "181 26.927295684814453\n", "182 25.845623016357422\n", "183 24.808975219726562\n", "184 23.814783096313477\n", "185 22.86162757873535\n", "186 21.947311401367188\n", "187 21.071706771850586\n", "188 20.23161506652832\n", "189 19.426319122314453\n", "190 18.65383529663086\n", "191 17.913501739501953\n", "192 17.202938079833984\n", "193 16.521442413330078\n", "194 15.867642402648926\n", "195 15.240697860717773\n", "196 14.638861656188965\n", "197 14.062265396118164\n", "198 13.50815200805664\n", "199 12.976459503173828\n", "200 12.466854095458984\n", "201 11.977287292480469\n", "202 11.508007049560547\n", "203 11.057541847229004\n", "204 10.624938011169434\n", "205 10.209487915039062\n", "206 9.811256408691406\n", "207 9.42843246459961\n", "208 9.060935020446777\n", "209 8.708433151245117\n", "210 8.369855880737305\n", "211 8.044754028320312\n", "212 7.732644081115723\n", "213 7.432569980621338\n", "214 7.144754886627197\n", "215 6.868185997009277\n", "216 6.602710723876953\n", "217 6.347681522369385\n", "218 6.103111267089844\n", "219 5.8678483963012695\n", "220 5.64181661605835\n", "221 5.424355506896973\n", "222 5.215670108795166\n", "223 5.015231132507324\n", "224 4.822762966156006\n", "225 4.637856483459473\n", "226 4.460170269012451\n", "227 4.28932523727417\n", "228 4.125042915344238\n", "229 3.967371940612793\n", "230 3.8158133029937744\n", "231 3.670203924179077\n", "232 3.530052661895752\n", "233 3.3955368995666504\n", "234 3.2662713527679443\n", "235 3.141659736633301\n", "236 3.022263526916504\n", "237 2.9074623584747314\n", "238 2.7971487045288086\n", "239 2.6909444332122803\n", "240 2.5889720916748047\n", "241 2.4906997680664062\n", "242 2.3964099884033203\n", "243 2.305689573287964\n", "244 2.2184183597564697\n", "245 2.134580612182617\n", "246 2.053964376449585\n", "247 1.9763656854629517\n", "248 1.9018102884292603\n", "249 1.830053687095642\n", "250 1.7611491680145264\n", "251 1.6948093175888062\n", "252 1.6310014724731445\n", "253 1.5695786476135254\n", "254 1.5105568170547485\n", "255 1.453961730003357\n", "256 1.3993768692016602\n", "257 1.3468921184539795\n", "258 1.296314001083374\n", "259 1.2476303577423096\n", "260 1.2009528875350952\n", "261 1.1559319496154785\n", "262 1.1127043962478638\n", "263 1.071056604385376\n", "264 1.0309712886810303\n", "265 0.9924764037132263\n", "266 0.9553502202033997\n", "267 0.9197673797607422\n", "268 0.8854739665985107\n", "269 0.8523190021514893\n", "270 0.8206137418746948\n", "271 0.7899705767631531\n", "272 0.7606277465820312\n", "273 0.73226398229599\n", "274 0.7050141096115112\n", "275 0.678828239440918\n", "276 0.6536497473716736\n", "277 0.6293545365333557\n", "278 0.6060177683830261\n", "279 0.5834728479385376\n", "280 0.561832070350647\n", "281 0.5409616231918335\n", "282 0.5209071636199951\n", "283 0.5015677213668823\n", "284 0.48301637172698975\n", "285 0.46511176228523254\n", "286 0.4478737413883209\n", "287 0.43125540018081665\n", "288 0.4153783321380615\n", "289 0.3999677896499634\n", "290 0.3852301836013794\n", "291 0.3709765374660492\n", "292 0.3573264181613922\n", "293 0.34412410855293274\n", "294 0.3314245641231537\n", "295 0.3191765248775482\n", "296 0.3074215054512024\n", "297 0.296056866645813\n", "298 0.28519463539123535\n", "299 0.2746979594230652\n", "300 0.26459330320358276\n", "301 0.25485098361968994\n", "302 0.24548396468162537\n", "303 0.23641419410705566\n", "304 0.22769007086753845\n", "305 0.21935638785362244\n", "306 0.21129179000854492\n", "307 0.20352837443351746\n", "308 0.19601596891880035\n", "309 0.18884600698947906\n", "310 0.1819111704826355\n", "311 0.17525847256183624\n", "312 0.16882126033306122\n", "313 0.16263249516487122\n", "314 0.15666751563549042\n", "315 0.15092256665229797\n", "316 0.14540348947048187\n", "317 0.14006322622299194\n", "318 0.13494110107421875\n", "319 0.13001160323619843\n", "320 0.12526486814022064\n", "321 0.1206771731376648\n", "322 0.11626103520393372\n", "323 0.11202862858772278\n", "324 0.10792234539985657\n", "325 0.10398980975151062\n", "326 0.1001921221613884\n", "327 0.09651487320661545\n", "328 0.09299999475479126\n", "329 0.08962738513946533\n", "330 0.08636265993118286\n", "331 0.08319984376430511\n", "332 0.08016426116228104\n", "333 0.07726199924945831\n", "334 0.07444174587726593\n", "335 0.07173093408346176\n", "336 0.0690966546535492\n", "337 0.06658675521612167\n", "338 0.06416875869035721\n", "339 0.06185092404484749\n", "340 0.05959108844399452\n", "341 0.057433173060417175\n", "342 0.0553474947810173\n", "343 0.053334783762693405\n", "344 0.051402702927589417\n", "345 0.04955539479851723\n", "346 0.04775090888142586\n", "347 0.04602515324950218\n", "348 0.044351741671562195\n", "349 0.04274814575910568\n", "350 0.041199490427970886\n", "351 0.03970283269882202\n", "352 0.03825933113694191\n", "353 0.036878347396850586\n", "354 0.0355573333799839\n", "355 0.03427095338702202\n", "356 0.03304218873381615\n", "357 0.03185059875249863\n", "358 0.03069448284804821\n", "359 0.029578279703855515\n", "360 0.028519731014966965\n", "361 0.02748997136950493\n", "362 0.026511413976550102\n", "363 0.025562353432178497\n", "364 0.02463531121611595\n", "365 0.023760242387652397\n", "366 0.02290988340973854\n", "367 0.022078199312090874\n", "368 0.021282397210597992\n", "369 0.020530449226498604\n", "370 0.019799597561359406\n", "371 0.019093159586191177\n", "372 0.018408171832561493\n", "373 0.017752142623066902\n", "374 0.017124634236097336\n", "375 0.0165147352963686\n", "376 0.0159267857670784\n", "377 0.015361725352704525\n", "378 0.014817671850323677\n", "379 0.014293329790234566\n", "380 0.013794535771012306\n", "381 0.013302133418619633\n", "382 0.012831700965762138\n", "383 0.012388093397021294\n", "384 0.011951521039009094\n", "385 0.011539716273546219\n", "386 0.011130605824291706\n", "387 0.010747026652097702\n", "388 0.010368922725319862\n", "389 0.010003476403653622\n", "390 0.009659701958298683\n", "391 0.009321259334683418\n", "392 0.009002749808132648\n", "393 0.008687980473041534\n", "394 0.008389437571167946\n", "395 0.008094895631074905\n", "396 0.00781853124499321\n", "397 0.00755091430619359\n", "398 0.007287600077688694\n", "399 0.007037787232547998\n", "400 0.006801604758948088\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "401 0.006566936150193214\n", "402 0.0063502490520477295\n", "403 0.0061323679983615875\n", "404 0.005927860736846924\n", "405 0.00572930509224534\n", "406 0.005533153191208839\n", "407 0.005349132232367992\n", "408 0.0051721855998039246\n", "409 0.004995665047317743\n", "410 0.004830839112401009\n", "411 0.00467012170702219\n", "412 0.004517041612416506\n", "413 0.0043692924082279205\n", "414 0.004221327602863312\n", "415 0.004083284642547369\n", "416 0.003950608428567648\n", "417 0.0038255397230386734\n", "418 0.003699194174259901\n", "419 0.003574871923774481\n", "420 0.0034629858564585447\n", "421 0.003353290958330035\n", "422 0.003244665451347828\n", "423 0.003142776433378458\n", "424 0.00304229324683547\n", "425 0.002945477142930031\n", "426 0.0028537893667817116\n", "427 0.002764546312391758\n", "428 0.002679029945284128\n", "429 0.002593627432361245\n", "430 0.002514220541343093\n", "431 0.002438138471916318\n", "432 0.0023646263871341944\n", "433 0.0022932353895157576\n", "434 0.002221874427050352\n", "435 0.0021538916043937206\n", "436 0.00209184642881155\n", "437 0.0020262051839381456\n", "438 0.0019666266161948442\n", "439 0.0019075790187343955\n", "440 0.0018491483060643077\n", "441 0.0017957690870389342\n", "442 0.0017438435461372137\n", "443 0.0016942177899181843\n", "444 0.001644242205657065\n", "445 0.0015961473109200597\n", "446 0.0015512119280174375\n", "447 0.0015078299911692739\n", "448 0.0014637272106483579\n", "449 0.0014234762638807297\n", "450 0.001381349633447826\n", "451 0.0013440074399113655\n", "452 0.001305867568589747\n", "453 0.0012704429682344198\n", "454 0.0012343706330284476\n", "455 0.0012007243931293488\n", "456 0.0011682230979204178\n", "457 0.0011356750037521124\n", "458 0.0011055029463022947\n", "459 0.0010760502191260457\n", "460 0.0010467303218320012\n", "461 0.0010180269600823522\n", "462 0.0009923577308654785\n", "463 0.0009656138136051595\n", "464 0.0009410750935785472\n", "465 0.0009165913797914982\n", "466 0.0008935595978982747\n", "467 0.0008697272278368473\n", "468 0.0008474260102957487\n", "469 0.0008268379024229944\n", "470 0.0008065475849434733\n", "471 0.0007843846688047051\n", "472 0.0007656721863895655\n", "473 0.0007456142921000719\n", "474 0.000727273290976882\n", "475 0.0007085043471306562\n", "476 0.0006915377452969551\n", "477 0.0006745709688402712\n", "478 0.0006604917580261827\n", "479 0.0006438849377445877\n", "480 0.0006294006016105413\n", "481 0.0006143326172605157\n", "482 0.0006005939794704318\n", "483 0.0005858491058461368\n", "484 0.000572391611058265\n", "485 0.0005596061819233\n", "486 0.0005467765731737018\n", "487 0.0005332881701178849\n", "488 0.0005208597867749631\n", "489 0.0005100099369883537\n", "490 0.000498197041451931\n", "491 0.0004874719597864896\n", "492 0.00047716329572722316\n", "493 0.00046697931247763336\n", "494 0.0004567308642435819\n", "495 0.0004460803756956011\n", "496 0.00043638842180371284\n", "497 0.0004275768587831408\n", "498 0.0004182616830803454\n", "499 0.0004095847543794662\n" ] } ], "source": [ "# Use of Vaiables and Autograd in a 2-layer network with no need to manually implement backprop!\n", "import torch\n", "from torch.autograd import Variable\n", "dtype = torch.FloatTensor\n", "\n", "# N is batch size; D_in is input dimension;\n", "# H is hidden dimension; D_out is output dimension.\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "# Create random Tensors to hold input and outputs and wrap them in Variables.\n", "\n", "x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False) # requires_grad=False means no need to compute gradients\n", "y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)\n", "\n", "# Create random Tensors to hold weights and wrap them in Variables.\n", "# requires_grad=True here to compute gradients w.r.t Variables during a backprop pass.\n", "\n", "w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True) # requires_grad=False means no need to compute gradients\n", "w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)\n", "\n", "learning_rate = 1e-6\n", "for t in range(500):\n", " # Forward pass: compute predicted y using operations on Variables; these\n", " # are exactly the same operations we used to compute the forward pass using\n", " # Tensors, but we do not need to keep references to intermediate values since\n", " # we are not implementing the backward pass by hand.\n", " y_pred = x.mm(w1).clamp(min=0).mm(w2)\n", "\n", " # Compute and print loss using operations on Variables.\n", " # Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape\n", " # (1,); loss.data[0] is a scalar value holding the loss.\n", " loss = (y_pred - y).pow(2).sum()\n", " print(t, loss.data[0])\n", "\n", " # Use autograd to compute the backward pass. This call will compute the\n", " # gradient of loss with respect to all Variables with requires_grad=True.\n", " # After this call w1.grad and w2.grad will be Variables holding the gradient\n", " # of the loss with respect to w1 and w2 respectively.\n", " loss.backward()\n", "\n", " # Update weights using gradient descent; w1.data and w2.data are Tensors,\n", " # w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are\n", " # Tensors.\n", " w1.data -= learning_rate * w1.grad.data\n", " w2.data -= learning_rate * w2.grad.data\n", "\n", " # Manually zero the gradients after updating weights\n", " w1.grad.data.zero_()\n", " w2.grad.data.zero_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch: Defining new autograd functions\n", "Under the hood, each primitive autograd operator is really two functions that operate on Tensors. The forward function computes output Tensors from input Tensors. The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value.\n", "\n", "In PyTorch we can easily define our own autograd operator by defining a subclass of torch.autograd.Function and implementing the forward and backward functions. We can then use our new autograd operator by constructing an instance and calling it like a function, passing Variables containing input data.\n", "\n", "In this example we define our own custom autograd function for performing the ReLU nonlinearity, and use it to implement our two-layer network:\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 37267740.0\n", "1 35764716.0\n", "2 35199480.0\n", "3 30134798.0\n", "4 20876230.0\n", "5 11940865.0\n", "6 6248357.5\n", "7 3411474.25\n", "8 2109929.75\n", "9 1486262.5\n", "10 1147416.0\n", "11 933659.25\n", "12 781843.3125\n", "13 665453.875\n", "14 572186.3125\n", "15 495587.15625\n", "16 431790.71875\n", "17 378165.53125\n", "18 332707.5\n", "19 293927.75\n", "20 260602.078125\n", "21 231805.1875\n", "22 206837.75\n", "23 185119.71875\n", "24 166225.28125\n", "25 149666.796875\n", "26 135110.765625\n", "27 122273.4140625\n", "28 110904.5703125\n", "29 100806.375\n", "30 91813.03125\n", "31 83787.9921875\n", "32 76605.484375\n", "33 70168.1953125\n", "34 64381.640625\n", "35 59169.4453125\n", "36 54462.21875\n", "37 50205.0546875\n", "38 46347.6328125\n", "39 42845.19921875\n", "40 39659.359375\n", "41 36757.4609375\n", "42 34109.640625\n", "43 31690.05078125\n", "44 29476.259765625\n", "45 27446.322265625\n", "46 25583.3671875\n", "47 23871.591796875\n", "48 22297.443359375\n", "49 20848.37109375\n", "50 19513.296875\n", "51 18281.951171875\n", "52 17144.08203125\n", "53 16093.103515625\n", "54 15118.5703125\n", "55 14213.7685546875\n", "56 13372.318359375\n", "57 12589.6640625\n", "58 11861.248046875\n", "59 11182.630859375\n", "60 10549.470703125\n", "61 9958.626953125\n", "62 9406.4560546875\n", "63 8890.7177734375\n", "64 8409.7861328125\n", "65 7959.3427734375\n", "66 7536.99755859375\n", "67 7140.72119140625\n", "68 6768.61767578125\n", "69 6418.8779296875\n", "70 6090.21875\n", "71 5781.171875\n", "72 5489.9794921875\n", "73 5215.68603515625\n", "74 4957.0380859375\n", "75 4713.0322265625\n", "76 4482.77490234375\n", "77 4265.33984375\n", "78 4059.960205078125\n", "79 3865.8935546875\n", "80 3682.364990234375\n", "81 3508.66455078125\n", "82 3344.263427734375\n", "83 3188.444580078125\n", "84 3040.791015625\n", "85 2900.77099609375\n", "86 2767.92041015625\n", "87 2641.862060546875\n", "88 2522.1953125\n", "89 2408.60107421875\n", "90 2300.690185546875\n", "91 2198.137451171875\n", "92 2100.65966796875\n", "93 2007.908447265625\n", "94 1919.6483154296875\n", "95 1835.62548828125\n", "96 1755.661865234375\n", "97 1679.4940185546875\n", "98 1606.9317626953125\n", "99 1537.7978515625\n", "100 1471.900146484375\n", "101 1409.1009521484375\n", "102 1349.2012939453125\n", "103 1292.034423828125\n", "104 1237.503662109375\n", "105 1185.4390869140625\n", "106 1135.744384765625\n", "107 1088.301513671875\n", "108 1042.9739990234375\n", "109 999.6783447265625\n", "110 958.3234252929688\n", "111 918.8175659179688\n", "112 881.0344848632812\n", "113 844.9067993164062\n", "114 810.3673095703125\n", "115 777.3306274414062\n", "116 745.7218627929688\n", "117 715.4826049804688\n", "118 686.55517578125\n", "119 658.874755859375\n", "120 632.377197265625\n", "121 606.9979858398438\n", "122 582.705322265625\n", "123 559.4349975585938\n", "124 537.1435546875\n", "125 515.7855834960938\n", "126 495.33056640625\n", "127 475.72637939453125\n", "128 456.9494323730469\n", "129 438.9481506347656\n", "130 421.68475341796875\n", "131 405.1681823730469\n", "132 389.3519287109375\n", "133 374.1879577636719\n", "134 359.64312744140625\n", "135 345.6932678222656\n", "136 332.3123779296875\n", "137 319.4747009277344\n", "138 307.150634765625\n", "139 295.32281494140625\n", "140 283.9774169921875\n", "141 273.079833984375\n", "142 262.6209411621094\n", "143 252.57992553710938\n", "144 242.93710327148438\n", "145 233.6918487548828\n", "146 224.80181884765625\n", "147 216.26210021972656\n", "148 208.05758666992188\n", "149 200.1812744140625\n", "150 192.6105499267578\n", "151 185.34080505371094\n", "152 178.35305786132812\n", "153 171.64593505859375\n", "154 165.1945343017578\n", "155 158.99392700195312\n", "156 153.0364532470703\n", "157 147.30772399902344\n", "158 141.80587768554688\n", "159 136.51861572265625\n", "160 131.43280029296875\n", "161 126.54752349853516\n", "162 121.8475112915039\n", "163 117.32698822021484\n", "164 112.97860717773438\n", "165 108.79804992675781\n", "166 104.7764892578125\n", "167 100.90899658203125\n", "168 97.18679809570312\n", "169 93.61103820800781\n", "170 90.16815948486328\n", "171 86.85511779785156\n", "172 83.66851043701172\n", "173 80.60234832763672\n", "174 77.6517562866211\n", "175 74.81199645996094\n", "176 72.0802230834961\n", "177 69.45210266113281\n", "178 66.92032623291016\n", "179 64.4840087890625\n", "180 62.139137268066406\n", "181 59.881126403808594\n", "182 57.707550048828125\n", "183 55.61520004272461\n", "184 53.60150146484375\n", "185 51.66324996948242\n", "186 49.795257568359375\n", "187 47.99642562866211\n", "188 46.26518249511719\n", "189 44.59707260131836\n", "190 42.99076843261719\n", "191 41.4441032409668\n", "192 39.95560073852539\n", "193 38.5213623046875\n", "194 37.1389274597168\n", "195 35.80757522583008\n", "196 34.524593353271484\n", "197 33.28947067260742\n", "198 32.09878158569336\n", "199 30.953550338745117\n", "200 29.848520278930664\n", "201 28.78449249267578\n", "202 27.759124755859375\n", "203 26.77107810974121\n", "204 25.818552017211914\n", "205 24.90070152282715\n", "206 24.016618728637695\n", "207 23.165571212768555\n", "208 22.34392738342285\n", "209 21.552274703979492\n", "210 20.789934158325195\n", "211 20.05422019958496\n", "212 19.345535278320312\n", "213 18.662813186645508\n", "214 18.00432777404785\n", "215 17.369583129882812\n", "216 16.757579803466797\n", "217 16.167564392089844\n", "218 15.598984718322754\n", "219 15.05059814453125\n", "220 14.521815299987793\n", "221 14.012593269348145\n", "222 13.521288871765137\n", "223 13.04751205444336\n", "224 12.590230941772461\n", "225 12.14965534210205\n", "226 11.725013732910156\n", "227 11.314926147460938\n", "228 10.919958114624023\n", "229 10.539161682128906\n", "230 10.171355247497559\n", "231 9.817110061645508\n", "232 9.474847793579102\n", "233 9.145466804504395\n", "234 8.827065467834473\n", "235 8.520536422729492\n", "236 8.224677085876465\n", "237 7.939168453216553\n", "238 7.663815021514893\n", "239 7.398035049438477\n", "240 7.141916751861572\n", "241 6.894689083099365\n", "242 6.656137943267822\n", "243 6.42626428604126\n", "244 6.204166889190674\n", "245 5.989865303039551\n", "246 5.7830491065979\n", "247 5.583761692047119\n", "248 5.391234397888184\n", "249 5.205582618713379\n", "250 5.026453018188477\n", "251 4.853508472442627\n", "252 4.686666011810303\n", "253 4.525530815124512\n", "254 4.370087146759033\n", "255 4.220409393310547\n", "256 4.075485706329346\n", "257 3.935882568359375\n", "258 3.800877332687378\n", "259 3.6708292961120605\n", "260 3.5451159477233887\n", "261 3.423879623413086\n", "262 3.306863307952881\n", "263 3.194089651107788\n", "264 3.084909439086914\n", "265 2.979776382446289\n", "266 2.8781092166900635\n", "267 2.7799901962280273\n", "268 2.6851348876953125\n", "269 2.593947172164917\n", "270 2.50581693649292\n", "271 2.4204869270324707\n", "272 2.3381664752960205\n", "273 2.2585439682006836\n", "274 2.18206787109375\n", "275 2.1080169677734375\n", "276 2.036513566970825\n", "277 1.9676604270935059\n", "278 1.9008911848068237\n", "279 1.8365049362182617\n", "280 1.7742794752120972\n", "281 1.714285135269165\n", "282 1.6563860177993774\n", "283 1.6004620790481567\n", "284 1.5464372634887695\n", "285 1.494178295135498\n", "286 1.4437446594238281\n", "287 1.3950902223587036\n", "288 1.3480584621429443\n", "289 1.3026072978973389\n", "290 1.258941411972046\n", "291 1.2163642644882202\n", "292 1.1755170822143555\n", "293 1.1359773874282837\n", "294 1.0978280305862427\n", "295 1.06088387966156\n", "296 1.0252677202224731\n", "297 0.990941047668457\n", "298 0.9577256441116333\n", "299 0.9256399273872375\n", "300 0.8946157097816467\n", "301 0.8645932674407959\n", "302 0.8356277346611023\n", "303 0.8076440095901489\n", "304 0.7806621789932251\n", "305 0.7545725703239441\n", "306 0.7293268442153931\n", "307 0.7049664855003357\n", "308 0.6814430952072144\n", "309 0.658761739730835\n", "310 0.6367996335029602\n", "311 0.6155775785446167\n", "312 0.5950506925582886\n", "313 0.5752754211425781\n", "314 0.5561575889587402\n", "315 0.5376402139663696\n", "316 0.5197628736495972\n", "317 0.5025719404220581\n", "318 0.48587459325790405\n", "319 0.46972161531448364\n", "320 0.4541400969028473\n", "321 0.43905285000801086\n", "322 0.4244878590106964\n", "323 0.4104618728160858\n", "324 0.39690709114074707\n", "325 0.38367366790771484\n", "326 0.37096306681632996\n", "327 0.35865893959999084\n", "328 0.34682440757751465\n", "329 0.3353811800479889\n", "330 0.3242852985858917\n", "331 0.3135945498943329\n", "332 0.3032459616661072\n", "333 0.29322201013565063\n", "334 0.2835618853569031\n", "335 0.2741992473602295\n", "336 0.26517972350120544\n", "337 0.25643640756607056\n", "338 0.2479628622531891\n", "339 0.23979204893112183\n", "340 0.23193234205245972\n", "341 0.2242671400308609\n", "342 0.2168930619955063\n", "343 0.2097444087266922\n", "344 0.2028605192899704\n", "345 0.1962338536977768\n", "346 0.18978500366210938\n", "347 0.1835404634475708\n", "348 0.17750103771686554\n", "349 0.17172792553901672\n", "350 0.16605545580387115\n", "351 0.16060921549797058\n", "352 0.15534861385822296\n", "353 0.15024448931217194\n", "354 0.14536623656749725\n", "355 0.14057379961013794\n", "356 0.13597580790519714\n", "357 0.13154900074005127\n", "358 0.1272878497838974\n", "359 0.12308774888515472\n", "360 0.1190754845738411\n", "361 0.11518353223800659\n", "362 0.11140834540128708\n", "363 0.1077791303396225\n", "364 0.10426264256238937\n", "365 0.10088305175304413\n", "366 0.09758631885051727\n", "367 0.09441451728343964\n", "368 0.09134842455387115\n", "369 0.08836644142866135\n", "370 0.08551565557718277\n", "371 0.0827186331152916\n", "372 0.0800226628780365\n", "373 0.07741819322109222\n", "374 0.0749141052365303\n", "375 0.0724942609667778\n", "376 0.07012748718261719\n", "377 0.06782606244087219\n", "378 0.06563150137662888\n", "379 0.06352625042200089\n", "380 0.06144631654024124\n", "381 0.05948108434677124\n", "382 0.05756570026278496\n", "383 0.055704500526189804\n", "384 0.05389977991580963\n", "385 0.052152279764413834\n", "386 0.050462037324905396\n", "387 0.0488210991024971\n", "388 0.04726396128535271\n", "389 0.045748334378004074\n", "390 0.04426315799355507\n", "391 0.04283710569143295\n", "392 0.041462723165750504\n", "393 0.04012970253825188\n", "394 0.03883979097008705\n", "395 0.0375945121049881\n", "396 0.036392029374837875\n", "397 0.035216353833675385\n", "398 0.0340835303068161\n", "399 0.0330033153295517\n", "400 0.0319441556930542\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "401 0.03092949651181698\n", "402 0.029925189912319183\n", "403 0.028983892872929573\n", "404 0.028051339089870453\n", "405 0.027159346267580986\n", "406 0.026293398812413216\n", "407 0.02546057477593422\n", "408 0.024651827290654182\n", "409 0.023875653743743896\n", "410 0.02311556600034237\n", "411 0.022379254922270775\n", "412 0.021676119416952133\n", "413 0.0209877360612154\n", "414 0.020319685339927673\n", "415 0.01967649720609188\n", "416 0.019053490832448006\n", "417 0.018447471782565117\n", "418 0.017870858311653137\n", "419 0.01730620674788952\n", "420 0.01676577515900135\n", "421 0.01624283567070961\n", "422 0.0157344788312912\n", "423 0.015246149152517319\n", "424 0.014768954366445541\n", "425 0.01429677102714777\n", "426 0.013850965537130833\n", "427 0.013418878428637981\n", "428 0.013010782189667225\n", "429 0.012603594921529293\n", "430 0.012216034345328808\n", "431 0.011834518052637577\n", "432 0.01147628203034401\n", "433 0.011118064634501934\n", "434 0.0107742203399539\n", "435 0.010442456230521202\n", "436 0.010112447664141655\n", "437 0.009803086519241333\n", "438 0.009505126625299454\n", "439 0.009210986085236073\n", "440 0.008932799100875854\n", "441 0.008662850596010685\n", "442 0.008402058854699135\n", "443 0.008139402605593204\n", "444 0.007891377434134483\n", "445 0.007653705310076475\n", "446 0.007423435337841511\n", "447 0.007203694898635149\n", "448 0.0069829924032092094\n", "449 0.006775957066565752\n", "450 0.0065697873942554\n", "451 0.006375055760145187\n", "452 0.006185355130583048\n", "453 0.0060003213584423065\n", "454 0.005824679974466562\n", "455 0.005652319174259901\n", "456 0.005483253858983517\n", "457 0.0053228093311190605\n", "458 0.005165347829461098\n", "459 0.005012849345803261\n", "460 0.00486498000100255\n", "461 0.004725072532892227\n", "462 0.004582775291055441\n", "463 0.0044530704617500305\n", "464 0.004323487635701895\n", "465 0.004199502058327198\n", "466 0.004082949832081795\n", "467 0.00396241107955575\n", "468 0.003854303155094385\n", "469 0.0037417884450405836\n", "470 0.003637957852333784\n", "471 0.003536310512572527\n", "472 0.0034359132405370474\n", "473 0.0033391271717846394\n", "474 0.003243305953219533\n", "475 0.003150815377011895\n", "476 0.003066697157919407\n", "477 0.0029837233014404774\n", "478 0.002901113824918866\n", "479 0.0028215814381837845\n", "480 0.002746968762949109\n", "481 0.002671067835763097\n", "482 0.002599822822958231\n", "483 0.0025278807152062654\n", "484 0.0024597335141152143\n", "485 0.0023916594218462706\n", "486 0.002332271309569478\n", "487 0.0022717046085745096\n", "488 0.002212217776104808\n", "489 0.0021560348104685545\n", "490 0.002097273012623191\n", "491 0.0020442584063857794\n", "492 0.00199206848628819\n", "493 0.001940639573149383\n", "494 0.0018914203392341733\n", "495 0.0018432828364893794\n", "496 0.0017941630212590098\n", "497 0.0017496095970273018\n", "498 0.00170668784994632\n", "499 0.0016636578366160393\n" ] } ], "source": [ "# -*- coding: utf-8 -*-\n", "import torch\n", "from torch.autograd import Variable\n", "\n", "\n", "class MyReLU(torch.autograd.Function):\n", " \"\"\"\n", " We can implement our own custom autograd Functions by subclassing\n", " torch.autograd.Function and implementing the forward and backward passes\n", " which operate on Tensors.\n", " \"\"\"\n", "\n", " def forward(self, input):\n", " \"\"\"\n", " In the forward pass we receive a Tensor containing the input and return a\n", " Tensor containing the output. You can cache arbitrary Tensors for use in the\n", " backward pass using the save_for_backward method.\n", " \"\"\"\n", " self.save_for_backward(input)\n", " return input.clamp(min=0)\n", "\n", " def backward(self, grad_output):\n", " \"\"\"\n", " In the backward pass we receive a Tensor containing the gradient of the loss\n", " with respect to the output, and we need to compute the gradient of the loss\n", " with respect to the input.\n", " \"\"\"\n", " input, = self.saved_tensors\n", " grad_input = grad_output.clone()\n", " grad_input[input < 0] = 0\n", " return grad_input\n", "\n", "\n", "dtype = torch.FloatTensor\n", "# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU\n", "\n", "# N is batch size; D_in is input dimension;\n", "# H is hidden dimension; D_out is output dimension.\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "# Create random Tensors to hold input and outputs, and wrap them in Variables.\n", "x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)\n", "y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)\n", "\n", "# Create random Tensors for weights, and wrap them in Variables.\n", "w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)\n", "w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)\n", "\n", "learning_rate = 1e-6\n", "for t in range(500):\n", " # Construct an instance of our MyReLU class to use in our network\n", " relu = MyReLU()\n", "\n", " # Forward pass: compute predicted y using operations on Variables; we compute\n", " # ReLU using our custom autograd operation.\n", " y_pred = relu(x.mm(w1)).mm(w2)\n", "\n", " # Compute and print loss\n", " loss = (y_pred - y).pow(2).sum()\n", " print(t, loss.data[0])\n", "\n", " # Use autograd to compute the backward pass.\n", " loss.backward()\n", "\n", " # Update weights using gradient descent\n", " w1.data -= learning_rate * w1.grad.data\n", " w2.data -= learning_rate * w2.grad.data\n", "\n", " # Manually zero the gradients after updating weights\n", " w1.grad.data.zero_()\n", " w2.grad.data.zero_()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## What is a nn module\n", "When building neural networks we frequently think of arranging the computation into layers, some of which have learnable parameters which will be optimized during learning.\n", "\n", "In TensorFlow, packages like Keras, TensorFlow-Slim, and TFLearn provide higher-level abstractions over raw computational graphs that are useful for building neural networks.\n", "\n", "In PyTorch, the nn package serves this same purpose. The nn package defines a set of Modules, which are roughly equivalent to neural network layers. A Module receives input Variables and computes output Variables, but may also hold internal state such as Variables containing learnable parameters. The nn package also defines a set of useful loss functions that are commonly used when training neural networks.\n", "\n", "In this example we use the nn package to implement our two-layer network:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 680.08154296875\n", "1 628.8499755859375\n", "2 584.1482543945312\n", "3 544.8362426757812\n", "4 509.60052490234375\n", "5 477.94586181640625\n", "6 449.4169616699219\n", "7 423.177734375\n", "8 398.7761535644531\n", "9 376.30096435546875\n", "10 355.3580322265625\n", "11 335.78729248046875\n", "12 317.5242614746094\n", "13 300.2922058105469\n", "14 283.9698181152344\n", "15 268.5042419433594\n", "16 253.79812622070312\n", "17 239.88864135742188\n", "18 226.66624450683594\n", "19 214.0961151123047\n", "20 202.15972900390625\n", "21 190.82431030273438\n", "22 180.07672119140625\n", "23 169.84349060058594\n", "24 160.177734375\n", "25 151.01641845703125\n", "26 142.33706665039062\n", "27 134.12139892578125\n", "28 126.35624694824219\n", "29 119.03170776367188\n", "30 112.15123748779297\n", "31 105.64317321777344\n", "32 99.51466369628906\n", "33 93.72786712646484\n", "34 88.27943420410156\n", "35 83.15460205078125\n", "36 78.34161376953125\n", "37 73.80974578857422\n", "38 69.55596923828125\n", "39 65.56876373291016\n", "40 61.8112907409668\n", "41 58.278228759765625\n", "42 54.950050354003906\n", "43 51.81352615356445\n", "44 48.865779876708984\n", "45 46.08952713012695\n", "46 43.48237609863281\n", "47 41.032230377197266\n", "48 38.72626495361328\n", "49 36.560401916503906\n", "50 34.52272415161133\n", "51 32.603416442871094\n", "52 30.797603607177734\n", "53 29.0950984954834\n", "54 27.492034912109375\n", "55 25.980615615844727\n", "56 24.55516242980957\n", "57 23.21516990661621\n", "58 21.952560424804688\n", "59 20.762468338012695\n", "60 19.642051696777344\n", "61 18.584585189819336\n", "62 17.585891723632812\n", "63 16.644123077392578\n", "64 15.755452156066895\n", "65 14.916620254516602\n", "66 14.125481605529785\n", "67 13.380037307739258\n", "68 12.675837516784668\n", "69 12.011648178100586\n", "70 11.383655548095703\n", "71 10.7904634475708\n", "72 10.233356475830078\n", "73 9.707324981689453\n", "74 9.2102689743042\n", "75 8.740628242492676\n", "76 8.298053741455078\n", "77 7.879985809326172\n", "78 7.484074592590332\n", "79 7.109870910644531\n", "80 6.756218433380127\n", "81 6.421789646148682\n", "82 6.105353355407715\n", "83 5.8055315017700195\n", "84 5.52149772644043\n", "85 5.252496242523193\n", "86 4.997690677642822\n", "87 4.756503582000732\n", "88 4.527294158935547\n", "89 4.309894561767578\n", "90 4.103759765625\n", "91 3.9084742069244385\n", "92 3.7226548194885254\n", "93 3.5464107990264893\n", "94 3.37937068939209\n", "95 3.2209417819976807\n", "96 3.070256471633911\n", "97 2.9272148609161377\n", "98 2.791459083557129\n", "99 2.6623592376708984\n", "100 2.5399575233459473\n", "101 2.4237000942230225\n", "102 2.313190460205078\n", "103 2.2082247734069824\n", "104 2.1084320545196533\n", "105 2.0137369632720947\n", "106 1.9235584735870361\n", "107 1.8378454446792603\n", "108 1.7563918828964233\n", "109 1.6788384914398193\n", "110 1.6050490140914917\n", "111 1.5347329378128052\n", "112 1.4678232669830322\n", "113 1.4041197299957275\n", "114 1.3435639142990112\n", "115 1.2858960628509521\n", "116 1.2308998107910156\n", "117 1.1784309148788452\n", "118 1.1284685134887695\n", "119 1.0808899402618408\n", "120 1.0353796482086182\n", "121 0.9919747710227966\n", "122 0.9504337310791016\n", "123 0.9108781814575195\n", "124 0.8731096386909485\n", "125 0.8370406627655029\n", "126 0.8026205897331238\n", "127 0.7697471380233765\n", "128 0.7383762001991272\n", "129 0.7084004282951355\n", "130 0.679768979549408\n", "131 0.6523947715759277\n", "132 0.626214861869812\n", "133 0.6011669635772705\n", "134 0.5772068500518799\n", "135 0.5543199777603149\n", "136 0.5323988795280457\n", "137 0.5114548206329346\n", "138 0.49139833450317383\n", "139 0.4721781611442566\n", "140 0.45376724004745483\n", "141 0.43613412976264954\n", "142 0.41924476623535156\n", "143 0.40306606888771057\n", "144 0.38756275177001953\n", "145 0.3727158010005951\n", "146 0.35854285955429077\n", "147 0.34496134519577026\n", "148 0.33194538950920105\n", "149 0.31945526599884033\n", "150 0.3074859082698822\n", "151 0.29600128531455994\n", "152 0.2849758565425873\n", "153 0.2743982970714569\n", "154 0.2642490863800049\n", "155 0.2545064389705658\n", "156 0.2451404333114624\n", "157 0.23615539073944092\n", "158 0.22752535343170166\n", "159 0.21923471987247467\n", "160 0.2112797051668167\n", "161 0.20362289249897003\n", "162 0.19626261293888092\n", "163 0.1891934722661972\n", "164 0.18240153789520264\n", "165 0.17586886882781982\n", "166 0.16958405077457428\n", "167 0.16354581713676453\n", "168 0.15773969888687134\n", "169 0.15215399861335754\n", "170 0.14677844941616058\n", "171 0.14159975945949554\n", "172 0.13661399483680725\n", "173 0.13181616365909576\n", "174 0.1271965205669403\n", "175 0.12274857610464096\n", "176 0.11846866458654404\n", "177 0.11434680968523026\n", "178 0.11037838459014893\n", "179 0.10655605047941208\n", "180 0.10287366062402725\n", "181 0.09932510554790497\n", "182 0.0959087535738945\n", "183 0.09261389821767807\n", "184 0.08943838626146317\n", "185 0.08637817949056625\n", "186 0.08342777192592621\n", "187 0.0805828794836998\n", "188 0.07784154266119003\n", "189 0.07519736886024475\n", "190 0.0726480707526207\n", "191 0.07019388675689697\n", "192 0.06782343983650208\n", "193 0.06553709506988525\n", "194 0.06333591789007187\n", "195 0.06121001020073891\n", "196 0.059156980365514755\n", "197 0.05717690289020538\n", "198 0.055265795439481735\n", "199 0.053421154618263245\n", "200 0.05164136365056038\n", "201 0.04992407187819481\n", "202 0.04826623946428299\n", "203 0.046666938811540604\n", "204 0.04512207210063934\n", "205 0.04363016411662102\n", "206 0.04219016805291176\n", "207 0.04079950973391533\n", "208 0.03945689648389816\n", "209 0.038160402327775955\n", "210 0.036908261477947235\n", "211 0.03569883480668068\n", "212 0.034530479460954666\n", "213 0.03340240567922592\n", "214 0.032312240451574326\n", "215 0.031259190291166306\n", "216 0.030241671949625015\n", "217 0.02925860695540905\n", "218 0.028309127315878868\n", "219 0.02739332616329193\n", "220 0.026506047695875168\n", "221 0.025648759678006172\n", "222 0.02482016757130623\n", "223 0.024018865078687668\n", "224 0.023244598880410194\n", "225 0.022496270015835762\n", "226 0.02177284099161625\n", "227 0.021073248237371445\n", "228 0.020396927371621132\n", "229 0.01974322460591793\n", "230 0.019111426547169685\n", "231 0.0185005571693182\n", "232 0.017909277230501175\n", "233 0.017337419092655182\n", "234 0.016784587875008583\n", "235 0.016250004991889\n", "236 0.01573282666504383\n", "237 0.015232610516250134\n", "238 0.014748821035027504\n", "239 0.014280819334089756\n", "240 0.013828235678374767\n", "241 0.013390136882662773\n", "242 0.012966613285243511\n", "243 0.012556690722703934\n", "244 0.012161037884652615\n", "245 0.0117774223908782\n", "246 0.011406159959733486\n", "247 0.011046788655221462\n", "248 0.010699193924665451\n", "249 0.010362686589360237\n", "250 0.010037034749984741\n", "251 0.00972204003483057\n", "252 0.009417165070772171\n", "253 0.009121924638748169\n", "254 0.008836277760565281\n", "255 0.00855974666774273\n", "256 0.008291949518024921\n", "257 0.00803307630121708\n", "258 0.007782185450196266\n", "259 0.007539329584687948\n", "260 0.007304261904209852\n", "261 0.007076835259795189\n", "262 0.00685644056648016\n", "263 0.006643133703619242\n", "264 0.006436587776988745\n", "265 0.006236482877284288\n", "266 0.006042997818440199\n", "267 0.005855487193912268\n", "268 0.005673850420862436\n", "269 0.005498659797012806\n", "270 0.005328337661921978\n", "271 0.005163470283150673\n", "272 0.0050038304179906845\n", "273 0.004849148914217949\n", "274 0.004699397832155228\n", "275 0.00455437321215868\n", "276 0.004413901828229427\n", "277 0.0042778486385941505\n", "278 0.004146031104028225\n", "279 0.004018353298306465\n", "280 0.0038947267457842827\n", "281 0.0037749370094388723\n", "282 0.003658916801214218\n", "283 0.0035464726388454437\n", "284 0.003437580307945609\n", "285 0.0033320633228868246\n", "286 0.0032298287842422724\n", "287 0.003130849450826645\n", "288 0.003034892724826932\n", "289 0.0029419672209769487\n", "290 0.0028519199695438147\n", "291 0.002764653880149126\n", "292 0.0026801559142768383\n", "293 0.002598251448944211\n", "294 0.002518962835893035\n", "295 0.002442245604470372\n", "296 0.0023677151184529066\n", "297 0.0022955138701945543\n", "298 0.0022255151998251677\n", "299 0.002157705370336771\n", "300 0.002091981703415513\n", "301 0.002028322545811534\n", "302 0.00196659192442894\n", "303 0.0019068144029006362\n", "304 0.0018488289788365364\n", "305 0.001792663475498557\n", "306 0.0017382020596414804\n", "307 0.0016854503192007542\n", "308 0.0016342989401891828\n", "309 0.0015847444301471114\n", "310 0.0015366816660389304\n", "311 0.0014901245012879372\n", "312 0.0014449851587414742\n", "313 0.001401250483468175\n", "314 0.0013588427100330591\n", "315 0.0013177691726014018\n", "316 0.001277906121686101\n", "317 0.0012393008219078183\n", "318 0.0012018403504043818\n", "319 0.001165554509498179\n", "320 0.0011303661158308387\n", "321 0.001096343039534986\n", "322 0.0010632890043780208\n", "323 0.001031213440001011\n", "324 0.0010001431219279766\n", "325 0.0009700573864392936\n", "326 0.0009408452315256\n", "327 0.000912512477952987\n", "328 0.0008850548765622079\n", "329 0.0008584235911257565\n", "330 0.0008326029637828469\n", "331 0.0008075683144852519\n", "332 0.0007833010167814791\n", "333 0.0007597761577926576\n", "334 0.0007369595696218312\n", "335 0.0007148331496864557\n", "336 0.0006933821714483202\n", "337 0.0006725748535245657\n", "338 0.0006523994379676878\n", "339 0.0006328612216748297\n", "340 0.0006138771423138678\n", "341 0.0005954877706244588\n", "342 0.0005776527686975896\n", "343 0.0005603585159406066\n", "344 0.0005436040228232741\n", "345 0.0005273337010294199\n", "346 0.000511560239829123\n", "347 0.0004962603561580181\n", "348 0.00048146690824069083\n", "349 0.0004670854832511395\n", "350 0.0004531279264483601\n", "351 0.0004396018339321017\n", "352 0.0004264691669959575\n", "353 0.00041373888961970806\n", "354 0.0004013977595604956\n", "355 0.0003894170222338289\n", "356 0.00037781387800350785\n", "357 0.0003665469994302839\n", "358 0.0003556182491593063\n", "359 0.0003450293734204024\n", "360 0.00033474891097284853\n", "361 0.0003247867280151695\n", "362 0.00031511206179857254\n", "363 0.0003057269495911896\n", "364 0.00029663191526196897\n", "365 0.00028781587025150657\n", "366 0.00027925128233619034\n", "367 0.0002709476975724101\n", "368 0.00026289623929187655\n", "369 0.00025508779799565673\n", "370 0.0002475187066011131\n", "371 0.00024016370298340917\n", "372 0.00023303109628614038\n", "373 0.00022611598251387477\n", "374 0.00021940314036328346\n", "375 0.00021291013399604708\n", "376 0.0002065994485747069\n", "377 0.00020047678845003247\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "378 0.00019453163258731365\n", "379 0.0001887652324512601\n", "380 0.0001831761037465185\n", "381 0.00017774860316421837\n", "382 0.00017248367657884955\n", "383 0.00016738024714868516\n", "384 0.00016242482524830848\n", "385 0.0001576153008500114\n", "386 0.0001529530854895711\n", "387 0.00014842944801785052\n", "388 0.00014404467947315425\n", "389 0.00013978010974824429\n", "390 0.00013565001427195966\n", "391 0.00013164129632059485\n", "392 0.00012775001232512295\n", "393 0.0001239780249306932\n", "394 0.00012031249207211658\n", "395 0.00011676689609885216\n", "396 0.00011332175199640915\n", "397 0.00010997249773936346\n", "398 0.00010673052747733891\n", "399 0.00010358005238231272\n", "400 0.00010052488505607471\n", "401 9.756081271916628e-05\n", "402 9.469027281738818e-05\n", "403 9.190698619931936e-05\n", "404 8.919845276977867e-05\n", "405 8.657083526486531e-05\n", "406 8.402206731261685e-05\n", "407 8.154464012477547e-05\n", "408 7.914425077615306e-05\n", "409 7.681240822421387e-05\n", "410 7.455307786585763e-05\n", "411 7.236027886392549e-05\n", "412 7.022878708085045e-05\n", "413 6.816528912167996e-05\n", "414 6.616451719310135e-05\n", "415 6.4219391788356e-05\n", "416 6.233102612895891e-05\n", "417 6.049991861800663e-05\n", "418 5.872180190635845e-05\n", "419 5.699725079466589e-05\n", "420 5.532385330297984e-05\n", "421 5.369873542804271e-05\n", "422 5.212163523538038e-05\n", "423 5.0590115279192105e-05\n", "424 4.910368807031773e-05\n", "425 4.766530400956981e-05\n", "426 4.626710870070383e-05\n", "427 4.490738865570165e-05\n", "428 4.3591026042122394e-05\n", "429 4.231411367072724e-05\n", "430 4.1074410546571016e-05\n", "431 3.9871807530289516e-05\n", "432 3.870454747811891e-05\n", "433 3.757094600587152e-05\n", "434 3.646806362667121e-05\n", "435 3.539905446814373e-05\n", "436 3.436301994952373e-05\n", "437 3.3357628126395866e-05\n", "438 3.237837881897576e-05\n", "439 3.1431503884959966e-05\n", "440 3.0511764634866267e-05\n", "441 2.961848076665774e-05\n", "442 2.875354402931407e-05\n", "443 2.7911590223084204e-05\n", "444 2.7094889446743764e-05\n", "445 2.6300884201191366e-05\n", "446 2.553403828642331e-05\n", "447 2.4784965717117302e-05\n", "448 2.4059936549747363e-05\n", "449 2.3357155441772193e-05\n", "450 2.267470335937105e-05\n", "451 2.2012040062691085e-05\n", "452 2.1369320165831596e-05\n", "453 2.0743538698297925e-05\n", "454 2.013810990320053e-05\n", "455 1.9549743228708394e-05\n", "456 1.8979713786393404e-05\n", "457 1.842488200054504e-05\n", "458 1.78881709871348e-05\n", "459 1.736673220875673e-05\n", "460 1.6858253729878925e-05\n", "461 1.6368272554245777e-05\n", "462 1.588999839441385e-05\n", "463 1.542721474834252e-05\n", "464 1.4977055798226502e-05\n", "465 1.4539912626787554e-05\n", "466 1.4116209058556706e-05\n", "467 1.3703524928132538e-05\n", "468 1.330458871962037e-05\n", "469 1.2917284038849175e-05\n", "470 1.254109520232305e-05\n", "471 1.2174677976872772e-05\n", "472 1.182032883662032e-05\n", "473 1.14756403490901e-05\n", "474 1.1141963113914244e-05\n", "475 1.081790560419904e-05\n", "476 1.0501042197574861e-05\n", "477 1.0197520168730989e-05\n", "478 9.900572877086233e-06\n", "479 9.611635505279992e-06\n", "480 9.332347872259561e-06\n", "481 9.060167940333486e-06\n", "482 8.796627298579551e-06\n", "483 8.540252565580886e-06\n", "484 8.29136752145132e-06\n", "485 8.050311407714617e-06\n", "486 7.816828656359576e-06\n", "487 7.589314918732271e-06\n", "488 7.368240403593518e-06\n", "489 7.1535123424837366e-06\n", "490 6.947231213416671e-06\n", "491 6.7452187977323774e-06\n", "492 6.549066711158957e-06\n", "493 6.358453902066685e-06\n", "494 6.174030204419978e-06\n", "495 5.9948902162432205e-06\n", "496 5.820296792080626e-06\n", "497 5.651726041833172e-06\n", "498 5.487122507474851e-06\n", "499 5.328512997948565e-06\n" ] } ], "source": [ "# -*- coding: utf-8 -*-\n", "import torch\n", "from torch.autograd import Variable\n", "\n", "# N is batch size; D_in is input dimension;\n", "# H is hidden dimension; D_out is output dimension.\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "# Create random Tensors to hold inputs and outputs, and wrap them in Variables.\n", "x = Variable(torch.randn(N, D_in))\n", "y = Variable(torch.randn(N, D_out), requires_grad=False)\n", "\n", "# Use the nn package to define our model as a sequence of layers. nn.Sequential\n", "# is a Module which contains other Modules, and applies them in sequence to\n", "# produce its output. Each Linear Module computes output from input using a\n", "# linear function, and holds internal Variables for its weight and bias.\n", "model = torch.nn.Sequential(\n", " torch.nn.Linear(D_in, H),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(H, D_out),\n", ")\n", "\n", "# The nn package also contains definitions of popular loss functions; in this\n", "# case we will use Mean Squared Error (MSE) as our loss function.\n", "loss_fn = torch.nn.MSELoss(size_average=False)\n", "\n", "learning_rate = 1e-4\n", "for t in range(500):\n", " # Forward pass: compute predicted y by passing x to the model. Module objects\n", " # override the __call__ operator so you can call them like functions. When\n", " # doing so you pass a Variable of input data to the Module and it produces\n", " # a Variable of output data.\n", " y_pred = model(x)\n", "\n", " # Compute and print loss. We pass Variables containing the predicted and true\n", " # values of y, and the loss function returns a Variable containing the\n", " # loss.\n", " loss = loss_fn(y_pred, y)\n", " print(t, loss.data[0])\n", "\n", " # Zero the gradients before running the backward pass.\n", " model.zero_grad()\n", "\n", " # Backward pass: compute gradient of the loss with respect to all the learnable\n", " # parameters of the model. Internally, the parameters of each Module are stored\n", " # in Variables with requires_grad=True, so this call will compute gradients for\n", " # all learnable parameters in the model.\n", " loss.backward()\n", "\n", " # Update the weights using gradient descent. Each parameter is a Variable, so\n", " # we can access its data and gradients like we did before.\n", " for param in model.parameters():\n", " param.data -= learning_rate * param.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PyTorch - optim\n", "With learning rate of $1e-4$\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable\n", "\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "x = Variable(torch.randn(N, D_in))\n", "y = Variable(torch.randn(N, D_out), requires_grad=False)\n", "\n", "model = torch.nn.Sequential( torch.nn.Linear(D_in, H),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(H, D_out)\n", " )\n", "\n", "loss_fxn = torch.nn.MSELoss(size_average=False)\n", "\n", "learning_rate = 1e-4\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n", "499 5.173239514988381e-06\n" ] } ], "source": [ "# We loop\n", "\n", "for i in range(500):\n", " y_pred = model(x)\n", " loss = loss_fxn(y_pred, y)\n", " print(t, loss.data[0])\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom nn module\n", "\n", "For more complex computation, you can define your own module by subclassing nn.Module" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n", "0 656.797607421875\n" ] } ], "source": [ "import torch\n", "from torch.autograd import Variable\n", "\n", "class DoubleLayerNet(torch.nn.Module):\n", " def __init__(self, D_in, H, D_out):\n", " # initialize 2 instances of nn.Linear mods\n", " super(DoubleLayerNet, self).__init__()\n", " self.linear1 = torch.nn.Linear(D_in, H)\n", " self.linear2 = torch.nn.Linear(H, D_out)\n", " \n", " def forward(self, x):\n", " # in this fxn we accept a Var of input data and\n", " # return a Var of output data.\n", " h_relu = self.linear1(x).clamp(min=0)\n", " y_pred = self.linear2(h_relu)\n", " return y_pred\n", "\n", "# Next, again as usual, define batch size, input dimensions, hidden dimension and output dimension\n", "\n", "N, D_in, H, D_out = 64, 1000, 100, 10\n", "\n", "# Create some random tensors to hold both input and output\n", "\n", "x = Variable(torch.randn(N, D_in))\n", "y = Variable(torch.randn(N, D_out), requires_grad=False)\n", "\n", "# Build model by instantiating class defined above\n", "my_model = DoubleLayerNet(D_in, H, D_out)\n", "\n", "# Build loss fxn and optimizer\n", "\n", "criterion = torch.nn.MSELoss(size_average=False)\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)\n", "\n", "# and then we loop\n", "\n", "for i in range(500):\n", " # fwd pass, calculate predicted y by passing x to the model\n", " y_pred = my_model(x)\n", " \n", " #calculate and print loss\n", " loss = criteria(y_pred, y)\n", " print(t, loss.data[0])\n", " \n", " # Zero gradients, performs a backprop pass and update the weights as it goe along\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }