{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "trX = np.linspace(-1, 1, 101)\n", "trY = 2 * trX + np.random.randn(*trX.shape) * 0.33 # create a y value which is approximately linear but with some random noise" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X = tf.placeholder(\"float\") # create symbolic variables\n", "Y = tf.placeholder(\"float\")\n", "\n", "def model(X, w):\n", " return tf.multiply(X, w) # lr is just X*w so this model line is pretty simple\n", "\n", "w = tf.Variable(0.0, name=\"weights\") # create a shared variable (like theano.shared) for the weight matrix\n", "y_model = model(X, w)\n", "\n", "cost = tf.square(Y - y_model) # use square error for cost function\n", "\n", "train_op = tf.train.GradientDescentOptimizer(0.01).minimize(cost) # construct an optimizer to minimize cost and fit line to my data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.00863\n" ] } ], "source": [ "# Launch the graph in a session\n", "with tf.Session() as sess:\n", " # you need to initialize variables (in this case just variable W)\n", " tf.global_variables_initializer().run()\n", "\n", " for i in range(100):\n", " for (x, y) in zip(trX, trY):\n", " sess.run(train_op, feed_dict={X: x, Y: y})\n", "\n", " print(sess.run(w)) # It should be something around 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 0 }