{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\"SWAN\"\n", "\"EP-SFT\"\n", "

\n", "

Handwritten Digit Classification using Deep Feed Forward Neural Network

\n", "

Using Apache Spark and Intel BigDL

\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Handwritten Digit Classification using Deep Feed Forward Neural Network Using Spark and BigDL\n", "***Author: Riccardo Castellotti***_ \n", "***Contact: Riccardo Castellotti / Luca Canali / Prasanth Kothuri*** \n", " \n", "To run this notebook we used the following configuration:\n", "* *Software stack*: LCG 94 (it has spark 2.3.1)\n", "* *Platform*: centos7-gcc7\n", "* *Spark cluster*: Analytix\n", "\n", "The tutorial presented will tackle the MNIST digit classification problem. We will build a deep feed forward neural network to help us solve the problem, which is actually a Multilayer Perceptron with two hidden layers. In a feed forward network information always moves one direction instead of ever going backwards, i.e. it has no loop or cycle inside. It's the simplest type of neural networks so it's easy and great to start with to be familiar with how to use BigDL to unleash its power. \n", " \n", "Link to Intel BigDL framework - https://bigdl-project.github.io/0.7.0/" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"/eos/project/s/swan/public/BigDL/bigdl-0.7.0-python-api.zip\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/spark-f9321949-7729-4507-9930-2e5642682f77/userFiles-ebfa5df3-2d3e-47a1-8b1d-cfcf8324de47/bigdl-0.7.0-python-api.zip/bigdl/util/engine.py:41: UserWarning: Find both SPARK_HOME and pyspark. You may need to check whether they match with each other. SPARK_HOME environment variable is set to: /cvmfs/sft.cern.ch/lcg/releases/spark/2.3.1-e21e3/x86_64-centos7-gcc7-opt, and pyspark is found in: /cvmfs/sft.cern.ch/lcg/views/LCG_94/x86_64-centos7-gcc7-opt/lib/python2.7/site-packages/pyspark/__init__.py. If they are unmatched, please use one source only to avoid conflict. For example, you can unset SPARK_HOME and use pyspark only.\n" ] } ], "source": [ "from __future__ import print_function\n", "import matplotlib\n", "import pandas\n", "import numpy as np\n", "import datetime as dt\n", "from bigdl.nn.layer import *\n", "from bigdl.nn.criterion import *\n", "from bigdl.optim.optimizer import *\n", "from bigdl.util.common import *\n", "from bigdl.util import common\n", "from bigdl.dataset.transformer import *\n", "from bigdl.dataset import mnist\n", "import matplotlib.pyplot as plt\n", "from pyspark import SparkContext\n", "from matplotlib.pyplot import imshow" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "init_engine()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load MNIST dataset\n", "First, we should get and store MNIST into RDD of Sample.\n", "Note: *edit the \\\"mnist_path\\\" accordingly. If the \\\"mnist_path\\\" directory does not consist of the mnist data, mnist.read_data_sets method will download the dataset directly to the directory*" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_mnist(sc, mnist_path):\n", " # target is start from 0,\n", " (train_images, train_labels) = mnist.read_data_sets(mnist_path, \"train\")\n", " (test_images, test_labels) = mnist.read_data_sets(mnist_path, \"test\")\n", " training_mean = np.mean(train_images)\n", " training_std = np.std(train_images)\n", " rdd_train_images = sc.parallelize(train_images)\n", " rdd_train_labels = sc.parallelize(train_labels)\n", " rdd_test_images = sc.parallelize(test_images)\n", " rdd_test_labels = sc.parallelize(test_labels)\n", "\n", " rdd_train_sample = rdd_train_images.zip(rdd_train_labels).map(lambda features_label:\n", " common.Sample.from_ndarray(\n", " (features_label[0] - training_mean) / training_std,\n", " features_label[1] + 1))\n", " rdd_test_sample = rdd_test_images.zip(rdd_test_labels).map(lambda features_label1:\n", " common.Sample.from_ndarray(\n", " (features_label1[0] - training_mean) / training_std,\n", " features_label1[1] + 1))\n", " return (rdd_train_sample, rdd_test_sample)\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('Extracting', 'datasets/mnist/train-images-idx3-ubyte.gz')\n", "('Extracting', 'datasets/mnist/train-labels-idx1-ubyte.gz')\n", "('Extracting', 'datasets/mnist/t10k-images-idx3-ubyte.gz')\n", "('Extracting', 'datasets/mnist/t10k-labels-idx1-ubyte.gz')\n", "60000\n", "10000\n" ] } ], "source": [ "mnist_path = \"datasets/mnist\"\n", "(train_data, test_data) = get_mnist(sc, mnist_path)\n", "print(train_data.count())\n", "print(test_data.count())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Hyperparameter setup\n", "NOTE: the `batch_size` variable has to be divisible by the number of cores available (for BigDL to be able to distribute the workload). If you change the configuration of the SparkSession, you also need to change the `batch_size` value. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.2\n", "training_epochs = 15\n", "batch_size = 2048\n", "display_step = 1\n", "# Network Parameters\n", "n_hidden_1 = 256 # 1st layer number of features\n", "n_hidden_2 = 256 # 2nd layer number of features\n", "n_input = 784 # MNIST data input (img shape: 28*28)\n", "n_classes = 10 # MNIST total classes (0-9 digits)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Model creation\n", "Let's define our multilayer_perceptron(MLP) model with 2 hidden layers here." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "creating: createSequential\n", "creating: createReshape\n", "creating: createLinear\n", "creating: createReLU\n", "creating: createLinear\n", "creating: createReLU\n", "creating: createLinear\n", "creating: createLogSoftMax\n" ] } ], "source": [ "def multilayer_perceptron(n_hidden_1, n_hidden_2, n_input, n_classes):\n", "# Initialize a sequential container\n", " model = Sequential()\n", " # Hidden layer with ReLu activation\n", " model.add(Reshape([28*28]))\n", " model.add(Linear(n_input, n_hidden_1).set_name('mlp_fc1'))\n", " model.add(ReLU())\n", " # Hidden layer with ReLu activation\n", " model.add(Linear(n_hidden_1, n_hidden_2).set_name('mlp_fc2'))\n", " model.add(ReLU())\n", " # output layer\n", " model.add(Linear(n_hidden_2, n_classes).set_name('mlp_fc3'))\n", " model.add(LogSoftMax())\n", " return model\n", "\n", "model = multilayer_perceptron(n_hidden_1, n_hidden_2, n_input, n_classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Optimizer setup and training\n", "Let's create an optimizer for training. As presented in the code, we are trying to optimize a [ClassNLLCriterion](https://bigdl-project.github.io/master/#APIGuide/Losses/#classnllcriterion) and use Stochastic Gradient Descent to update the weights. Also in order to enable visualization support, we need to [generate summary info in BigDL](https://bigdl-project.github.io/master/#ProgrammingGuide/visualization/" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "creating: createClassNLLCriterion\n", "creating: createDefault\n", "creating: createSGD\n", "creating: createMaxEpoch\n", "creating: createDistriOptimizer\n", "creating: createEveryEpoch\n", "creating: createTop1Accuracy\n", "creating: createTrainSummary\n", "creating: createSeveralIteration\n", "creating: createValidationSummary\n", "saving logs to multilayer_perceptron-20181206-201037\n" ] } ], "source": [ "optimizer = Optimizer(\n", "model=model,\n", "training_rdd=train_data,\n", "criterion=ClassNLLCriterion(),\n", "optim_method=SGD(learningrate=learning_rate),\n", "end_trigger=MaxEpoch(training_epochs),\n", "batch_size=batch_size)\n", "\n", "# Set the validation logic\n", "optimizer.set_validation(\n", "batch_size=batch_size,\n", "val_rdd=test_data,\n", "trigger=EveryEpoch(),\n", "val_method=[Top1Accuracy()]\n", ")\n", "\n", "app_name='multilayer_perceptron-'+dt.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", "train_summary = TrainSummary(log_dir='/tmp/bigdl_summaries',\n", " app_name=app_name)\n", "train_summary.set_summary_trigger(\"Parameters\", SeveralIteration(50))\n", "val_summary = ValidationSummary(log_dir='/tmp/bigdl_summaries',\n", "app_name=app_name)\n", "optimizer.set_train_summary(train_summary)\n", "optimizer.set_val_summary(val_summary)\n", "print(\"saving logs to \",app_name)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimization Done.\n", "CPU times: user 22.5 s, sys: 11.4 s, total: 33.9 s\n", "Wall time: 2min 36s\n" ] } ], "source": [ "%%time\n", "# Boot training process\n", "trained_model = optimizer.optimize()\n", "print(\"Optimization Done.\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Loss visualization\n", "\n", "\"After training, we can draw the preformance curves from the previous `train_summary` and `val_summary` variables.\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "loss = np.array(train_summary.read_scalar(\"Loss\"))\n", "top1 = np.array(val_summary.read_scalar(\"Top1Accuracy\"))\n", "\n", "plt.figure(figsize = (12,12))\n", "\n", "ax=plt.subplot(2,1,1)\n", "ax.plot(loss[:,0],loss[:,1])\n", "plt.xlim(0,loss.shape[0]+10)\n", "ax.grid(True)\n", "ax.set_xlabel(\"Iterations\",fontSize=15)\n", "ax.set_ylabel(\"Loss\",fontSize=15)\n", "plt.title(\"Loss\",fontSize=15)\n", "\n", "ax=plt.subplot(2,1,2)\n", "plt.plot(top1[:,0],top1[:,1])\n", "plt.xlim(0,loss.shape[0]+10)\n", "plt.title(\"Validation accuracy\",fontSize=15)\n", "ax.set_xlabel(\"Iterations\",fontSize=15)\n", "ax.set_ylabel(\"Validation accuracy\",fontSize=15)\n", "plt.grid(True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Prediction on test data\n", "\n", "Now, let's see the prediction results on test data by our trained model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def map_predict_label(l):\n", " return np.array(l).argmax()\n", "def map_groundtruth_label(l):\n", " return int(l[0] - 1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ground Truth labels:\n", "7, 2, 1, 0, 4, 1, 4, 9\n", "Predicted labels:\n", "7, 2, 1, 0, 4, 1, 4, 9\n", "CPU times: user 83 ms, sys: 29.7 ms, total: 113 ms\n", "Wall time: 1.06 s\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABMCAYAAAB9PUwnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEH1JREFUeJzt3XlwzPf/wPFnvhGtuNqQxBlHjKKqlB5DStODtIjGrSftOFrUKOpqURVmVNvpIU3NUKWGSBO3Okovk2LQhCLuo0bjqBJ1TGX3/ftjf++33TRhN3b3I5++HjOfsbvZ3c/bHq99f17v1/v9CVFKIYQQovT7n9UNEEII4R8S0IUQwiYkoAshhE1IQBdCCJuQgC6EEDYhAV0IIWxCAroQQtiEBHQhhLAJCehCCGETZYK5s5CQEJmWKoQQPlJKhXhzP+mhCyGETUhAF0IIm5CALoQQNhHUHPp/wciRIylXrhwAzZo1o3v37uZvn3/+OQC//PIL8+fPt6R9Qgj7kh66EELYhVIqaBug7LqlpaWptLQ05XA4brrt379fxcTEqJiYGMvbfaOtYcOGyul0KqfTqYYOHWp5e8qXL69SUlJUSkqKcjgcauvWrWrr1q2qTp06lrdNNtkCuXkbYyXl4gdpaWkeqRUtNzeXtWvXAlC/fn06d+4MQGxsLM8//zwA06ZNC15DfdSiRQucTicAJ06csLg1UL16dfr37w+A0+mkZcuWAHTq1ImZM2da1q4HHngAgMzMTOrWrev149q3b8/evXsB+P333wPRNJ/pz+iyZcsYOnQoAKmpqTgcjqDsPyoqisWLFwOQlZXFrFmzADh69KhPz1O5cmXatm0LwJo1a7h27Zpf23m7kpSLEELYhPTQb0GrVq0ASEpKMrft3r2bxMREAM6ePcvff/8NQNmyZdm8eTMA999/P1WqVAlya33XvHlzLl26BMCSJUssa0dkZCQAX331lWVtuJEOHToAcMcdd/j0uM6dO/PKK68A0Lt3b7+3y1dVqlQhJSXFXP/0008BmDNnDleuXAnovu+++27A9f2pXLkyAKdOnSpRzxxg+/bt5nPTsmVLDh486L/GFlKpUiVzpN20aVOefPJJAEuOCm67gN69e3dzWH3y5EmuXr0KwIIFC8jLywMI6Jvji+rVqwMQEhLC7t27AdeX+48//vjXfUeMGEGTJk3M9VWrVgWnkSXQtGlTAIYMGWJ5Nc4bb7zBs88+C8BDDz1U5H3atm3L//7nOtjMycnhp59+Clr7ypQpwzPPPFOix27fvp0333wTgPLly5sfT6u0bduWmjVrmusLFy4EMN/BQKlatSppaWkAREREmB8VnfLxxdtvvw1AvXr1GDhwIBC4eKHTpsnJydSuXdvcXqlSJQD+/PPPgOz3RiTlIoQQNhHy/9UnwdmZF2u5HD58uNiBpYsXLwKY3rC3Tpw4wfTp0wHYtm2bT4/1Rp06dUzbzp07V+R9cnJyTM8XMIdl33//vd/bc6v0AO/ixYuJj48H4Mcff7SkLQ6HwwzMFqZ75e5/P3bsGL169QJcPeBAe+qpp/j2228BmD59OuPGjfP6scOHD+f9998HXEd7Z86cCUgbb0anijZt2mQGmgE6duwIYP5/gdK+fXuPfVSrVg3A59fj3nvvZdeuXYArRdi3b1/getzwp1q1avHrr78CrlSVexzVRxtDhgwpNh74ytu1XG67lEv//v1p1qwZAHv37qVx48aAq5LgscceA+CRRx4xVQHuhzoABQUFgOvDoFMiAMePHwcCE9CPHTtW7N9GjRoFQMOGDc1tW7ZsYcuWLX5vh7+89dZbgOv/FYjXyxurV68GrgftouhD2r///ps6deoArkPtrVu3AhAaGhqw9ukf54ULF3Lo0CEApk6d6tNzdOnSxe/tKon77rsPwCOYFxQUBDyQR0VFAdCtWzdz26uvvlqiQA7w3XffmduWLFkSkECujRw5koiIiCL/pjsUCQkJJCcnA67xiH/++Sdg7dEk5SKEEDZx2/XQN2zYwIYNG8z1NWvWmMt6JLx58+bmcPrBBx/0eLwewNm/f7+p8Y2IiDC9qGDq1KkTkydPBlxVLqdPnwZg7NixXL58Oejt8UbdunVN9c7+/fstGahr164d99xzD+BKpxSVcklNTWXdunUAXLhwgccffxyA8ePHm/u89tprZrkFf9ODb+XLlychIQHAVDTdjO7ZtWvXrth0UjB17dr1X7fp+ROB9MEHHwDwwgsvmO9zenq6z8/z6KOPAhAdHc3cuXMB+Prrr/3TyEL0kWC/fv3MbTt37uTUqVPA9VQquCpuRo4cCXgWdQTSbRfQb+Svv/4CPPPO7sHfXbdu3cwPwK5du0xeK5hatWpF2bJlzXXdBqvy0d5o166duRzsnK4eO1m0aBFVq1b919+PHTtGRkYGAO+++67Hj6JOew0YMMCUq02fPp0777wTgM8++8xvZWTdu3c3lS0HDx70OS2lf3ScTic//PADAOfPn/dL20pCB0TApAXeeeedgO9X552dTicnT5702P/N6PWSxo0bx+uvv26eT5eBBkrz5s0BqFixIj///DPg+s7oz1mfPn3MOEpsbKwZD1i2bBlPP/00UPw4mz9IykUIIWyiVPXQvaEHWlJSUsyA2uTJkwP6q1jY0qVLAdfovTZv3jxzmH470wNkgKkMCpYyZVwfx8K9c31E07t3b86ePVvkY3UPfdq0aXz44YcAhIeHm//D8uXL/ZZ269GjB+Hh4QAeE3G8UbduXVO/7HA4mDJlCmDNJBSA1q1b06ZNG3Ndp9iys7OD2g5dUbNu3TpztFJcuqxdu3YeBRLaN998E9hGcr0iSCnFRx99ZG7Xqd4vv/ySHj16AK7lPrTLly8HZVDUdgF98ODBgGt2oU7R7Nu3L2j7r169Oq1btwZcb74OQFOmTPE6x2oF/cXo16+fKcdav369lU0CXFVJ+jC6uGDubvny5SZgFh5fuVV6FqJ7EPE1Rz9gwADzg7V3717Ly1YLv0a+/kDdio8//hiA+Ph4atSoAbgmN4WEuCr09IzrwkJCQjzKBA8fPgzgU8loSfXp08dc1j9CugOn6TEod5s3bw7K919SLkIIYRO26qG3adOGMWPGmOt6yvhvv/0WtDZkZGR4rNOiR9utqLLxhR6dj4iIMJVFgZ7yXRz32vOHH37Yp8eGhISYx7s/z6RJk3jxxRdvqV36cLtmzZpmWryvYmNjzeVgfi6L496bPH/+PF988UXQ9q0rW5o1a2YGGxMSEszcjTNnzhS5fs/8+fPJyckx17OysoDgfMf0+56YmGiObho1amRSlUlJSaYY4/z58+Zy//79zTIae/bsCVwD7bQeenJyslm/e/369SosLEyFhYUFZb3ixMRElZiYqK5evWrWPd+wYYOqUKGCqlChguXrKd9sS09PV+np6crpdKqkpCSVlJQU9DbMmDFDzZgxQ127ds1svj7H0KFDzWMdDoe5HBsbe8vtK1eunCpXrpzatm2bysnJUTk5OSoiIsKrx0ZFRamoqCiPdfEHDx5s2fsdFxen4uLiVEFBgfnOHDlyxPLPoTdb/fr1TZt37NihIiMjVWRkZFD2HRERoSIiItS5c+fM++h0Oj3e17Vr16q1a9eqBg0aqNzcXJWbm6scDodKTU1VqampJdqvtzFWUi5CCGETtki56JrUhIQEM5I8ceLEoFUOVKlSxQzIhIWFmduzs7Nv64FQrVq1aqYWed++fZYtlatPruCryMhIs5Jl4YExXUvvj8+CXkL20KFDZrr6qlWrTFVNYXp5gPr165sae/fBPCsnFem0oHta6nYYBPfGhAkTzOs4evTooM6X0NVyPXv2NFU1erAcXFP8R48eDbhSlpmZmQCMGTPGLLMcGxsbsPSQLQK6zrm1aNHC5H91Xi0YRowY4VEtoEe9J06cGLQ23Iq+ffuacs9Ar98RCOPHjzfVTe6OHj3Kyy+/DFxfy8cfJk6caCoxOnbsWGw+XVflKKWKnCilZzVawf0MW7pMUJ8d6HalywFfeukls06LFUvUgmvdGP0aPvfcc+Y1nDBhgsfY03vvvQdA48aNTdXOhAkTzOfS3yTlIoQQNlHqe+gdO3Y005Tz8/PN2inBpE9SoA0ZMgTwfm0Pq+n1KeD68gqlgV6RUa/7UtiePXvYtGmT3/ebm5tLz549AddU8AYNGhR5P/eJLrpaQ9fIAwE/C1BxatWq5VFPrc8Xa9XKmt7SU+cBVq5cCcCOHTusao5Z3dF9lcfC9HuclpZmeujx8fFmPR9/T3gstQFd5wA/+eQTs0zq6tWrzWnerKTfrOLythcuXDB/CwsL88jB3XXXXcC/fyT0SXpHjx7t94W9OnXqZC6vWLHCr8/tC53GcM/run+JZ82aZSaguN+vuFx0SXPyvsjOzvZqVqWe/OKuadOmlpQutm7d2uM1XrZsWdDbUBL6s3Dp0iWzsFdpsXjxYhPQe/XqZTp9/u6ASspFCCFsolT20ENDQ83gZ7169cyIcTBWiPPGzp07b/j39PR0c97R6OhosyC+N/Ly8syi+bcqLi4OuH6GGKvpafTua8isXLnSowdeVG+88G2pqakBamHJ6aMP/S9YN7HIfeLb2bNnzRT829mgQYOIjo4G4PTp05amWkrC6XSaz3WXLl1MwcSiRYvYv3+/3/ZTKgN6bGysx9lVdHrCqtmYq1ev9unsM3q0vrCCggKP4LR8+XLAM7epl+z0h6SkJMD1A6nXbwnmCZYL0yVeo0aNMkvgeuPMmTNm7fsBAwYUeZJuq+kyu2Ce8rE47ovGHT9+nAsXLljYGu8MGjTIvHbuJ1ivWLGimY3pz0qmQNCpuQkTJphTD06dOtXMYPbHmIqkXIQQwiZKVQ9dV2PoM9WAqzenR7yt0rVrV3MeTveJRXD9fIeF0ypz5swBXLXSWkZGBrm5uQFs6XXh4eHmJA1wvSJDD75aQS+B27t3b7MOz7Bhw276uOTkZGbOnBnQtt0qfQIEsK66RX823atyrly5YtnSvSXlcDhMtdDw4cPNSeMDVdvtb/PmzWPgwIGAK3bogdGbpWq9UprWcklOTvZYr8XpdKpWrVpZvrZEadzCwsJUVlaWysrKUkuXLlXh4eEqPDzc8nYV3hISElRmZqbKzMxU165dUxkZGSojI0N16NBBJSQkqISEBBUTE2N5O2+25eXlqby8PHX27Fk1bNgwNWzYsKC3ITQ0VIWGhqo5c+aY78/cuXMtf2282bKzs4tcO2XWrFmqdu3aqnbt2pa30ZctJiZGxcTEKKfTqRYsWKAWLFhww/vLWi5CCPFfU1p66HFxcSo/P1/l5+dLD122UretWLFCrVixQsXHx1velho1aqjZs2er2bNnW7rioy9bXFyc2rhxo9q4caOaNGmSio6OVtHR0aps2bKWt+1WtnXr1qmLFy+qixcvqiZNmhR7P69jbGkJ6GPHjvUI5AcOHFAHDhxQjRo1svxNkU022WQryVapUiV15MgRdeTIEZWYmFjs/STlIoQQ/zGlqspFy8nJ4YknngD8vxaCEEIES35+PvXq1fPb84UEc6JDSEhI8HYmhBA2oZQKufm9ZGKREELYRlB76EIIIQJHeuhCCGETEtCFEMImJKALIYRNSEAXQgibkIAuhBA2IQFdCCFsQgK6EELYhAR0IYSwCQnoQghhExLQhRDCJiSgCyGETUhAF0IIm5CALoQQNiEBXQghbEICuhBC2IQEdCGEsAkJ6EIIYRMS0IUQwiYkoAshhE1IQBdCCJuQgC6EEDYhAV0IIWxCAroQQtjE/wFIye0NfFiqDQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%time\n", "predictions = trained_model.predict(test_data)\n", "imshow(np.column_stack([np.array(s.features[0].to_ndarray()).reshape(28,28) for s in test_data.take(8)]),cmap='gray'); plt.axis('off')\n", "print('Ground Truth labels:')\n", "print(', '.join(str(map_groundtruth_label(s.label.to_ndarray())) for s in test_data.take(8)))\n", "print('Predicted labels:')\n", "print(', '.join(str(map_predict_label(s)) for s in predictions.take(8)))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15" }, "sparkconnect": { "bundled_options": [], "list_of_options": [ { "name": "spark.jars", "value": "/eos/project/s/swan/public/BigDL/bigdl-SPARK_2.3-0.7.0-jar-with-dependencies.jar" }, { "name": "spark.scheduler.minRegisteredResourcesRatio", "value": "1.0" }, { "name": "spark.shuffle.reduceLocality.enabled", "value": "false" }, { "name": "spark.shuffle.blockTransferService", "value": "nio" }, { "name": "spark.dynamicAllocation.enabled", "value": "false" }, { "name": "spark.speculation", "value": "false" }, { "name": "spark.executor.instances", "value": "8" }, { "name": "spark.executor.cores", "value": "4" }, { "name": "spark.yarn.dist.files", "value": "/eos/project/s/swan/public/BigDL/bigdl-0.7.0-python-api.zip" } ] } }, "nbformat": 4, "nbformat_minor": 2 }