{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluating Models\n", "\n", "This code will implement two models. The first is a simple regression model, we will show how to call the loss function, MSE during training, and output it after for test and training sets.\n", "\n", "The second model will be a simple classification model. We will also show how to print percent classified for both the test and training sets.\n", "\n", "### Regression Model\n", "\n", "For the regression model we will generate 100 random samples from a Normal(mean=1, sd=0.1). The target will be an array of size 100 filled with the target value of 10.0.\n", "\n", "We will fit the linear model $y=A \\cdot x$ (no y intercept). The theoretical value of `A` is `10.0`.\n", "\n", "To start we load the necessary libraries and reset the computational graph." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops\n", "ops.reset_default_graph()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Start a graph session:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sess = tf.Session()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Declare the batch size:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "batch_size = 25" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate Data for Regression\n", "\n", "Here we generate the data required for the regression. We also specify the necessary placeholders.\n", "\n", "After we split the data into a 80-20 train-test split." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Create data\n", "x_vals = np.random.normal(1, 0.1, 100)\n", "y_vals = np.repeat(10., 100)\n", "x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)\n", "y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)\n", "\n", "# Split data into train/test = 80%/20%\n", "train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)\n", "test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))\n", "x_vals_train = x_vals[train_indices]\n", "x_vals_test = x_vals[test_indices]\n", "y_vals_train = y_vals[train_indices]\n", "y_vals_test = y_vals[test_indices]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Variables and Operations\n", "\n", "We create the model variable `A` and the multiplication operation." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Create variable (one model parameter = A)\n", "A = tf.Variable(tf.random_normal(shape=[1,1]))\n", "\n", "# Add operation to graph\n", "my_output = tf.matmul(x_data, A)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loss, Optimization Function, and Variable Initialization\n", "\n", "We use the L2 loss, and the standard Gradient Descent Optimization with a learning rate of 0.02." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Add L2 loss operation to graph\n", "loss = tf.reduce_mean(tf.square(my_output - y_target))\n", "\n", "# Create Optimizer\n", "my_opt = tf.train.GradientDescentOptimizer(0.02)\n", "train_step = my_opt.minimize(loss)\n", "\n", "# Initialize variables\n", "init = tf.global_variables_initializer()\n", "sess.run(init)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run Regression\n", "\n", "We iterate 100 times through the training step, selecting a random batch of data each time." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step #25 A = [[ 6.47693586]]\n", "Loss = 12.0633\n", "Step #50 A = [[ 8.68510914]]\n", "Loss = 2.556\n", "Step #75 A = [[ 9.50503254]]\n", "Loss = 1.21711\n", "Step #100 A = [[ 9.77385426]]\n", "Loss = 1.04426\n" ] } ], "source": [ "# Run Loop\n", "for i in range(100):\n", " rand_index = np.random.choice(len(x_vals_train), size=batch_size)\n", " rand_x = np.transpose([x_vals_train[rand_index]])\n", " rand_y = np.transpose([y_vals_train[rand_index]])\n", " sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})\n", " if (i+1)%25==0:\n", " print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))\n", " print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation of Regression Model\n", "\n", "For the regression model evaluation, we will run the loss wih the training and test set." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MSE on test:0.96\n", "MSE on train:1.16\n" ] } ], "source": [ "# Evaluate accuracy (loss) on test set\n", "mse_test = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})\n", "mse_train = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})\n", "print('MSE on test:' + str(np.round(mse_test, 2)))\n", "print('MSE on train:' + str(np.round(mse_train, 2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification Example\n", "\n", "For the classification example, we generate data as follows:\n", "\n", "The input data will be a sample of size 50 from a Normal(mean = -1, sd = 1) and a sample of 50 from a Normal(mean = 1, sd = 1).\n", "\n", "The target data will be 50 values of 0 and 50 values of 1.\n", "\n", "We fit the binary classification model:\n", "\n", "- If $sigmoid(x+A)<0.5$ Then we predict class 0\n", "- If $sigmoid(x+A)>=0.5$ Then we predict class 1\n", "\n", "Theoretically A should be\n", "\n", "$$ - \\frac{mean1 + mean2}{2} = 0$$\n", "\n", "We start by resetting the computational graph:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "ops.reset_default_graph()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a graph session:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sess = tf.Session()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Declare the batch size:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "batch_size = 25" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate Classification Data and Targets\n", "\n", "We generate the classification data as described above. Then we create the necessary placeholders.\n", "\n", "After, we split the data in a 80-20 train-test split." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Create data\n", "x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(2, 1, 50)))\n", "y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))\n", "x_data = tf.placeholder(shape=[1, None], dtype=tf.float32)\n", "y_target = tf.placeholder(shape=[1, None], dtype=tf.float32)\n", "\n", "# Split data into train/test = 80%/20%\n", "train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)\n", "test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))\n", "x_vals_train = x_vals[train_indices]\n", "x_vals_test = x_vals[test_indices]\n", "y_vals_train = y_vals[train_indices]\n", "y_vals_test = y_vals[test_indices]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Variables and Operations\n", "\n", "We create the model variable, `A`, and the model operation, which is the adding of `A` to the input data. Note that we do not put the `sigmoid()` function in here because it will be included in the loss function. This also means that for prediction, we will need to include the sigmoid function." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Create variable (one model parameter = A)\n", "A = tf.Variable(tf.random_normal(mean=10, shape=[1]))\n", "\n", "# Add operation to graph\n", "# Want to create the operstion sigmoid(x + A)\n", "# Note, the sigmoid() part is in the loss function\n", "my_output = tf.add(x_data, A)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loss, Optimization Function, and Variable Initialization\n", "\n", "The loss will be the sigmoid-cross-entropy. We wrap that function in a `tf.reduce_mean()` so that we can reduce the sigmoid-cross-entropy over the whole batch.\n", "\n", "The optimization function we use is again the standard Gradient Descent Optimization with a learning rate of 0.05." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Add classification loss (cross entropy)\n", "xentropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(my_output, y_target))\n", "\n", "# Create Optimizer\n", "my_opt = tf.train.GradientDescentOptimizer(0.05)\n", "train_step = my_opt.minimize(xentropy)\n", "\n", "# Initialize variables\n", "init = tf.global_variables_initializer()\n", "sess.run(init)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run Classification\n", "\n", "We iterate the classification training operation for 1800 iterations and print off the `A` values along with the loss every 200 iterations" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step #200 A = [ 5.70815563]\n", "Loss = 2.20039\n", "Step #400 A = [ 1.35876882]\n", "Loss = 0.53517\n", "Step #600 A = [-0.13932848]\n", "Loss = 0.246674\n", "Step #800 A = [-0.50210857]\n", "Loss = 0.296652\n", "Step #1000 A = [-0.57548112]\n", "Loss = 0.308673\n", "Step #1200 A = [-0.55576968]\n", "Loss = 0.242491\n", "Step #1400 A = [-0.61009347]\n", "Loss = 0.269398\n", "Step #1600 A = [-0.60937202]\n", "Loss = 0.223748\n", "Step #1800 A = [-0.5841468]\n", "Loss = 0.245155\n" ] } ], "source": [ "# Run loop\n", "for i in range(1800):\n", " rand_index = np.random.choice(len(x_vals_train), size=batch_size)\n", " rand_x = [x_vals_train[rand_index]]\n", " rand_y = [y_vals_train[rand_index]]\n", " sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})\n", " if (i+1)%200==0:\n", " print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))\n", " print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation of Classification Results" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on train set: 0.95\n", "Accuracy on test set: 0.95\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAEKCAYAAAA7LB+5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VPW9//HXB1EBEwggiBtEUJaKlXpb61YZ0SpQF3p/\nXhEQpXrtldar/Kp4XSok9OdyXW5TrV3UgiLi1StK1eoVCw1utdqCIIIrhqBYBSGEsAXI5/fHOZlO\nkpnJZJnkJLyfj8c8MnPO93zP58yc+eQ737N8zd0REZHo6tDaAYiISHpK1CIiEadELSIScUrUIiIR\np0QtIhJxStQiIhGnRN2KzOzXZnZTa8eRCTP7k5ldmqW6DzezcjOz8HVvM3vZzDab2Z1mdoOZ3Z+N\ndYu0BUrUWWRmJWa2LUxCX5nZs2Z2aPV8d5/s7re0ZozVzGxfMyswsw/MbIuZrTazB82sb7bX7e5r\n3b2r/+Ok/h8CX7p7N3ef6u63ufsPsxmDmXUxswozey6b64kqMxtmZn81s61m9paZHZum7GAzW2hm\nZeH+MiZhXj8zqwr3+S3h3zbRGIkyJerscuB77t4VOBj4Erg32ys1s30asdg84GzgQqAbcCzwN+D0\nZgwtU/2Alc1RUXUrPQPnAzuAM83soOZYd6Ya+Xk15/r3BeYDs4G88O/vzaxjkrL7AL8HngG6A/8G\nzDGzIxOKOdDN3XPDf8CRaIy0ae6uR5YewCfAiITXo4D3El7PAmaEz4cDa4GfAF8AnwGTEsqOBpYA\nm4E1wPSEef2AKuDScF4x8BxwZa14lgHnJonzDGArcEiabfkTcGn4vD+wENhA8M9nDtA1oex/AJ8C\n5cAq4LRw+reAt8Jt+By4q1b8HcL3pBLYGS4/ApgOPJJQ/wnAa8AmYCkwvFac/w94Ndym/hl+VguB\nnwF/BX5Sa95hBP/IvgTWA/ckzLuc4J9KObACGBZOr0pcd4rP+rrwfXiYIEE+G67jq/D5IQnLdwdm\nhvvFV8BT4fR3CBoD1eU6hjF+vQH76XeBtbWmrQHOTFL2aKC81rQXgcJan+U+rf39a08PtahbiJl1\nAcYCf05TrA+QCxwC/Ctwn5l1C+dVABPdvRvwPeAKMzu31vKnAoOAswi+/BclrP/YsN7nk6z3dOBN\nd1+X6eYAt4bxDiFIZAXhegYCPwb+yYNfEmcBJeFyvwCKwm0YADyRUKcDuPsPgEeB//SgNbYocX7Y\ndfQcQdLrDlwLzDOzngl1XUTw/uUSJJz0GxN078TC9c4FLkmY1yFc3ydAX+BQ4L/Def8CTAMuCrf1\nXIIkGo83jT4EybkvQVdPB4JEfHg4bRtwX0L5OUBngve7N/DzcPpsYGJCue8B69x9eRjjJjPbGP5N\nfL7RzK4LlzkaWF4rvuXh9NqS/UIxYGjCawdKzKzUzGbW+mykEZSos2++mW0kaEWeAdyVpmwl8DN3\n3+PuLxAk50EA7v6yu78bPl9BkCyGJyzrBK3sHe6+k+Dn6ZFmNiCcfxHwuLvvTrLengQtu4y4+8fu\nvtDdd7v7VwRJozqWPcB+wFAz6+jupe7+ScL2HWlmPd19m7u/mek6E0wA/uDuL4axLCRoBY9OKPOQ\nu7/n7lXuvieDOi8Glrn7e8BjwNcS+miPJ+i2ui58byvd/fVw3mXAHe6+JIxltbuvDefV1+Wyh+Dz\n2uXuO919o7s/HT7fCtxG8I8XMzuY4B/ev7l7ebh/vBLWMwcYZWY54euLgEeqV+Lu3d29R/g38XkP\nd78jLJZDsH8m2kzwj66294AvzexaM+toZmcSfPZdwvkbCH459QP+Kazj0XreC6mHEnX2nefuPQiS\n178DL5tZ7xRlv3L3qoTX2wi+RJjZt81skZl9aWZlBH2DB9Za/tPqJ+5eSdBivSjspx1Hwhe49noJ\nklFGzKyXmT1mZp+GscypjsXdPwamELSwvzCzuWGigSCxDQLeM7O/mNn3Ml1ngn7ABWGLcKOZbQJO\nJmihVlubfNGUJhImE3f/HHiZf7SqDwfW1PpcSJj3cQPXVW29u++qfmFmnc3st+EB6DJgMZAXfnaH\nARvdvbx2JWG8rwH/J/z1NYqGJ8YKoGutaV2BLUnWtxsYQ3A843Pg/wKPE+577r7V3ZeE/yTXA1cS\n9Pvn1K5LMqdEnX0G4IGnCVpSpzSinkcJDvgc6u55wG+p22qr/XN7NkEL63Rgq7v/JUXdfwSON7ND\nMozlNoJ+yKFhLBclxuLu/+3u3yFIqgC3h9M/dvfx7t4LuAN40sw6Z7jOamuB2WGLsLp1mOvudyaU\nyfiWkGZ2InAUcIOZfW5mnxO0oseF3R5rgb7h82SxDEgyHYJ/sl0SXvepNb92jNeEcXwrfE9PrQ4x\nXE8PM6udTKtVd3/8C/B6mLyrt6/6zIvER/W068Ni7wJfr1Xn18Ppdbj7CnePuXsvdx9F8B6k+3Xk\n1P8LQ9JQom5BZnYeQb9kY85oyAE2ufsuMzseGF+7+toLuPsbBAn1blK3pqu7D14Cnjaz48xsHzPL\nMbN/M7NJSRbJJWiFlYd9xlPjQZgNNLPTzGw/gq6O7QT/nDCzCWZW/StgM8EXuLprItMv8hzgHDM7\n08w6mFknMxue7p+MmU03s0UpZk8CFhD0/R4bPo4BDiBonb5J0HK8PTyFb38zOylc9kHgWjM7LlzP\nADM7PJy3FBgfxjiSmt1UyeQSvFflZtaDsM8fwN3/DrwA/MrM8sIuh+8kLDsfOA64iiBpk7Bs9ZkX\niY/qabeHxYqBPWb272a2n5ldSfDZJH3PzOyY8H3oYmbXEvwTeiicd3y4D1jYN/0L4E/uXqd1LplT\nos6+Z8PWy2aCswouDvtCM5HY6voR8LOwnp8S/NxMVTbRbIIDPXPqWdf5BAcaHwfKCM4m+CeC1nbt\n+gvDeWUEZyfMS5i3P0ELej2wDugF3BjOGwm8a2blBP3aY8MumnTx1+DunwLnhXWuJzhYeC3/2JeT\n1XM4QfdADWa2P8F23+Pu6939y/BRQvC+XRJ2eZxD0NotJWjdXhDG8iRwCzA33KangR5h9VMIDi5u\nIuh2erqeTSsiaIFvAF6n7kHficBugj7iL4CrE96THQSfwRHAU/Wsp46wC2YMQXfPJoJ/XudVH8+w\n4IKjP9SK5XPg78BpwHcTunH6A/9LcBbMcoJTHms3KqSBzL3+74eZXU1wFB3gAXe/J6tRSbMxs4nA\n5e5+ar2F2ykzWwKc7u6bWjuWbDGzm4Gj3P3i1o5Fml+9LWozO5rgINA3gWEEPztT9ctJhFhwSuCP\nCPqz91ruflw7T9I9CL6je/Xn3J5l0vUxBHgjPG1oD8HR6O9nNyxpqvC0qS8JfqI+1srhSJaY2b8S\ndMn8wd3rdO9I+1Bv14eZDSY4WHEiwdVifwTecver0y4oIiLNos61/LW5+3tm9p8ECXoL8DbBQQ0R\nEWkBGR1MrLGA2S0E9wX4Ta3pGs5cRKSB3L3eU1MzOj3PzHqFf/sS9E8n7fP0CNy8JBuP6dOnt3oM\n2r69Y/umX3IJPn16jcf0Sy5pN9vX3j+/hj4yVW/XR2heeGR5F/Ajd699XwAREcmSjBK178Xn4IqI\ntDZdmZiBWCzW2iFklbavbdP2tX9K1Blo7zuKtq9t0/a1fw0+6yNlRWbeXHWJtJT8/HzWrKl3bAGR\nJunXrx8lJSV1ppsZnsFZH0rUslcLvyitHYa0c6n2s0wTtbo+REQiTolaRCTilKhFRCJOiVpEJOKU\nqEX2ch06dCA3N5ebb765tUNps4488kj2339/Lr44O+M2KFGLRFwsFqNHjx7s2rWr/sKNYGYsX76c\nn/3sZynLLFy4kCFDhpCTk8Ppp59OaWlpyrL5+fl06dKFrl270rVrV0aOHBmfN3nyZHJzc+PzOnXq\nRLdu3eLz33vvPU4//XTy8vIYOHAg8+fPr1H3gw8+yFFHHUXXrl0ZPXo0n38eH8eXzZs3M2nSJA46\n6CD69OlDYWFhjWVff/11vv3tb9O1a1eGDRvGa6/VvH33LbfcQr9+/cjLy2P8+PFUVFTE523atImx\nY8fSq1cvevfuzcSJE2vM/+ijj7jxxhvJmma8uYiLtDVR329LSkp8n3328Z49e/qTTz6ZlXWYmX/8\n8ccp52/YsMG7devm8+bN8507d/rUqVP9hBNOSFk+Pz/fFy1alNG6J02a5Jdddpm7u+/evdsHDhzo\nRUVFXlVV5YsWLfIDDjjAP/zwQ3d3Ly4u9t69e/uqVat8165dPnnyZB8+fHiNui644ALfsWOHl5SU\n+IABA/yhhx5yd/eNGzf6gQce6PPmzfOqqiqfM2eOd+/e3cvKytzd/aGHHvIhQ4b4Z5995lu3bvXz\nzjvPL7nkknjdkydP9rPOOssrKiq8vLzczzjjDL/mmmtqbEtBQYFPnDgx6Xam2s/C6fXn10wKZVRR\nxHd4kWTq22+nT5/uBAPm1nhMnz494/KpymZixowZfsopp/g111zjZ599dqPrSae+RH3//ff7ySef\nHH+9detW79y5s7///vtJy+fn5/vChQvrXW9FRYXn5ub6K6+84u7uK1as8Nzc3BplzjzzTJ82bZq7\nu1977bV+5ZVXxuetW7fOzcxXr17t7u4HHnig/+1vf4vPv/XWW/3UU091d/fnnnvOhw4dWqPugQMH\n+syZM93d/fzzz/e77rorPu/111/3Tp06+fbt293dfdSoUf7rX/86Pv++++7zkSNH1qgvm4laXR8i\nETZ79mwuuugixo8fz4svvsj69etTlv3xj39M9+7d6dGjR/xv9fNhw4Y1OoZ3332XY489Nv66S5cu\nDBgwgHfffTflMhMmTOCggw5i5MiRLF++PGmZefPm0bt3b0455RSApBeEuDsrVqyIP08sU1VVBRCf\nnzit+nmqZTOpu7Kykg8//BAI3ttnn32WsrIyNm3axLx58xg9enTK7W9uStQiEfXqq69SWlrKBRdc\nwHHHHceRRx7J3LlzU5a/77772LRpExs3boz/rX7+9ttvNzqOioqKGv3IAN26dWPLli1Jy8+dO5eS\nkhLWrFlDLBbjrLPOory8vE652bNn1zj4NnjwYHr37s1dd93F7t27WbBgAYsXL2bbtm0AjB49miee\neIIVK1awfft2ZsyYQYcOHeLzR44cye23305FRQUfffQRs2bNis876aSTWLduHY8//ji7d+/m4Ycf\n5uOPP47PHzVqFA8++CBr1qxh8+bN3HHHHQDx+ccddxyVlZX07NmTXr160bFjRyZPntzo97ShlKhF\n0igoKEj6U7SgoCDj8qnK1mf27NmceeaZdO/eHYBx48bx8MMPN3JLGi8nJ6dOoi0vLyc3Nzdp+RNP\nPJH999+fTp06cf3115OXl8crr7xSo8zatWtZvHhxjUTdsWNH5s+fz3PPPcfBBx/Mz3/+c8aOHcth\nhx0GwIgRIygsLOSf//mfOeKII+jfvz+5ubnx+ffccw+dOnXiqKOO4vvf/z7jx4+Pz+vRowe///3v\nufvuu+nTpw8LFizgu9/9bnz+pZdeyrhx44jFYhxzzDGMGDECID7//PPPZ9CgQWzdupXy8nL69+/P\nhAkTmvrWZi6T/pFMHqiPWtqgqO6327dv927dunlubq736dPH+/Tp4z169PAOHTr48uXLky5zxRVX\neE5Ojufm5tZ45OTk1OmfTdTQPuqKigrv0qVLyj7q2oYMGeLPPvtsjWm33HJLjQOBqZx00kl+//33\nJ533wQcfeE5OTvyAYG033nijjx8/Pum83bt3e79+/XzBggVJ57/44ot++OGHx1/n5OTUeN/ffvvt\nOv3prX4wEfi/wApgOfAosF+SMkkDEYmyqO63c+fO9Z49e/qnn37qX3zxRfwxfPjwOmcbNFV9iXr9\n+vWel5fnTz31lO/YscOvu+46P/HEE5OWLS0t9ddee80rKyt9x44dfscdd3jv3r1948aNNcoNGjQo\nfkZGouXLl/uOHTt869atfuedd3r//v29srLS3d137NjhK1ascHf3NWvWeCwW85/+9KfxZT/++GP/\n6quvfM+ePf788897r169fNWqVfH5S5cu9V27dvnmzZv96quv9lNOOSU+b+PGjfH34N133/WhQ4f6\ngw8+GJ8/YsQIv+qqq3z79u2+bds2nzx5co3l3Vs5UQOHAKurkzPwOHBxknJJAxGJsqjutyNHjvSp\nU6fWmf7EE0/4wQcf7Hv27Gm2ddWXqN3dFy5c6IMHD/YuXbr4aaed5mvWrInPu+KKK3zy5MnuHiS5\nr3/9656Tk+MHHnign3HGGb5kyZIadf35z3/2nJwcr6ioqLOeqVOnevfu3T03N9dHjx5dI66ysrJ4\n3QcffLDfdNNNXlVVFZ//xBNP+CGHHOIHHHCAf+Mb3/CXXnqpRt3jxo3zbt26eV5enl944YW+fv36\n+LwPPvjABw0a5AcccIDn5+d7UVFRjWVLSkr8nHPO8Z49e3rPnj191KhR/tFHH9Uok81EXe9tTs3s\nEODPwDBgC/A08At3/2Otcl5fXSJRo9ucBmdx7L///lx11VV1LhKRzAwePJh169YxduxYHnjggTrz\nm3qb04zuR21mVwG3ANuABe4+MUkZJWppc5SopSU0NVHXO7itmeUB5wH9gM3Ak2Y23t3rnCeUeHQ7\nFotpCJ12rKioiLKysjrT8/LymDJlSitEJBJ9xcXFFBcXN3i5TLo+zgfOcvfLw9cTgW+7+5W1yqlF\nvRcpKChIetpZqulRpRa1tISWGOGlFDjBzDqZmQGnA6saHKmIiDRKvYna3d8EngSWAssAA+7Pclwi\nIhKqt48awN0LAR0OFhFpBbqEXEQk4pSoRdqw8ePH88wzz2R9PZWVlQwZMoQNGzZkfV1SV0ZdHyJ7\ni1SnHTaXhpy+mJ+fz44dO/jkk0/o3LkzAL/73e+YM2cOf/rTn1i+fDnLly9Pe0e9adOmMX/+fFat\nWsXNN9/MtGnTUpYtLi5mxowZLFmyhB49erB69er4vP3224/LLruM22+/nbvuuivDrZXmokQtkqCs\nrCyrpxc2pG4zY8+ePRQVFXHDDTfUmA5w//3313sHt6OOOoo777yT3/zmN/Wu74ADDuCyyy5j/Pjx\n3HrrrXXmjxs3jmHDhnHbbbex7777Zrwd0nTq+hCJsKlTp3L33XcnvZ/zCy+8wPDhw9MuP3HiRM46\n6yxycnLqXde3vvUtJkyYwBFHHJF0/qGHHkqPHj144403Mgtemo0StUiEffOb3yQWi3HnnXfWmL5t\n2zY++eQTBg0a1KLxDB48mGXLlrXoOkWJWiTyCgsL+eUvf8lXX30Vn1bdj57q5v3Zkpubm9U+fElO\niVok4o4++mjOPvtsbrvttvi0vLw8gBrDYQ0dOpTc3Fy6du3Ka6+9lpVYtmzZEl+3tBwlapE2oKCg\ngAceeIDPPvsM+McAsx988EG8zIoVK9iyZQvl5eWcfPLJWYlj1apVNQa6lZahRC3SBgwYMICxY8dy\nzz33xKeNHj2axYsXp11u9+7d7Nixg6qqKnbt2sXOnTvjI3WvWbOGDh06UFpaCgSDiOzcuZPKykqq\nqqrYuXMnu3btite1bt06Nm3axAknnJCFLZR0dHqeSIK8vLysnp7XkG6D6tPwqk2bNo05c+bEp//w\nhz9k7NixXH/99SnruPzyy3n44Yfjy9x6663MmjWLiy++mNLSUvLz8zn00EMBePnllznttNPiZbt0\n6cLw4cNZtGgRAI8++iiXXHKJTs1rBUrUIgmidC/txAtOIBgRe9u2bfHXRx99NMOGDeOZZ57h3HPP\nTVrHrFmzmDVrVtJ5L7/8MjfccAP77LMPAMOHD4+3tmurrKxk1qxZvPzyy43ZFGkiJWqRNmzOnDmN\nXvamm27KuOx+++3HypUrG70uaRr1UYuIRJwStYhIxClRi4hEXL2J2swGmtlSM1sS/t0cjkouIiIt\noN6Die7+AfANADPrAHwKPJ3luEREJNTQro8zgI/dfW02ghERkboamqjHAo9lIxAREUku4/OozWxf\n4Fwg5WVQiVd0xWIxYrFYE0KT9i7VaCoNGQVlbzd+/HguvPDClBe8NNazzz7L3LlzeewxtcuaU3Fx\nMcXFxQ1eriEXvIwC/ubu61MVyOalt9L+pBpNpTX3o6Jp0ygL732RDXl9+zJlxoyMyjZ0KK7nn3+e\n2267jRUrVtC5c2fOOecc7r777pSDBqQbpuucc87hpptuYsWKFQwdOrSJWy3VajdgCwsLM1quIYl6\nHOr2kHaurLSUgvz8rNVfUFKScdmGDsVVXl7OzTffzKmnnsrOnTsZN24c1113Hb/61a+S1l/fMF0X\nXnghv/3tb7n33nszjlmyI6M+ajPrTHAg8anshiMiiRoyFNeFF17ImWeeSadOnejWrRuXX3552vtS\n1zdMVywW4w9/+EPTN0KaLKNE7e7b3b2Xu2+pv7SINJemDMW1ePFijj766Eave8iQIaxZs4aKiopG\n1yHNQzdlEom4wsJCTjnllBoHWOsbiuull17ikUce4c0332z0enNzc3F3ysrKMhocV7JHl5CLRFym\nQ3FVe+ONN5gwYQLz5s1jwIABjV7vli1bMDMNvRUBStQibUAmQ3EBLF26lDFjxvDQQw81+fTYVatW\nkZ+fr9Z0BChRi7QBmQzFtWLFCkaNGsW9997L6NGj69RRWFjIiBEj4q/TDdMFQR/3qFGjsrRF0hDq\noxZJkNe3b4NOoWtM/Zlq6FBc//Vf/8WGDRu47LLLuPTSS4HgXOx33nkHgLVr19YY9DbdMF0Ajz32\nGI8++mgjt1SakxK1SIJML0ZpCQ0dimvmzJnMnDkzZX1Llixh4cKF8dfphul67rnn+NrXvsYxxxzT\nxK2Q5qBELdKGNWQoriVLlmRc9uyzz+bss89uTEiSBeqjFhGJOCVqEZGIU6IWEYk4JWoRkYhTohYR\niTid9SF7tX79+tU5X1mkufXr169JyytRy16tpJ6LWxIHMWiJAQ0KJk2qcz/sgpISCh56KOvrluhS\nohZJQ6MWSRSoj1pEJOIyHeGlm5n9j5mtMrN3zezb2Q5MREQCmXZ9/AJ43t3/xcw6Al2yGJOIiCSo\nN1GbWS7wHXefBODuu4G6A7iJiEhWZNKi7g9sMLNZwLHAX4Gr3X17ViMTiYCWPutDJJlMEnVH4Djg\nx+7+VzMrAq4HptcumLgjx2KxJo8wIdLaCgsL48/bSqIumjaNstLSGtPy+vaN1C1c91bFxcUUFxc3\neLlMEvWnwFp3/2v4+kngP5IVbCs7skh7VlZamvRcbGl9tRuwiQ2BdOo968PdvwDWmtnAcNLpwMqG\nhygiIo2R6VkfVwGPmtm+wGrgB9kLSUREEmWUqN19GfCtLMciIiJJ6BJykTSmT69zzFykxSlRi6Sh\nA+QSBbrXh4hIxClRi4hEnBK1iEjEKVGLiEScDiaKpKF7fUgUKFGLpNEW7/Uh7Y+6PkREIk6JWkQk\n4pSoRUQiTolaRCTidDBRJA3d60OiQIlaJA2d6SFRoK4PEZGIy6hFbWYlwGagCtjl7sdnMygREfmH\nTLs+qoCYu2/KZjAiIlJXpl0f1oCyIiLSjDJtUTvwopk5cL+7P5DFmEQiQ/f6kCjINFGf5O5/N7Ne\nwEtmtsrdX81mYCJRoHt9SBRkOrjt38O/683saeB4oE6iTtyRY7EYsVisWYKUvcuyZcuSJsW8vDym\nTJnS8gGJNJPi4mKKi4sbvFy9idrMugAd3L3CzA4AzgQKk5VVi0Oag7sn3Ze0f0lbV7sBm/iLLZ1M\nWtQHAU+H/dMdgUfdfUEjYhQRkUaoN1G7+yfAsBaIRUREktAl5CJp6F4fEgVK1CJpqF9cokAXsYiI\nRJwStYhIxClRi4hEnBK1iEjE6WCiSBq614dEgRK1SBq614dEgbo+REQiTolaRCTilKhFRCJOiVpE\nJOJ0MFEkDd3rQ6JAiVokDZ3pIVGgrg8RkYhTohYRibiME7WZdTCzJWb2TDYDEhGRmhrSor4aWJmt\nQEREJLmMDiaa2WHAaOAW4CdZjUgkQnSvD4mCTM/6+DkwFeiWxVhEIkf3+pAoqLfrw8y+B3zh7m8D\nFj5ERKSFZNKiPhk418xGA52BXDOb7e4X1y6Y2OKIxWLEYrFmClNk77Vs6VIKJk2qMz2vb1+mzJjR\nYnVI0xUXF1NcXNzg5epN1O5+I3AjgJkNB65JlqRBPw1FssG3bqUgP7/O9IKSkhatQ5qudgM2sWst\nHZ1HLSIScQ26hNzdFwOLsxSLSOToXh8SBbrXh0ga6s6TKFDXh4hIxClRi4hEnBK1iEjEKVGLiESc\nDiaKpKF7fUgUKFGLpKF7fUgUqOtDRCTilKhFRCJOiVpEJOKUqEVEIk4HE0XS0L0+JAqUqEXS0Jke\nEgXq+hARiTglahGRiFOiFhGJuHr7qM1sf+BlYL+w/JPuntn4MSIi0mSZjJm408xOc/dtZrYP8JqZ\nveDub7ZAfCKtSvf6kCjI6KwPd98WPt0/XMazFpFIhOheHxIFGfVRm1kHM1sK/B14yd3fym5YIiJS\nLdMWdRXwDTPrCsw3s6+5+8ra5RJbHLWHRZfmUVRURFlZWZ3peXl5TJkyJSt1v//++wwaNKjGtJUr\n63z8Da67oXVEXdG0aZSVltaY9v7q1Qzq379O2by+fZkyY0ZLhZbSsqVLKZg0qc70VPEl28aobEtb\nUFxcTHFxcYOXa+go5OVmVgyMBNImasmOsrKypO9zc7z3qeoeM2ZMneljxoxpct0NrSPqykpLKcjP\nrzFtzKuvUjBiRJ2yBSUlLRNUPXzr1joxQ+r4km1jVLalLajdgE3sWkun3q4PMzvQzLqFzzsDZwDv\nNSpKERFpsExa1AcDD5tZB4LE/ri7P5/dsESiQff6kCjI5PS8d4DjWiAWkchRd55Ega5MFBGJOCVq\nEZGIU6IWEYk4JWoRkYjTwAEiaeheHxIFStQiaeheHxIF6voQEYk4JWoRkYhTohYRiTglahGRiNPB\nRJE0dK8PiQIlapE0dKaHRIG6PkREIk6JWkQk4pSoRUQiTolaRCTi6j2YaGaHAbOBPsAe4AF3vyfb\ngYlEge5bplVOAAAHyElEQVT1IVGQyVkfu4GfuPvbZpYD/M3MFri7xk2Udk/3+pAoqLfrw93/7u5v\nh88rgFXAodkOTEREAg3qozazfGAY8JdsBCMiInVlfMFL2O3xJHB12LKuI/GnYSwWIxaLNTE8aWuW\nLVuWtItg5cqVWak7Ly+PKVOm1ClbVFREWVlZnenvv/8+gwYNymj6+3/+c43XBZMmAbDwlVfomZtb\np47NX34JP/xhfZsBwLKlS+P1JVq5dCnk52dUR1tVNG0aZaWlNabl9e3LlBkzWimillNcXExxcXGD\nl8soUZtZR4Ik/Yi7/z5VOfXhibsn3Q/GjBmTlbpT7XNlZWUp48h0+phhw2quK0ygi555hvkTJ9ap\n49R7Mj/G7lu3xuursc5XX824jraqrLS0zrYXlJS0SiwtrXYDNvEYSDqZtqhnAivd/RcNjkykDZs+\nfHhrhyCS0el5JwMTgHfMbCngwI3u/r/ZDk6ktRWo+04ioN5E7e6vAfu0QCwiIpKErkwUEYk4JWoR\nkYhTohYRiTgNHCCSRkHCOa86sCitRYlaJI3CxYvjz5WopbWo60NEJOKUqEVEIk6JWkQk4pSoRUQi\nTgcTRdLQvT4kCpSoRdLQmR4SBer6EBGJOCVqEZGIU6IWEYk4JWoRkYjTwUSRNHSvD4mCTEZ4+R1w\nNvCFu389+yGJRIfu9SFRkEnXxyzgrGwHIiIiydWbqN39VWBTC8QiIiJJ6GCiiEjENevBxIKCgvjz\nWCxGTH16jVZUVERZWVmd6StXrkxaftmyZTXef4C8vDymTJnS5Lqj7MXHHuPt+fPrTN9QWVnn/Ujn\ni1WrKJg0qca0z1avrvG6ODywWFlZmbSOnTt3xstUS/Y+N5dlS5fWiRlg5dKlkJ/f6nUXTZtGWWlp\nRnWkWl9e375MmTEjo7pTlY2S4uLiOvtIJrKWqKVpysrKkr6fY8aMSVre3euUT/V5NLTuKNt3+3bm\njxtXZ/qpM2c2uJ6CWsnjhT17atzro7rh4W+9lbKe2o2TquXLGxRHQ/jWrXViBhjz6quRqLustDTj\nOlKtr6CkJOO6U5WNktoN2MLCwoyWyzRRW/gQ2avoTA+Jgnr7qM1sLvA6MNDMSs3sB9kPS0REqtXb\nonb38S0RiIiIJKezPkREIk6JWkQk4nSvD5E0dK8PiQIlapE0dK8PiQJ1fYiIRJwStYhIxClRi4hE\nnBK1iEjE6WCiSBqJ9/oQaS1K1CJp6EwPiQJ1fYiIRJwStYhIxClRi4hEnBK1iEjE6WCiSBq614dE\nQUaJ2sxGAkUELfDfuft/ZjUqkYjQvT4kCjIZ4aUD8EvgLOBoYJyZDc52YFHSmMEo2xJtX9tW3AbG\nCmyK9v75ZSKTPurjgQ/dfY277wL+Gzgvu2FFS3vfUbR9bZsSdfuXSaI+FFib8PrTcJqIiLSATPqo\nk40+7s0dSEM88cQTLE7oO6x23XXX0a9fv1aISEQke8w9fc41sxOAAncfGb6+HvDaBxTNrFWTt4hI\nW+TuyRrDNWSSqPcB3gdOBz4H3gTGufuq5ghSRETSq7frw933mNmVwAL+cXqekrSISAupt0UtIiKt\nq1kvITezfzez98zsHTO7vTnrjgozu9bMqsysR2vH0pzM7A4zW2Vmb5vZPDPr2toxNZWZjQz3xw/M\n7D9aO57mZGaHmdkiM1sZft+uau2YssHMOpjZEjN7prVjaW5m1s3M/if83r1rZt9OVbbZErWZxYBz\ngKHufgxwV3PVHRVmdhhwBrCmtWPJggXA0e4+DPgQuKGV42mSveBCrd3AT9z9a8CJwI/b2fZVuxpY\n2dpBZMkvgOfdfQhwLJCyS7k5W9STgdvdfTeAu29oxrqj4ufA1NYOIhvc/Y/uXhW+fAM4rDXjaQbt\n+kItd/+7u78dPq8g+JK3q+sbwobRaODB1o6luZlZLvAdd58F4O673b08VfnmTNQDgVPN7A0z+5OZ\nfbMZ6251ZnYOsNbd32ntWFrApcALrR1EE+01F2qZWT4wDPhL60bS7KobRu3xQFp/YIOZzQq7du43\ns86pCjfo7nlm9hJwUOIkgjfxp2Fdee5+gpl9C3giDKbNqGf7bgS+W2tem5Jm+25y92fDMjcBu9x9\nbiuE2Jwid6FWNphZDvAkcHXYsm4XzOx7wBfu/nbYrdrmvm/16AgcB/zY3f9qZkXA9cD0VIUz5u7f\nTTXPzK4AngrLvRUecOvp7l81ZB2tKdX2mdlQIB9YZmZG0C3wNzM73t2/bMEQmyTd5wdgZpcQ/NQc\n0TIRZdWnQN+E14cB61oplqwws44ESfoRd/99a8fTzE4GzjWz0UBnINfMZrv7xa0cV3P5lOAX+l/D\n108CKQ94N2fXx3yCi2Iws4HAvm0pSafj7ivcvY+793f3Iwje5G+0pSRdn/BWttcB57r7ztaOpxm8\nBRxpZv3MbD/gQqC9nTkwE1jp7r9o7UCam7vf6O593b0/wWe3qB0ladz9C2BtmCshyJ0pD5o258AB\ns4CZZvYOsBNoN29qEk77+yl2L7Af8FLwo4E33P1HrRtS47X3C7XM7GRgAvCOmS0l2CdvdPf/bd3I\npAGuAh41s32B1cAPUhXUBS8iIhGnMRNFRCJOiVpEJOKUqEVEIk6JWkQk4pSoRUQiTolaRCTilKhF\nRCJOiVpEJOL+PxVtB8tras//AAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Evaluate Predictions on test set\n", "y_prediction = tf.squeeze(tf.round(tf.nn.sigmoid(tf.add(x_data, A))))\n", "correct_prediction = tf.equal(y_prediction, y_target)\n", "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", "acc_value_test = sess.run(accuracy, feed_dict={x_data: [x_vals_test], y_target: [y_vals_test]})\n", "acc_value_train = sess.run(accuracy, feed_dict={x_data: [x_vals_train], y_target: [y_vals_train]})\n", "print('Accuracy on train set: ' + str(acc_value_train))\n", "print('Accuracy on test set: ' + str(acc_value_test))\n", "\n", "# Plot classification result\n", "A_result = -sess.run(A)\n", "bins = np.linspace(-5, 5, 50)\n", "plt.hist(x_vals[0:50], bins, alpha=0.5, label='N(-1,1)', color='white')\n", "plt.hist(x_vals[50:100], bins[0:50], alpha=0.5, label='N(2,1)', color='red')\n", "plt.plot((A_result, A_result), (0, 8), 'k--', linewidth=3, label='A = '+ str(np.round(A_result, 2)))\n", "plt.legend(loc='upper right')\n", "plt.title('Binary Classifier, Accuracy=' + str(np.round(acc_value_test, 2)))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }