{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# TensorFlow로 간단한 linear regression을 구현\n", "\n", "https://www.youtube.com/watch?v=mQGwjrStQgg&feature=youtu.be" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hypothesis and cost function\n", "\n", "\\begin{equation}\n", "H(x) = Wx+b\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "Cost(W,b) = \\frac{1}{m} \\sum_{i=1}^{m}(H(x^{(i)})-y^{(i)})^{2}\n", "\\end{equation}\n", "\n", "학습: cost가 최소가 되는 W,b를 찾는것" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Build Graph using TF operations\n", "## 1. 텐서플로 연산을 사용하여 그래프를 생성\n", "\n", "\\begin{equation}\n", "H(x) = Wx+b\n", "\\end{equation}" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "# X and Y data\n", "x_train = [1, 2, 3]\n", "y_train = [1, 2, 3]\n", "\n", "# 텐서플로가 사용하는 변수. trainable\n", "W = tf.Variable(tf.random_normal([1]), name='weight')\n", "b = tf.Variable(tf.random_normal([1]), name='bias')\n", "# Our hypothesis XW+b\n", "hypothesis = x_train * W + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\\begin{equation}\n", "Cost(W,b) = \\frac{1}{m} \\sum_{i=1}^{m}(H(x^{(i)})-y^{(i)})^{2}\n", "\\end{equation}" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# cost/loss function\n", "cost = tf.reduce_mean(tf.square(hypothesis - y_train))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.5\n" ] } ], "source": [ "t = [1., 2., 3., 4.]\n", "with tf.Session() as sess:\n", " print(sess.run(tf.reduce_mean(t)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "tf.reduce_mean(t) ==> 2.5\n", "\n", "t의 값들의 평균을 구함" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GradientDescent" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Minimize\n", "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)\n", "train = optimizer.minimize(cost)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GradientDescent 부분을 통해 cost를 minimize함." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 3. Run/update graph and get results\n", "## 2. 3. 그래프를 갱신하고 결과를 얻음" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 13.1854 [-1.33570743] [ 1.58136511]\n", "20 0.735851 [-0.04721653] [ 2.02556849]\n", "40 0.566408 [ 0.11423562] [ 1.97974777]\n", "60 0.513496 [ 0.16655035] [ 1.89140821]\n", "80 0.466357 [ 0.20673704] [ 1.80296659]\n", "100 0.423553 [ 0.24411441] [ 1.71827638]\n", "120 0.384678 [ 0.27964747] [ 1.63752782]\n", "140 0.34937 [ 0.31350225] [ 1.56057036]\n", "160 0.317304 [ 0.3457652] [ 1.48722947]\n", "180 0.28818 [ 0.37651184] [ 1.41733515]\n", "200 0.26173 [ 0.40581352] [ 1.35072541]\n", "220 0.237707 [ 0.43373808] [ 1.28724647]\n", "240 0.21589 [ 0.46035028] [ 1.22675061]\n", "260 0.196075 [ 0.48571181] [ 1.16909802]\n", "280 0.178078 [ 0.50988138] [ 1.1141547]\n", "300 0.161733 [ 0.53291523] [ 1.06179345]\n", "320 0.146889 [ 0.55486649] [ 1.01189291]\n", "340 0.133407 [ 0.57578617] [ 0.96433771]\n", "360 0.121162 [ 0.59572256] [ 0.91901743]\n", "380 0.110041 [ 0.61472219] [ 0.87582701]\n", "400 0.0999414 [ 0.63282883] [ 0.83466631]\n", "420 0.0907683 [ 0.65008456] [ 0.79544002]\n", "440 0.0824372 [ 0.66652936] [ 0.7580573]\n", "460 0.0748708 [ 0.68220115] [ 0.72243142]\n", "480 0.0679988 [ 0.69713652] [ 0.68847978]\n", "500 0.0617577 [ 0.71137005] [ 0.65612382]\n", "520 0.0560893 [ 0.72493458] [ 0.62528843]\n", "540 0.0509412 [ 0.73786157] [ 0.5959022]\n", "560 0.0462656 [ 0.75018114] [ 0.56789696]\n", "580 0.0420192 [ 0.7619217] [ 0.54120791]\n", "600 0.0381625 [ 0.77311057] [ 0.51577306]\n", "620 0.0346598 [ 0.7837736] [ 0.49153358]\n", "640 0.0314786 [ 0.79393542] [ 0.46843323]\n", "660 0.0285893 [ 0.80361968] [ 0.44641864]\n", "680 0.0259653 [ 0.81284875] [ 0.42543861]\n", "700 0.0235821 [ 0.82164425] [ 0.40544456]\n", "720 0.0214176 [ 0.83002627] [ 0.38639015]\n", "740 0.0194518 [ 0.83801442] [ 0.36823127]\n", "760 0.0176665 [ 0.84562713] [ 0.35092577]\n", "780 0.016045 [ 0.85288209] [ 0.33443353]\n", "800 0.0145723 [ 0.85979611] [ 0.31871641]\n", "820 0.0132348 [ 0.86638516] [ 0.30373788]\n", "840 0.0120201 [ 0.87266451] [ 0.28946337]\n", "860 0.0109168 [ 0.87864894] [ 0.27585965]\n", "880 0.00991482 [ 0.88435191] [ 0.26289526]\n", "900 0.0090048 [ 0.88978696] [ 0.25054011]\n", "920 0.0081783 [ 0.89496654] [ 0.23876567]\n", "940 0.00742766 [ 0.89990282] [ 0.22754458]\n", "960 0.00674593 [ 0.904607] [ 0.21685077]\n", "980 0.00612676 [ 0.9090901] [ 0.20665959]\n", "1000 0.00556442 [ 0.91336256] [ 0.19694737]\n", "1020 0.0050537 [ 0.91743416] [ 0.18769155]\n", "1040 0.00458985 [ 0.92131442] [ 0.17887077]\n", "1060 0.00416857 [ 0.92501235] [ 0.17046453]\n", "1080 0.00378597 [ 0.92853642] [ 0.16245335]\n", "1100 0.00343848 [ 0.93189502] [ 0.15481864]\n", "1120 0.00312288 [ 0.93509567] [ 0.14754273]\n", "1140 0.00283625 [ 0.938146] [ 0.1406088]\n", "1160 0.00257592 [ 0.94105291] [ 0.13400066]\n", "1180 0.0023395 [ 0.94382322] [ 0.12770312]\n", "1200 0.00212477 [ 0.94646329] [ 0.12170152]\n", "1220 0.00192974 [ 0.94897938] [ 0.11598199]\n", "1240 0.00175263 [ 0.95137715] [ 0.11053121]\n", "1260 0.00159176 [ 0.95366216] [ 0.10533665]\n", "1280 0.00144566 [ 0.95583993] [ 0.10038622]\n", "1300 0.00131298 [ 0.95791525] [ 0.09566843]\n", "1320 0.00119246 [ 0.95989305] [ 0.09117243]\n", "1340 0.00108302 [ 0.96177793] [ 0.08688767]\n", "1360 0.000983619 [ 0.96357429] [ 0.08280428]\n", "1380 0.000893336 [ 0.96528614] [ 0.07891276]\n", "1400 0.000811341 [ 0.96691757] [ 0.07520416]\n", "1420 0.000736873 [ 0.96847236] [ 0.07166984]\n", "1440 0.00066924 [ 0.96995401] [ 0.06830166]\n", "1460 0.000607814 [ 0.97136605] [ 0.06509172]\n", "1480 0.000552024 [ 0.9727118] [ 0.06203262]\n", "1500 0.000501358 [ 0.9739942] [ 0.0591173]\n", "1520 0.000455342 [ 0.97521639] [ 0.05633902]\n", "1540 0.000413548 [ 0.97638112] [ 0.05369127]\n", "1560 0.000375591 [ 0.97749114] [ 0.05116797]\n", "1580 0.000341117 [ 0.97854894] [ 0.04876323]\n", "1600 0.00030981 [ 0.97955704] [ 0.04647155]\n", "1620 0.000281374 [ 0.98051786] [ 0.04428752]\n", "1640 0.000255547 [ 0.98143345] [ 0.04220613]\n", "1660 0.000232092 [ 0.982306] [ 0.04022259]\n", "1680 0.000210791 [ 0.98313755] [ 0.03833229]\n", "1700 0.000191443 [ 0.98393005] [ 0.03653081]\n", "1720 0.000173872 [ 0.9846853] [ 0.034814]\n", "1740 0.000157911 [ 0.98540503] [ 0.03317784]\n", "1760 0.000143417 [ 0.98609096] [ 0.03161858]\n", "1780 0.000130254 [ 0.98674464] [ 0.03013258]\n", "1800 0.000118299 [ 0.98736757] [ 0.02871647]\n", "1820 0.000107442 [ 0.98796123] [ 0.02736692]\n", "1840 9.75799e-05 [ 0.988527] [ 0.02608078]\n", "1860 8.8623e-05 [ 0.98906618] [ 0.02485507]\n", "1880 8.04903e-05 [ 0.98957992] [ 0.02368699]\n", "1900 7.31019e-05 [ 0.99006969] [ 0.02257388]\n", "1920 6.63927e-05 [ 0.99053645] [ 0.02151295]\n", "1940 6.02978e-05 [ 0.99098122] [ 0.02050187]\n", "1960 5.47633e-05 [ 0.99140507] [ 0.01953833]\n", "1980 4.97365e-05 [ 0.99180901] [ 0.01862007]\n", "2000 4.51714e-05 [ 0.992194] [ 0.01774498]\n" ] } ], "source": [ "# Launch the graph in a session.\n", "sess = tf.Session()\n", "# Initializes global variables in the graph.\n", "sess.run(tf.global_variables_initializer())\n", "\n", "# Fit the line\n", "for step in range(2001):\n", " sess.run(train)\n", " if step % 20 == 0:\n", " print(step, sess.run(cost), sess.run(W), sess.run(b))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Placeholders\n", "\n", "Placeholder: 세션을 실행할 때 필요한 값을 넘길 때 사용." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 23.297 [ 1.96590233] [ 1.62496638]\n", "20 0.00962304 [ 0.94277626] [ 1.32471764]\n", "40 0.00798899 [ 0.94218838] [ 1.30879986]\n", "60 0.00697683 [ 0.94595498] [ 1.29511988]\n", "80 0.0060929 [ 0.94949448] [ 1.282341]\n", "100 0.00532099 [ 0.95280206] [ 1.27039945]\n", "120 0.00464685 [ 0.9558931] [ 1.25923967]\n", "140 0.00405813 [ 0.95878166] [ 1.24881101]\n", "160 0.00354399 [ 0.96148115] [ 1.23906517]\n", "180 0.00309499 [ 0.96400386] [ 1.22995758]\n", "200 0.00270288 [ 0.96636122] [ 1.22144663]\n", "220 0.00236043 [ 0.96856421] [ 1.21349287]\n", "240 0.00206139 [ 0.97062302] [ 1.20606029]\n", "260 0.00180021 [ 0.97254694] [ 1.19911397]\n", "280 0.00157214 [ 0.97434491] [ 1.19262278]\n", "300 0.00137297 [ 0.9760251] [ 1.18655694]\n", "320 0.00119902 [ 0.97759521] [ 1.1808883]\n", "340 0.00104712 [ 0.9790625] [ 1.17559099]\n", "360 0.000914452 [ 0.98043376] [ 1.17064023]\n", "380 0.0007986 [ 0.9817152] [ 1.16601384]\n", "400 0.000697417 [ 0.98291272] [ 1.16169024]\n", "420 0.000609056 [ 0.9840318] [ 1.15764999]\n", "440 0.000531887 [ 0.98507762] [ 1.1538744]\n", "460 0.000464507 [ 0.98605484] [ 1.1503464]\n", "480 0.000405663 [ 0.9869681] [ 1.14704931]\n", "500 0.000354267 [ 0.98782152] [ 1.14396811]\n", "520 0.000309387 [ 0.98861903] [ 1.14108884]\n", "540 0.00027019 [ 0.98936439] [ 1.13839781]\n", "560 0.000235956 [ 0.99006093] [ 1.13588309]\n", "580 0.000206065 [ 0.99071187] [ 1.133533]\n", "600 0.000179956 [ 0.99132019] [ 1.13133681]\n", "620 0.000157154 [ 0.99188864] [ 1.1292845]\n", "640 0.000137247 [ 0.99241984] [ 1.12736666]\n", "660 0.000119859 [ 0.99291629] [ 1.12557435]\n", "680 0.000104673 [ 0.99338013] [ 1.12389958]\n", "700 9.14121e-05 [ 0.99381369] [ 1.12233436]\n", "720 7.98333e-05 [ 0.99421877] [ 1.1208719]\n", "740 6.97188e-05 [ 0.99459738] [ 1.11950505]\n", "760 6.08877e-05 [ 0.99495119] [ 1.11822772]\n", "780 5.31733e-05 [ 0.99528182] [ 1.11703408]\n", "800 4.6437e-05 [ 0.99559087] [ 1.1159184]\n", "820 4.0554e-05 [ 0.99587953] [ 1.11487603]\n", "840 3.54157e-05 [ 0.99614942] [ 1.11390185]\n", "860 3.09304e-05 [ 0.99640155] [ 1.11299145]\n", "880 2.70114e-05 [ 0.99663723] [ 1.11214054]\n", "900 2.35899e-05 [ 0.99685746] [ 1.11134553]\n", "920 2.06012e-05 [ 0.99706328] [ 1.11060238]\n", "940 1.79884e-05 [ 0.99725574] [ 1.10990739]\n", "960 1.57094e-05 [ 0.99743551] [ 1.10925865]\n", "980 1.37192e-05 [ 0.99760342] [ 1.10865223]\n", "1000 1.19814e-05 [ 0.99776036] [ 1.10808563]\n", "1020 1.04639e-05 [ 0.99790704] [ 1.10755622]\n", "1040 9.13835e-06 [ 0.99804407] [ 1.10706139]\n", "1060 7.98016e-06 [ 0.99817216] [ 1.10659885]\n", "1080 6.96877e-06 [ 0.99829185] [ 1.10616672]\n", "1100 6.08597e-06 [ 0.99840379] [ 1.10576284]\n", "1120 5.31547e-06 [ 0.99850827] [ 1.10538554]\n", "1140 4.64177e-06 [ 0.99860597] [ 1.10503268]\n", "1160 4.05328e-06 [ 0.99869728] [ 1.10470307]\n", "1180 3.54009e-06 [ 0.99878258] [ 1.10439515]\n", "1200 3.0916e-06 [ 0.99886233] [ 1.10410726]\n", "1220 2.69998e-06 [ 0.99893683] [ 1.10383832]\n", "1240 2.3579e-06 [ 0.99900645] [ 1.10358703]\n", "1260 2.05914e-06 [ 0.99907148] [ 1.10335219]\n", "1280 1.79858e-06 [ 0.99913228] [ 1.10313272]\n", "1300 1.57044e-06 [ 0.99918902] [ 1.10292757]\n", "1320 1.3717e-06 [ 0.99924219] [ 1.10273588]\n", "1340 1.19803e-06 [ 0.99929178] [ 1.10255671]\n", "1360 1.04596e-06 [ 0.99933821] [ 1.10238922]\n", "1380 9.13619e-07 [ 0.99938154] [ 1.10223258]\n", "1400 7.97728e-07 [ 0.99942201] [ 1.10208642]\n", "1420 6.96617e-07 [ 0.99945992] [ 1.10194981]\n", "1440 6.0862e-07 [ 0.99949527] [ 1.10182214]\n", "1460 5.31349e-07 [ 0.99952829] [ 1.10170281]\n", "1480 4.64173e-07 [ 0.9995591] [ 1.10159147]\n", "1500 4.05457e-07 [ 0.99958801] [ 1.1014874]\n", "1520 3.54099e-07 [ 0.99961495] [ 1.10138988]\n", "1540 3.09269e-07 [ 0.99964017] [ 1.10129905]\n", "1560 2.69982e-07 [ 0.99966371] [ 1.10121381]\n", "1580 2.35905e-07 [ 0.9996857] [ 1.10113442]\n", "1600 2.05917e-07 [ 0.99970639] [ 1.10106015]\n", "1620 1.79893e-07 [ 0.99972558] [ 1.10099077]\n", "1640 1.57064e-07 [ 0.99974358] [ 1.1009258]\n", "1660 1.37143e-07 [ 0.99976033] [ 1.10086513]\n", "1680 1.19873e-07 [ 0.99977601] [ 1.10080862]\n", "1700 1.04663e-07 [ 0.99979061] [ 1.10075557]\n", "1720 9.15154e-08 [ 0.99980432] [ 1.10070634]\n", "1740 7.99159e-08 [ 0.99981713] [ 1.10066009]\n", "1760 6.97393e-08 [ 0.99982917] [ 1.10061669]\n", "1780 6.08945e-08 [ 0.99984032] [ 1.10057628]\n", "1800 5.31413e-08 [ 0.99985081] [ 1.10053861]\n", "1820 4.64375e-08 [ 0.99986058] [ 1.10050344]\n", "1840 4.05344e-08 [ 0.9998697] [ 1.1004703]\n", "1860 3.54076e-08 [ 0.99987817] [ 1.10043943]\n", "1880 3.09444e-08 [ 0.99988616] [ 1.10041082]\n", "1900 2.7042e-08 [ 0.99989355] [ 1.10038412]\n", "1920 2.36139e-08 [ 0.99990052] [ 1.10035908]\n", "1940 2.06039e-08 [ 0.99990708] [ 1.10033524]\n", "1960 1.80461e-08 [ 0.9999131] [ 1.10031366]\n", "1980 1.57384e-08 [ 0.99991882] [ 1.10029304]\n", "2000 1.37607e-08 [ 0.999924] [ 1.10027397]\n" ] } ], "source": [ "X = tf.placeholder(tf.float32, shape=[None]) ###\n", "Y = tf.placeholder(tf.float32, shape=[None]) ###\n", "\n", "# 텐서플로가 사용하는 변수. trainable\n", "W = tf.Variable(tf.random_normal([1]), name='weight')\n", "b = tf.Variable(tf.random_normal([1]), name='bias')\n", "# Our hypothesis XW+b\n", "hypothesis = X * W + b ##\n", "\n", "# cost/loss function\n", "cost = tf.reduce_mean(tf.square(hypothesis - Y))\n", "\n", "# Minimize\n", "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)\n", "train = optimizer.minimize(cost)\n", "\n", "# Launch the graph in a session.\n", "sess = tf.Session()\n", "# Initializes global variables in the graph.\n", "sess.run(tf.global_variables_initializer())\n", "\n", "# Fit the line\n", "for step in range(2001):\n", " cost_val, W_val, b_val, _ = sess.run([cost, W, b, train],\n", " feed_dict={X:[1,2,3,4,5],Y:[2.1,3.1,4.1,5.1,6.1]})\n", " if step % 20 == 0:\n", " print(step, cost_val, W_val, b_val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "학습데이터를 feed_dict로 같이 전달할 수 있음.\n", "\n", "tf.placeholder에는 인자로 shape=[None]을 줄 수 있고 원하는대로 값을 줄 수 있고 개수의 제한이 없다는 뜻임" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 6.09989405]\n", "[ 3.60008383]\n", "[ 2.60016012 4.60000801]\n" ] } ], "source": [ "# Testing our model, H(x) = 1.0 x + 1.1\n", "print(sess.run(hypothesis, feed_dict={X: [5]})) # [6.1]\n", "print(sess.run(hypothesis, feed_dict={X: [2.5]})) # [3.6]\n", "print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]})) # [2.6, 4.6]" ] } ], "metadata": { "kernelspec": { "display_name": "Python [Root]", "language": "python", "name": "Python [Root]" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 0 }