{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## INTRO \n", "- Run the simple linear regression with Tensorflow DL framework.\n", "The idea is showing how tf work with simplist examples with Boston house dataset\n", "- General train process :\n", " - Step 1) Load the data, transform to the right size, and split to train,valid, test set \n", " - Step 2) Define variables (y=w*x+b), weight, bias for example \n", " - Step 3) Define cost, optimizer, super-parameter ( learn-rate for example)\n", " - Step 4) Intiate the variables and optimizer\n", " - Step 5) Run the model (with tf.Session() as sess, sess.run(init))\n", " \n", "\n", "## REF \n", "- https://medium.com/@saxenarohan97/intro-to-tensorflow-solving-a-simple-regression-problem-e87b42fd4845" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/yennanliu/anaconda3/envs/ds_dash/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: compiletime version 3.6 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.5\n", " return f(*args, **kwds)\n" ] } ], "source": [ "# OP\n", "from sklearn.datasets import load_boston\n", "from sklearn.preprocessing import scale\n", "from matplotlib import pyplot as plt\n", "\n", "# DL \n", "import tensorflow as tf\n", "from tensorflow.python.ops.metrics_impl import accuracy" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# LOAD THE DATA \n", "total_features, total_prices = load_boston(True)\n", "\n", "# SPLIT TO TRAIN, VALIDATION, TEST SET \n", "\n", "# Keep 300 samples for training\n", "train_features = scale(total_features[:300])\n", "train_prices = total_prices[:300]\n", "\n", "# Keep 100 samples for validation\n", "valid_features = scale(total_features[300:400])\n", "valid_prices = total_prices[300:400]\n", "\n", "# Keep remaining samples as test set\n", "test_features = scale(total_features[400:])\n", "test_prices = total_prices[400:]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ -6.41131126e-01, 1.00803985e-01, -1.03067021e+00,\n", " -3.14485451e-01, 2.17757002e-01, 2.19427174e-01,\n", " 8.26098070e-02, -9.55971610e-02, -2.15826599e+00,\n", " -2.32544281e-01, -1.00268807e+00, 4.20545707e-01,\n", " -9.23483688e-01],\n", " [ -6.09771235e-01, -5.93509179e-01, -2.83216038e-01,\n", " -3.14485451e-01, -4.17103062e-01, 9.45806783e-04,\n", " 5.53520080e-01, 3.48016313e-01, -1.53926045e+00,\n", " -1.01526137e+00, 9.27468208e-02, 4.20545707e-01,\n", " -2.52348069e-01],\n", " [ -6.09801116e-01, -5.93509179e-01, -2.83216038e-01,\n", " -3.14485451e-01, -4.17103062e-01, 1.08484038e+00,\n", " -5.83195447e-02, 3.48016313e-01, -1.53926045e+00,\n", " -1.01526137e+00, 9.27468208e-02, 3.26456136e-01,\n", " -1.07674783e+00]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# sample X train data \n", "train_features[:3]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 24. , 21.6, 34.7])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# sample y train data \n", "train_prices[:3]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train X dataset shape (300, 13)\n", "train y dataset shape (300,)\n" ] } ], "source": [ "print ('train X dataset shape',train_features.shape)\n", "print ('train y dataset shape',train_prices.shape )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# DEFINE THE VARIABLES \n", "# y= x*w + b \n", "\n", "# 1) weight (w) size : [data_feature_size, output_size]\n", "# tf.truncated_normal() as an initial value, which generates a regularised set of numbers from the normal probability distribution\n", "# or you can use tf.zeros as well -> i.e. weights = tf.Variable(tf.zeros([13, 1]))\n", "w = tf.Variable(tf.truncated_normal([13, 1], mean=0.0, stddev=1.0, dtype=tf.float64))\n", "\n", "# 2) bias size : [output_size]\n", "# select tf.float64 as bias dtype here \n", "b = tf.Variable(tf.zeros(1, dtype = tf.float64))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# HELP FUNC \n", "\n", "def calc(x, y):\n", "# Returns predictions and error\n", " predictions = tf.add(b, tf.matmul(x, w))\n", " error = tf.reduce_mean(tf.square(y - predictions))\n", " return [ predictions, error ]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# SET UP THE CALCULATE, SUPER PARAMETERS \n", "\n", "y, cost = calc(train_features, train_prices)\n", "# Feel free to tweak these 2 values:\n", "learning_rate = 0.025\n", "epochs = 300\n", "output = [[], []] # get the model train history \n", "accuracy = []" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# INIT THE VARIABLES AND THE optimizer\n", "init = tf.global_variables_initializer()\n", "optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 0 error : 671.141817236\n", "epoch : 100 error : 78.8004630724\n", "epoch : 200 error : 78.7280935877\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEICAYAAABWJCMKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XmcVNWZ//HPI6sgOw0iiEDEIISICIhLNOJuYjBxiSa/QBwSkl8csy8kTtQkTtRMMi6Z/DSMGtGfMRpJXjKJSUDUuII2CIii0iJIy76DQdZn/jin0kV1dXd1dzW3lu/79apX3Trn1K3n1K3up85dTpm7IyIi5eeQpAMQEZFkKAGIiJQpJQARkTKlBCAiUqaUAEREypQSgIhImVICEEmAmT1nZscnHMOLZjYsyRgkWUoAkjMzW25mu82sZ0b5AjNzMxuQVnaymT1hZtvNbKuZ/Y+ZDU2r/6iZ7TezHfFWbWYPm9nojHW7mb2X1m6HmX031l1vZv8/x9jdzI6uo66fmT1gZhvja71oZh/PaDM+9nObmW0ws9mp/ppZVzO7x8zWxP6+aWbfqyeWC4Ht7v5yLrG3oJ8DP044BkmQEoA01tvAFakHZjYcODS9gZmdBMwEHgWOAAYCC4HnzGxQWtNV7n4Y0AkYC7wOPGNmZ2a85nHuflja7Wf56oyZdQeeBXYDw4CewC3Ab83sktjmaOA+4FtAl9if/wfsj6u5BTgMODbWfwJ4q56X/TJwfxPjbZ1LWY7rmAGcYWZ9mhKLFD8lAGms+4EJaY8nEv45pvsZcJ+73+bu2919k7v/GzAHuD5zhR5Uu/u1wF3AzS0TelbfAHYAk9x9jbvvdPcHgX8HfmFmBowA3nb32THW7e4+3d3fiesYDfzW3Te7+353f93dH8n2YmbWFhgH/D2t7BAzm2Jmb8VRyMMxMWFmA+LoZZKZvQM8ka0stv2Emb1qZlvM7CkzOzbtNZab2ffMbBHwnpm1dvf3gXnAOXl9R6VoKAFIY80BOpvZsWbWCvg08M/dMGbWATgZ+H2W5z4MnN3A+v8AjDSzjnmKtyFnA9PdfX9G+cNAf+AYYD4wxMxuMbMzzOywjLZzgH83syvNbHADrzcY2O/u1WllXwUuAk4njJg2A7/KeN7phBHGudnKzOwY4EHg60AF8BjwPzHhpFwBfAzo6u57Y9kS4LgGYpYSpQQgTZEaBZxN2G3zblpdd8LnanWW560m7GKpzyrAgK5pZfPjt9rU7dw6ntsUPak7VoCe7r4M+CjQl5AYNpjZvWmJ4GrgAeBfgdfMrMrMzq/j9boC2zPKvgRcE0dBuwijpEsydu1c7+7vufvOOso+DfzZ3We5+x7C/v1DCck45XZ3X5mxju0c+F5LGVECkKa4H/gM8Hlq7/7ZTNg3nm2/ch9gQwPr7gs4sCWtbKS7d027/a1JUWe3gbpjTdXj7nPc/TJ3rwA+ApwGXBPrdrr7T939BKAHIUn8PrUbJ8NmwjGPdEcBf0wlOMK38n1A77Q2K7OsK73sCGBF6kEc0awkvJ/1raMTB77XUkaUAKTR3H0F4WDwBYRdNul17wEvAJdmeeplwOwGVv9JYH5cz8HwOHCxmWX+LVxG+If5ZuYT3P0lQr8/lKVuG/BToCPhYHGmpYCZWeY/5vMzklx7d08fWWWbtje9bBUhkUB8AeBIDhydZVvHsYQD9FKGlACkqSYB4+r4Rz0FmGhmXzWzTmbWzcxuAE4CfpTZ2IK+ZnYd8AXgB42I4xAza592a1dP27YZbVsRzuDpDNxtZofH8isI3+6/4+5uZqea2RfNrFeMdwjhTJ858fEPzWy0mbU1s/bA1wjfqt/IDCDunnmcsP8+5U7CMYSj4voqzGx8I94DCKOOj5nZmWbWhnDG0i7g+bqeEN+rE4BZjXwtKRFKANIk7v6Wu1fWUfcs4WDlpwj70lcAxwOnuvvStKZHmNkOwlk4LwHDgY+6+8yMVS60A68DuDWt7gpgZ9qtvtMvX81oe6W7bwROBdoDrwEbgW8Cn3P3h+LzthD+4b8S4/0r8EfC2U4Qvln/hrC7aBXh2MjH3H1HHXH8Gvhc2uPbCKdkzjSz7YTEcmI9/ajF3d8A/g/wyxjHhcCF7r67nqd9AnjK3Vc15rWkdJh+EEbk4DOzZ4Grk7wYzMzmEk5/XZxUDJIsJQARkTKV0y6geKn7I2b2upktMbOTzKy7mc0ys6Xxvltsa2Z2ezwVbpGZjWzZLoiISFPkegzgNuCv7j6EcNHIEsKBvtnuPphwZseU2PZ8wsUug4HJwB15jVhERPKiwV1AZtaZcJrYIE9rbGZvEA7YrY5ziTzl7h80s1/H5Qcz27VYL0REpNFymURqELAe+I2ZHUeYO+RrQO/UP/WYBHrF9n058IKT6lh2QAIws8mEEQIdO3Y8YciQIbB3L7z/PnTsCGbN6ZeISMmbN2/ehnhxYpPkkgBaAyMJZyzMNbPbqNndk022/9y1hhnuPhWYCjBq1CivrKyEu+6CL34RVqyA/v1zCE1EpHyZ2YqGW9Utl2MA1UC1u8+Njx8hJIS1cdcP8X5dWvsj057fj3BudMN69Aj3Gzfm1FxERJquwQTg7muAlWb2wVh0JuGCmRmEqYCJ94/G5RnAhHg20Fhga877/7vHqVOUAEREWlyuPyRxNfBAnFp2GXAlIXk8bGaTgHeomfvlMcIcMVXAP2Lb3KRGAJs25fwUERFpmpwSgLsvAEZlqcr85SbimUJXNSka7QISETloGvVTci2uZ0946CEYPbrhtiIi0iyFlQDatIHLLks6ChGRslB4s4G+8ALMndtwOxERaZbCGgEAXH019OoFjz2WdCQiIiWt8EYAPXroILCIyEGgBCAiUqaUAEREylRhJoAtW8LEcCIi0mIKLwFMmADPPKPZQEVEWljhnQU0aFC4iYhIiyq8EcD69XD//bAqtwlERUSkaQovASxfHnYDzZuXdCQiIiWt8BKAJoQTETkolABERMpU4SWAzp2hdWslABGRFlZ4CcAs/DKYEoCISIsqvNNAAR5/HCqa/EP3IiKSg8JMAMOHJx2BiEjJK7xdQBBGAHfdlXQUIiIlrTATwO9+B9dem3QUIiIlrTATQGpGUPekIxERKVmFmwB274b33ks6EhGRklW4CQB0KqiISAsqzATQvXu4VwIQEWkxhXka6DnnwLvvhh+HFxGRFlGYCaBjx3ATEZEWU5i7gN5/H3784/DLYCIi0iIKcwTQqhVcd11Y/shHko1FRKREFeYIoE2bMCuoDgKLiLSYwkwAUHMxmIiItAglABGRMpVTAjCz5Wb2ipktMLPKWNbdzGaZ2dJ43y2Wm5ndbmZVZrbIzEY2KbIePWDTpiY9VUREGtaYEcAZ7j7C3UfFx1OA2e4+GJgdHwOcDwyOt8nAHU2KbPp0eP75Jj1VREQa1pxdQOOBaXF5GnBRWvl9HswBuppZn0avvWPHcDaQiIi0iFwTgAMzzWyemU2OZb3dfTVAvE9dttsXWJn23OpYdgAzm2xmlWZWuX79+tqvOHs2fPnLsHdvjiGKiEhj5JoATnH3kYTdO1eZ2Wn1tLUsZbXmdXb3qe4+yt1HVWT7+cfXXoNf/xo2b84xRBERaYycEoC7r4r364A/AmOAtaldO/F+XWxeDRyZ9vR+wKpGR6YZQUVEWlSDCcDMOppZp9QycA6wGJgBTIzNJgKPxuUZwIR4NtBYYGtqV1GjKAGIiLSoXKaC6A380cxS7X/r7n81s5eAh81sEvAOcGls/xhwAVAF/AO4skmRKQGIiLSoBhOAuy8DjstSvhE4M0u5A1c1O7IePaBtW/0qmIhICynMyeAABgwIs4JatmPKIiLSXIWbAPSPX0SkRRXuXEAA3/wm/Nd/JR2FiEhJKuwEMHMmPPFE0lGIiJSkwk4AmhFURKTFKAGIiJQpJQARkTJV2AmgXz/o0iXpKERESlJhJ4DrroPXX086ChGRklTYCUBERFpMYSeAefPgnHNgyZKkIxERKTmFnQB27oRZs2DlyobbiohIoxR2AtCMoCIiLUYJQESkTBV2AujePdwrAYiI5F1hJ4DWrWH0aOjcOelIRERKTuFOB53y4otJRyAiUpIKewQgIiItpvATwHe/C1dckXQUIiIlp/ATwLvvwty5SUchIlJyCj8B9OgBmzYlHYWISMkpjgSwdSvs3Zt0JCIiJaU4EgBoFCAikmeFnwAGD4Zx42DPnqQjEREpKYV/HcC554abiIjkVeGPAEREpEUUfgJYsybsBnrggaQjEREpKYWfAA47DKqqwvUAIiKSN4WfADp2hLZtNSOoiEieFX4CMAungioBiIjkVc4JwMxamdnLZvan+Higmc01s6Vm9pCZtY3l7eLjqlg/oNlRKgGIiORdY0YAXwPSf539ZuAWdx8MbAYmxfJJwGZ3Pxq4JbZrngsugFGjmr0aERGpkVMCMLN+wMeAu+JjA8YBj8Qm04CL4vL4+JhYf2Zs33Q33wzXXNOsVYiIyIFyHQHcCnwX2B8f9wC2uHtqgp5qoG9c7gusBIj1W2P7A5jZZDOrNLPK9evXNzF8ERFpqgYTgJl9HFjn7vPSi7M09Rzqagrcp7r7KHcfVVFRUX8QN90E3bqB11qNiIg0US5TQZwCfMLMLgDaA50JI4KuZtY6fsvvB6yK7auBI4FqM2sNdAGaN5Nb69awZQvs2AGdOjVrVSIiEjQ4AnD377t7P3cfAFwOPOHunwWeBC6JzSYCj8blGfExsf4J92Z+dU/NCLphQ7NWIyIiNZpzHcD3gG+aWRVhH//dsfxuoEcs/yYwpXkhUpMAdCqoiEjeNGo2UHd/CngqLi8DxmRp8z5waR5iq6EEICKSd4V/JTDAUUfBpEnQu3fSkYiIlIzC/z0AgH794K67ko5CRKSkFMcIAMIpoLt3Jx2FiEjJKJ4E0KsXfOtbSUchIlIyiicBdOmig8AiInlUPAmge3clABGRPCqeBKApoUVE8koJQESkTBXHaaAAF10Ew4YlHYWISMkongRwySUNtxERkZwVzy6gvXth7VrYsyfpSERESkLxJIDp0+Hww+HNN5OORESkJBRPAtCEcCIieaUEICJSppQARETKlBKAiEiZKp4E0KED3HgjnH560pGIiJSE4rkOwAymNP/XJUVEJCieEQDA6tXw9ttJRyEiUhKKZwQA8JnPwL598PTTSUciIlL0imsE0KMHbNiQdBQiIiWh+BKAzgISEcmL4ksAmzaF3wcWEZFmKb4EsHcvbNuWdCQiIkWvuA4Cn3deSAJt2iQdiYhI0SuuBDBsmH4URkQkT4prF9DOnTBnDqxbl3QkIiJFr7gSwDvvwEknwaxZSUciIlL0iisBaEI4EZG8Ka4E0K1bmBNICUBEpNmKKwG0agVduyoBiIjkQYMJwMzam9mLZrbQzF41sx/F8oFmNtfMlprZQ2bWNpa3i4+rYv2AvEbcvbsSgIhIHuQyAtgFjHP344ARwHlmNha4GbjF3QcDm4FJsf0kYLO7Hw3cEtvlz69+Bd/+dl5XKSJSjhpMAB7siA/bxJsD44BHYvk04KK4PD4+JtafaWaWt4jPPRdOOCFvqxMRKVc5HQMws1ZmtgBYB8wC3gK2uPve2KQa6BuX+wIrAWL9VqBHlnVONrNKM6tcv3597hEvWQJ/+Uvu7UVEJKucEoC773P3EUA/YAxwbLZm8T7bt/1as7e5+1R3H+XuoyoqKnKNF6ZOhcsuy729iIhk1aizgNx9C/AUMBboamapqST6AavicjVwJECs7wJsykewQLgWYMcO2L07b6sUESlHuZwFVGFmXePyocBZwBLgSeCS2Gwi8GhcnhEfE+ufcM/j/M2p0cKaNXlbpYhIOcplMrg+wDQza0VIGA+7+5/M7DXgd2Z2A/AycHdsfzdwv5lVEb75X57XiIcPD/cvvwz9++d11SIi5aTBBODui4Djs5QvIxwPyCx/H7g0L9FlM2JEuCDspZdg/PgWexkRkVJXXNNBA3ToAH//OwwdmnQkIiJFrfgSAMAppyQdgYhI0SuuuYBSVq6EG26A6uqkIxERKVrFmQA2boQf/hCeeSbpSEREilZxJoBhw+DQQ8OBYBERaZLiTABt2sDxx8OLLyYdiYhI0SrOBAAwejTMnw979zbcVkREaineBDBmDLjD228nHYmISFEq3gRw8cWwbRsMHpx0JCIiRak4rwMAaNcu6QhERIpa8Y4AAO64AyZObLidiIjUUtwJYOVKeOAB2Lkz6UhERIpOcSeAMWNg3z5YsCDpSEREik5xJ4DRo8O9LggTEWm04k4AfftCnz5KACIiTVC8ZwGlfPzj0L590lGIiBSd4k8AU6cmHYGISFEq7l1A6fbvTzoCEZGiUvwJYOdOGDQI/uM/ko5ERKSoFH8COPRQOOQQzQwqItJIxZ8AIFwPoAQgItIopZEARo8OPw+5Zk3SkYiIFI3SSQCg6wFERBqhNBLAyJHwpS/BEUckHYmISNEo/usAADp0gDvvTDoKEZGiUhojAAiTwi1ZEn4lTEREGlQ6CeCee2DoUP1EpIhIjkonAZxwQrjXgWARkZyUTgIYPjz8TKSuBxARyUnpJIA2beD44zUCEBHJUYMJwMyONLMnzWyJmb1qZl+L5d3NbJaZLY333WK5mdntZlZlZovMbGRLd+KfRo+G+fPDAWEREalXLiOAvcC33P1YYCxwlZkNBaYAs919MDA7PgY4Hxgcb5OBO/IedV3+5V/gwQc1M6iISA4aTADuvtrd58fl7cASoC8wHpgWm00DLorL44H7PJgDdDWzPnmPPJsRI+DCC8PuIBERqVejjgGY2QDgeGAu0NvdV0NIEkCv2KwvsDLtadWxLHNdk82s0swq169f3/jI6/LcczBzZv7WJyJSonJOAGZ2GDAd+Lq7b6uvaZayWldnuftUdx/l7qMqKipyDaNh114LP/hB/tYnIlKickoAZtaG8M//AXf/Qyxem9q1E+/XxfJq4Mi0p/cDVuUn3ByMHg2LFsH77x+0lxQRKUa5nAVkwN3AEnf/z7SqGcDEuDwReDStfEI8G2gssDW1q+igGDMG9uyBhQsP2kuKiBSjXEYApwCfA8aZ2YJ4uwC4CTjbzJYCZ8fHAI8By4Aq4L+Br+Q/7HpoamgRkZw0OBuouz9L9v36AGdmae/AVc2Mq+n69YPDD4fKysRCEBEpBqUxHXQ6M3j6aejfP+lIREQKWuklAIDBg5OOQESk4JXOXEDp1qyBb387TAshIiJZlWYCaN0afvELePzxpCMRESlYpZkAevaEgQN1JpCISD1KMwFAuB5Avw0gIlKn0k0Ao0fDO+/AunUNtxURKUOlnQB69oTly5OORESkIJXmaaAAp54avv1bXdewiYiUt9JNAIeU7uBGRCQfSvu/5F13wckng9eajVpEpOyVdgLYswdeeAFWrEg6EhGRglPaCeCMM8L9nXcmG4eISAEq7QQwZAhceSXccgtUVSUdjYhIQSntBADw059C27bwne8kHYmISEEp3bOAUg4/HH7zGxg6NOlIREQKSuknAIBLLqlZdte1ASIilMMuoJSdO+Gyy+CXv0w6EhGRglA+CaB9e9iyBa67DjZsSDoaEZHElU8CMINbb4Xt2+GHP0w6GhGRxJVPAoBwIPiqq2DqVFi4MOloREQSVV4JAOD666FbN40CRKTslcdZQOm6dYPp03VaqIiUvfJLAACnnx7u9++HffugTZtk4xERSUD57QJKee89OOkkuPnmpCMREUlE+SaAjh3hyCPhxhuhujrpaEREDrryTQAAP/952AU0ZUrSkYiIHHTlnQAGDAiTxD3wADz/fNLRiIgcVOWdACB8++/bN0wZLSJSRsrzLKB0HTvC3/4GRx+ddCQiIgeVEgDAsGHh/h//CMcEOnVKNh4RkYOgwV1AZnaPma0zs8VpZd3NbJaZLY333WK5mdntZlZlZovMbGRLBp9X27eHi8NGjYKnnko6GhGRFpfLMYB7gfMyyqYAs919MDA7PgY4Hxgcb5OBO/IT5kHQqVOYI2jPnvBbwp//vGYNFZGS1mACcPengU0ZxeOBaXF5GnBRWvl9HswBuppZn3wF2+LOOQcWL4bvfz+cGfTBD8KrryYdlYhIi2jqWUC93X01QLzvFcv7AivT2lXHslrMbLKZVZpZ5fr165sYRgvo0CH8jvCCBXDppSEJQPhBGRGREpLv00Cz/daiZ2vo7lPdfZS7j6qoqMhzGHkwbBjceSe0bg2bNsHgwXDNNUoEIlIympoA1qZ27cT7dbG8GjgyrV0/YFXTwysgZ54ZRgYf+hDMnJl0NCIizdbUBDADmBiXJwKPppVPiGcDjQW2pnYVFbXu3WHaNJg9O4wIzj0XrrgCdu1KOjIRkSZr8DoAM3sQ+CjQ08yqgeuAm4CHzWwS8A5waWz+GHABUAX8A7iyBWJOzrhxsGgR3HQTvPYatGsXyr/0JejdG04+GcaOha5dk41TRCQH5p51F/1BNWrUKK+srEw6jMZxD78zvHcvnHhiOGi8f3+oGzoUvvEN+MIXQjsIbUVE8sjM5rn7qKY+X1cCN1XqH3rr1jBvXriQ7KWXwqRyzz8PrVqF+pUrYcSIkBR69YKKinC77DL48IfD8956K5T17FkzqhARaWFKAPnSqVPYRTRu3IHl+/bBpz4V/sm/+SY89xxs3AjHHRcSwIsvwlln1bTv3DnsQrr/fjjtNHj6abj22pAY2revuf3bv4Uzkyor4ZFHQsJJv02eHBLOyy/D3/9eU24Wbp/9bIh5wYKQwFLlqdvll4fXfPnlsLsLapKeGXz603DIIeG5VVUHjnBatYKLLw7LL70EK1Yc+J60awcXXhiWX3gBVmWcJ9CxI5wXrz189llYt+7A+i5dwkF5CFdtb958YH337jW/+jZ7dkiy6Xr1CrvrIMwDlXlm1xFHwJgxYfnPfw4XB6br3x9GxovcZ8yoGeWlDBwYtu2+ffCnP1HL4MHhC8GuXfDXv9auP/ZYOOaY8KNFjz9eu374cBg0CLZuzX7V+vHHhxg3bgzvX6bRo0Mf166FOXNq1590UniPVq0K2y/TRz4S3uN33gmfj0xnnBE+x8uWwSuv1K4/++xwuvWbb8KSJbXrzz8f2rYNn7ulS2vXX3hh+OwtWgRvv31g3SGH1Hy25s8PX8DStW0b1g/hb291xiHKDh1CfBC+yGWeot65c+gfwDPP1P7sdesW3h+AJ5+s/dmrqAjvL8CsWbU/e336hO1zsLh74rcTTjjBy8q+fe579oTltWvdp093v/NO95/8xP2rX3WfMMH9lVdC/ZNPup9+uvuJJ7ofd5z7kCHuAwa4z58f6u+9171tW/dWrdzDv6JwW7w41N9664Hlqdvy5aH+Jz/JXr9xY6j/3vey1+/eHeq/8pXade3a1fR1woTa9T171tR/8pO16wcMqKk/66za9cOH19SfeGLt+pNPrqkfNqx2/bnn1tT371+7/uKLa+q7datd//nP19S3bl27/uqrQ93Ondnfux/8INSvX5+9/sYbQ/2yZdnrf/nLUL9wYfb6e+8N9c8+m71++vRQ/5e/ZK+fOTPU//732etfeCHU33NP9vrUZ/e22+r/7N1wQ/M+e1dd1fjPXo8exfPZywFQ6d70/706BlBqUr9z3KpV+Da0a1fNJHf79tV81Hr1Cm22bg23zI/iUUeF+g0bwrec1OckdX/MMeFb/5o14TqJTEOHhvvqatiypabcPaw3Vb9iBWzbduBz27SBIUPC8rJlsGPHgfXt2tVcoLd0aehfug4dwrdsgNdfh/ffP7C+Uyf4wAfC8quv1v6G36VL+BYP4Rvsvn0H1nfrFt4fyP4NuKIC+vUL22Lhwtr1vXuHb+B792b/hnzEEaHNrl01o690/fqF19i5M/QvU//+0KNHeN+yfYMeODCMMrdtCyPTTB/4QPimu3kzLF9eu/6YY8IobePGMArINGQIHHpoGLm9+27t+mHDwjfxNWtqfwOHMMJp3To8N3P0B2GXqln4dp85XYtZqIfw2cr8bLZqFUZnED5bW7ceWN+2bc3kkFVVtb/Bt28fRmgQ3vvMz17HjjWfzddeq/3Z69y5ZubhxYth9+4D67t2DaO7HDX3GIASgIhIkWpuAtAPwoiIlCklABGRMqUEICJSppQARETKlBKAiEiZUgIQESlTSgAiImVKCUBEpEwpAYiIlKmCuBLYzLYDbyQdRwvqCWxosFXxKuX+lXLfQP0rdh90905NfXKhzAb6RnMuZy50Zlap/hWnUu4bqH/FzsyaNYeOdgGJiJQpJQARkTJVKAlgatIBtDD1r3iVct9A/St2zepfQRwEFhGRg69QRgAiInKQKQGIiJSpxBOAmZ1nZm+YWZWZTUk6nuYys+Vm9oqZLUidomVm3c1slpktjffdko4zV2Z2j5mtM7PFaWVZ+2PB7XFbLjKzkclFnps6+ne9mb0bt+ECM7sgre77sX9vmNm5yUSdOzM70syeNLMlZvaqmX0tlhf9NqynbyWx/cysvZm9aGYLY/9+FMsHmtncuO0eMrO2sbxdfFwV6wc0+CLN+UHh5t6AVsBbwCCgLbAQGJpkTHno03KgZ0bZz4ApcXkKcHPScTaiP6cBI4HFDfUHuAD4C2DAWGBu0vE3sX/XA9/O0nZo/Iy2AwbGz26rpPvQQP/6ACPjcifgzdiPot+G9fStJLZf3AaHxeU2wNy4TR4GLo/ldwL/Ny5/BbgzLl8OPNTQayQ9AhgDVLn7MnffDfwOGJ9wTC1hPDAtLk8DLkowlkZx96eBzF99r6s/44H7PJgDdDWzPgcn0qapo391GQ/8zt13ufvbQBXhM1yw3H21u8+Py9uBJUBfSmAb1tO3uhTV9ovbYEd82CbeHBgHPBLLM7ddaps+ApxpZlbfaySdAPoCK9MeV1P/BiwGDsw0s3lmNjmW9Xb31RA+tECvxKLLj7r6U0rb81/jLpB70nbZFXX/4i6B4wnfJEtqG2b0DUpk+5lZKzNbAKwDZhFGLVvcfW9skt6Hf/Yv1m8FetS3/qQTQLbsVOznpZ7i7iOB84GrzOy0pAM6iEple94BfAAYAawGfhHLi7Z/ZnYYMB34urs6SHZ+AAAByElEQVRvq69plrKC7mOWvpXM9nP3fe4+AuhHGK0cm61ZvG90/5JOANXAkWmP+wGrEoolL9x9VbxfB/yRsNHWpobR8X5dchHmRV39KYnt6e5r4x/efuC/qdlNUJT9M7M2hH+QD7j7H2JxSWzDbH0rte0H4O5bgKcIxwC6mllqHrf0Pvyzf7G+Cw3s3kw6AbwEDI5HtdsSDlzMSDimJjOzjmbWKbUMnAMsJvRpYmw2EXg0mQjzpq7+zAAmxDNJxgJbU7sZiknGPu9PErYhhP5dHs+2GAgMBl482PE1RtwHfDewxN3/M62q6LdhXX0rle1nZhVm1jUuHwqcRTjO8SRwSWyWue1S2/QS4AmPR4TrVABHui8gHL1/C7gm6Xia2ZdBhLMMFgKvpvpD2A83G1ga77snHWsj+vQgYRi9h/ANY1Jd/SEMQX8Vt+UrwKik429i/+6P8S+Kf1R90tpfE/v3BnB+0vHn0L9TCbsBFgEL4u2CUtiG9fStJLYf8GHg5diPxcC1sXwQIXFVAb8H2sXy9vFxVawf1NBraCoIEZEylfQuIBERSYgSgIhImVICEBEpU0oAIiJlSglARKRMKQGIiJQpJQARkTL1vwpXXcq6l86qAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Validation error = 104.344204453 \n", "\n", "Test error = 123.334899326 \n", "\n" ] } ], "source": [ "# TRAIN THE MODEL \n", "\n", "with tf.Session() as sess:\n", "\n", " sess.run(init)\n", "\n", " for i in list(range(epochs)):\n", "\n", " sess.run(optimizer)\n", "\n", " if i % 10 == 0.:\n", " output[0].append(i+1)\n", " output[1].append(sess.run(cost))\n", "\n", " if i % 100 == 0:\n", " print('epoch : ',i ,'error : ', sess.run(cost))\n", " \n", " # for accuracy plot \n", " #valid_predict = calc(valid_features, valid_prices)[0]\n", " #accuracy.append.(tf.metrics.accuracy(labels, valid_accuracy))\n", " \n", " \n", " plt.title('MODEL LOSS (error) ')\n", " plt.plot(output[0], output[1], 'r--')\n", " plt.axis([0, epochs, 50, 600])\n", " plt.show()\n", " \n", " #plt.title('MODEL ACCURACY ')\n", " #plt.plot(output[0], output[1], 'r--')\n", " #plt.axis([0, epochs, 50, 600])\n", " #plt.show()\n", "\n", " valid_cost = calc(valid_features, valid_prices)[1]\n", "\n", " print('Validation error =', sess.run(valid_cost), '\\n')\n", "\n", " test_cost = calc(test_features, test_prices)[1]\n", "\n", " print('Test error =', sess.run(test_cost), '\\n')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# end of demo\n", "# TODO :\n", " # 1) add accuracy,\n", " # 2) learning curve plots \n", " # 3) grid search \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Environment (conda_ds_dash)", "language": "python", "name": "conda_ds_dash" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }