{ "cells": [ { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "from random import randint\n", "\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import numpy as np\n", "from sklearn.base import BaseEstimator\n", "from sklearn.datasets import fetch_mldata\n", "from sklearn.linear_model import SGDClassifier\n", "from sklearn.model_selection import cross_val_score, cross_val_predict\n", "from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, precision_recall_curve\n", "import tensorflow as tf\n", "from tensorflow.contrib.layers import fully_connected\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "\n", "sns.set(font_scale=1.5, palette='colorblind')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST classification" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "mnist = fetch_mldata('MNIST original', data_home='~/workspace/ds/MNIST')" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "X, y = mnist['data'], mnist['target']" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70000, 784)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.shape" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5.0" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "digit = X[36000]\n", "digit_image = digit.reshape(28, 28)\n", "y[36000]" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-0.5, 27.5, 27.5, -0.5)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPkAAAD3CAYAAADfRfLgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABgtJREFUeJzt3a9rVX8cx/E7EYMsDF0awsbAWQzivzHEpha1mRRhGkyWFUG0WQXFpEFENC6IQWxjpiHzJw6EK8gtC+r95vHlvA+7d9fN1x6P+uLs3uCTT/hwrmP9fr/fAWLt2+kvAIyWyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCGcyCHc/p3+ArBVjx8/LveVlZXG7eHDh9v9dTb59OnTSP/+IJzkEE7kEE7kEE7kEE7kEE7kEE7kEM49OSPR6/Uat9evX5fPLi4ulvubN2/KfWxsrNz3Gic5hBM5hBM5hBM5hBM5hBM5hHOFFurXr1/lvr6+PtTfb7vm+vDhQ+O2tLQ01GeP0uTkZLmfO3fuL32T7eMkh3Aih3Aih3Aih3Aih3Aih3Aih3DuyUO13YPPzMyUe7/fL/fd/DrniRMnGrfz58+Xz87Pz5f70aNHB/pOO8lJDuFEDuFEDuFEDuFEDuFEDuFEDuHck4e6fv16ubfdg7ftbaamphq3S5culc/evHlzqM9mMyc5hBM5hBM5hBM5hBM5hBM5hBM5hHNP/g+7f/9+4/by5cvy2WHfB297vtvtNm5tvwm/urpa7nNzc+XOZk5yCCdyCCdyCCdyCCdyCCdyCCdyCDfWH/bFYUamugfvdDqdhYWFxq3X6w312Tv5u+vT09Plvra2NrLPTuQkh3Aih3Aih3Aih3Aih3Aih3Cu0Haxtqukr1+/Dvy3JyYmyn18fLzc9+2rz4eNjY3G7fv37+WzbX7//j3U83uNkxzCiRzCiRzCiRzCiRzCiRzCiRzC+UnmXez06dPlfu/evcbt4sWL5bOXL18u95MnT5Z7m/X19cZtfn6+fHZ5eXmoz2YzJzmEEzmEEzmEEzmEEzmEEzmEEzmE8z45I/Ht27fGbdh78j9//gz0nfYqJzmEEzmEEzmEEzmEEzmEEzmEEzmE2/Pvk3/58qXcDx482LgdPnx4u79OjOquu+2/PW7bnz17Vu5t7+HvNU5yCCdyCCdyCCdyCCdyCCdyCCdyCBd/T37r1q1yf/DgQbkfOHCgcZudnS2fffr0abn/y7rdbrnfuHGjcXv37l357MzMzCBfiQZOcggncggncggncggncggncggXf4X29u3bcl9dXR34b3/+/Lncr127Vu537twZ+LNHre0V3BcvXpR7dU22f3/9z+748ePl7lXSrXGSQziRQziRQziRQziRQziRQziRQ7j4e/JRmpiYKPfdfA/e5urVq+Xe9rPIlampqZH9bf7PSQ7hRA7hRA7hRA7hRA7hRA7hRA7h4u/J237ed3x8vNx7vV7jdurUqUG+0l9x9uzZcn/y5Em59/v9cm/774Urt2/fHvhZts5JDuFEDuFEDuFEDuFEDuFEDuFEDuHi78nv3r1b7u/fvy/36vfFNzY2ymfb7qLbLC4ulvvPnz8btx8/fpTPtt1zHzt2rNwvXLgw8H7o0KHyWbaXkxzCiRzCiRzCiRzCiRzCiRzCjfXb3ikMt7S0VO4LCwuNW/UaaqfT6Xz8+LHcR/k659zcXLlPTk6W+6NHj8p9enp6y9+JneEkh3Aih3Aih3Aih3Aih3Aih3Aih3B7/p68TbfbbdzaXudcXl4u91evXpX78+fPy/3KlSuN25kzZ8pnjxw5Uu7kcJJDOJFDOJFDOJFDOJFDOJFDOJFDOPfkEM5JDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuFEDuH+A7up+am9jbSbAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')\n", "plt.axis('off')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "shuffle_index = np.random.permutation(60000)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 7837, 33903, 24440, ..., 12235, 14155, 56369])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shuffle_index" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Binary classifier: 5 or not 5?" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "y_train_5 = (y_train == 5)\n", "y_test_5 = (y_test == 5)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "sgd_clf = SGDClassifier(loss='log', random_state=42)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n" ] }, { "data": { "text/plain": [ "array([0.95825, 0.9601 , 0.95145])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "class Never5Classifier(BaseEstimator):\n", " \n", " def fit(self, X, y=None):\n", " pass\n", " \n", " def predict(self, X):\n", " return np.zeros((len(X), 1), dtype=bool)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.91065, 0.9101 , 0.9082 ])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "never_5 = Never5Classifier()\n", "cross_val_score(never_5, X_train, y_train_5, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n" ] } ], "source": [ "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[53996, 583],\n", " [ 2021, 3400]])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "confusion_matrix(y_train_5, y_train_pred)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[54579, 0],\n", " [ 5421, 0]])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "confusion_matrix(y_train_5, cross_val_predict(never_5, X_train, y_train_5, cv=3))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Precision: 0.8536279186542807\n", "Recall: 0.6271905552481092\n" ] } ], "source": [ "precision = precision_score(y_train_5, y_train_pred)\n", "recall = recall_score(y_train_5, y_train_pred)\n", "\n", "print('Precision:', precision)\n", "print('Recall:', recall)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7230965546575924" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f1_score(y_train_5, y_train_pred)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/linear_model/stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.\n", " \"and default tol will be 1e-3.\" % type(self), FutureWarning)\n" ] } ], "source": [ "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method='decision_function')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0,0.5,'recall')" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "f, ax = plt.subplots(1, figsize=(10, 6))\n", "ax.plot(precisions, recalls)\n", "ax.set_xlabel('precision')\n", "ax.set_ylabel('recall')" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "f, ax = plt.subplots(1, figsize=(20, 8))\n", "ax.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\")\n", "ax.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\")\n", "ax.set_xlabel(\"Threshold\")\n", "ax.legend(loc=\"upper left\")\n", "ax.set_ylim([0, 1])\n", "ax.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiclass classifier" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "n_inputs = 28*28\n", "n_hidden1 = 300\n", "n_hidden2 = 100\n", "n_outputs = 10" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "X = tf.placeholder(tf.float32, shape=(None, n_inputs), name=\"X\")\n", "y = tf.placeholder(tf.int64, shape=(None), name=\"y\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope(\"dnn\"):\n", " hidden1 = fully_connected(X, n_hidden1)\n", " hidden_2 = fully_connected(hidden1, n_hidden2)\n", " logits = fully_connected(hidden1, n_hidden2, activation_fn=None)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope(\"loss\"):\n", " xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)\n", " loss = tf.reduce_mean(xentropy, name=\"loss\") " ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.01\n", "\n", "with tf.name_scope(\"train\"):\n", " optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n", " training_op = optimizer.minimize(loss)\n", "\n", "with tf.name_scope(\"eval\"):\n", " correct = tf.nn.in_top_k(logits, y, 1)\n", " accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "WARNING:tensorflow:From /Users/srom/workspace/ds/env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please write your own downloading logic.\n", "WARNING:tensorflow:From /Users/srom/workspace/ds/env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting /tmp/data/train-images-idx3-ubyte.gz\n", "WARNING:tensorflow:From /Users/srom/workspace/ds/env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n", "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n", "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n", "WARNING:tensorflow:From /Users/srom/workspace/ds/env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n" ] } ], "source": [ "mnist = input_data.read_data_sets(\"/tmp/data/\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "n_epochs = 50\n", "batch_size = 50" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 Train accuracy: 0.96 Test accuracy: 0.8897\n", "1 Train accuracy: 0.92 Test accuracy: 0.91\n", "2 Train accuracy: 0.86 Test accuracy: 0.9157\n", "3 Train accuracy: 0.88 Test accuracy: 0.9242\n", "4 Train accuracy: 0.96 Test accuracy: 0.926\n", "5 Train accuracy: 0.94 Test accuracy: 0.929\n", "6 Train accuracy: 0.9 Test accuracy: 0.9317\n", "7 Train accuracy: 0.94 Test accuracy: 0.9359\n", "8 Train accuracy: 0.96 Test accuracy: 0.9381\n", "9 Train accuracy: 0.9 Test accuracy: 0.9403\n", "10 Train accuracy: 0.94 Test accuracy: 0.9427\n", "11 Train accuracy: 0.96 Test accuracy: 0.9424\n", "12 Train accuracy: 0.94 Test accuracy: 0.946\n", "13 Train accuracy: 0.98 Test accuracy: 0.9484\n", "14 Train accuracy: 0.92 Test accuracy: 0.9494\n", "15 Train accuracy: 0.98 Test accuracy: 0.9513\n", "16 Train accuracy: 0.96 Test accuracy: 0.9529\n", "17 Train accuracy: 0.88 Test accuracy: 0.953\n", "18 Train accuracy: 0.98 Test accuracy: 0.9547\n", "19 Train accuracy: 0.98 Test accuracy: 0.9562\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0miteration\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_examples\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mX_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmnist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnext_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraining_op\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mX_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0my_batch\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py\u001b[0m in \u001b[0;36mnext_batch\u001b[0;34m(self, batch_size, fake_data, shuffle)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0mperm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_examples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mperm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_images\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mperm\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_labels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mperm\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;31m# Start next epoch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "with tf.Session() as sess:\n", " init.run()\n", " for epoch in range(n_epochs):\n", " for iteration in range(mnist.train.num_examples // batch_size):\n", " X_batch, _, y_batch, _ = train_test_split(\n", " X_train, y_train, train_size=batch_size, random_state=int(time.time()))\n", " \n", " sess.run(training_op, feed_dict={X_p: X_batch, y_p: y_batch})\n", "\n", " acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n", " acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})\n", "\n", " print(epoch, 'Train accuracy:', acc_train, 'Test accuracy:', acc_test)\n", "\n", " save_path = saver.save(sess, \"./my_model_final.ckpt\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_instance_nb = randint(0, mnist.test.num_examples)\n", "\n", "with tf.Session() as sess:\n", " saver.restore(sess, \"./my_model_final.ckpt\")\n", " X_new_scaled = [mnist.test.images[test_instance_nb]]\n", " Z = logits.eval(feed_dict={X: X_new_scaled})\n", " y_pred = np.argmax(Z, axis=1)\n", "\n", "label = mnist.test.labels[test_instance_nb]\n", "\n", "print\n", "print('Prediction:', y_pred[0], '\\nExpected:', label, '\\n{}'.format('Yaaay' if label == y_pred[0] else 'Noooes'))\n", "\n", "x = mnist.test.images[test_instance_nb].reshape([28, 28])\n", "plt.gray()\n", "plt.axis('off')\n", "_ = plt.imshow(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tf.Session() as sess:\n", " saver.restore(sess, \"./my_model_final.ckpt\")\n", " all_X_test = mnist.test.images\n", " Z = logits.eval(feed_dict={X: all_X_test})\n", " y_pred = np.argmax(Z, axis=1)\n", " \n", " y_nope = [i for i, y_hat in enumerate(y_pred) if y_hat != mnist.test.labels[i]]\n", " \n", "len(y_nope)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_instance_nb = randint(0, len(y_nope))\n", "idx = y_nope[test_instance_nb]\n", "\n", "with tf.Session() as sess:\n", " saver.restore(sess, \"./my_model_final.ckpt\")\n", " X_new_scaled = [mnist.test.images[idx]]\n", " Z = logits.eval(feed_dict={X: X_new_scaled})\n", " y_pred = np.argmax(Z, axis=1)\n", "\n", "label = mnist.test.labels[idx]\n", "\n", "print\n", "print('Prediction:', y_pred[0], '\\nExpected:', label, '\\n{}'.format('Yaaay' if label == y_pred[0] else 'Noooes'))\n", "\n", "x = X_new_scaled[0].reshape([28, 28])\n", "plt.gray()\n", "plt.axis('off')\n", "_ = plt.imshow(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }