{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST Analysis with Distributed Keras\n", "\n", "**Joeri Hermans** (Technical Student, IT-DB-SAS, CERN) \n", "*Departement of Knowledge Engineering* \n", "*Maastricht University, The Netherlands*" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18 January 2017\r\n" ] } ], "source": [ "!(date +%d\\ %B\\ %G)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will show you how to process the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset using Distributed Keras. As in the [workflow](https://github.com/JoeriHermans/dist-keras/blob/master/examples/workflow.ipynb) notebook, we will guide you through the complete machine learning pipeline.\n", "\n", "## Preparation\n", "\n", "To get started, we first load all the required imports. Please make sure you installed `dist-keras`, and `seaborn`. Furthermore, we assume that you have access to an installation which provides Apache Spark.\n", "\n", "Before you start this notebook, place make sure you ran the \"MNIST preprocessing\" notebook first, since we will be evaluating a manually \"enlarged dataset\"." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "\n", "from keras.optimizers import *\n", "from keras.models import Sequential\n", "from keras.layers.core import *\n", "from keras.layers.convolutional import *\n", "\n", "from pyspark import SparkContext\n", "from pyspark import SparkConf\n", "\n", "from matplotlib import pyplot as plt\n", "\n", "from pyspark import StorageLevel\n", "\n", "from pyspark.ml.feature import StandardScaler\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.ml.feature import OneHotEncoder\n", "from pyspark.ml.feature import MinMaxScaler\n", "from pyspark.ml.feature import StringIndexer\n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", "\n", "from distkeras.trainers import *\n", "from distkeras.predictors import *\n", "from distkeras.transformers import *\n", "from distkeras.evaluators import *\n", "from distkeras.utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the following cell, adapt the parameters to fit your personal requirements." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Modify these variables according to your needs.\n", "application_name = \"Distributed Keras MNIST Analysis\"\n", "using_spark_2 = False\n", "local = False\n", "path = \"mnist.parquet\"\n", "if local:\n", " # Tell master to use local resources.\n", " master = \"local[*]\"\n", " num_processes = 3\n", " num_executors = 1\n", "else:\n", " # Tell master to use YARN.\n", " master = \"yarn-client\"\n", " num_executors = 30\n", " num_processes = 1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of desired executors: 30\n", "Number of desired processes / executor: 1\n", "Total number of workers: 30\n" ] } ], "source": [ "# This variable is derived from the number of cores and executors, and will be used to assign the number of model trainers.\n", "num_workers = num_executors * num_processes\n", "\n", "print(\"Number of desired executors: \" + `num_executors`)\n", "print(\"Number of desired processes / executor: \" + `num_processes`)\n", "print(\"Total number of workers: \" + `num_workers`)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "conf = SparkConf()\n", "conf.set(\"spark.app.name\", application_name)\n", "conf.set(\"spark.master\", master)\n", "conf.set(\"spark.executor.cores\", `num_processes`)\n", "conf.set(\"spark.executor.instances\", `num_executors`)\n", "conf.set(\"spark.locality.wait\", \"0\")\n", "conf.set(\"spark.executor.memory\", \"5g\")\n", "conf.set(\"spark.serializer\", \"org.apache.spark.serializer.KryoSerializer\");\n", "\n", "# Check if the user is running Spark 2.0 +\n", "if using_spark_2:\n", " sc = SparkSession.builder.config(conf=conf) \\\n", " .appName(application_name) \\\n", " .getOrCreate()\n", "else:\n", " # Create the Spark context.\n", " sc = SparkContext(conf=conf)\n", " # Add the missing imports\n", " from pyspark import SQLContext\n", " sqlContext = SQLContext(sc)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Check if we are using Spark 2.0\n", "if using_spark_2:\n", " reader = sc\n", "else:\n", " reader = sqlContext\n", "# Read the training and test set.\n", "training_set = reader.read.parquet('data/mnist_train_big.parquet') \\\n", " .select(\"features_normalized_dense\", \"label_encoded\", \"label\")\n", "test_set = reader.read.parquet('data/mnist_test_preprocessed.parquet') \\\n", " .select(\"features_normalized_dense\", \"label_encoded\", \"label\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- features_normalized_dense: vector (nullable = true)\n", " |-- label_encoded: vector (nullable = true)\n", " |-- label: long (nullable = true)\n", "\n" ] } ], "source": [ "# Print the schema of the dataset.\n", "training_set.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Development" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multilayer Perceptron" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "mlp = Sequential()\n", "mlp.add(Dense(1000, input_shape=(784,)))\n", "mlp.add(Activation('relu'))\n", "mlp.add(Dropout(0.2))\n", "mlp.add(Dense(200))\n", "mlp.add(Activation('relu'))\n", "mlp.add(Dropout(0.2))\n", "mlp.add(Dense(10))\n", "mlp.add(Activation('softmax'))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "____________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "====================================================================================================\n", "dense_1 (Dense) (None, 1000) 785000 dense_input_1[0][0] \n", "____________________________________________________________________________________________________\n", "activation_1 (Activation) (None, 1000) 0 dense_1[0][0] \n", "____________________________________________________________________________________________________\n", "dropout_1 (Dropout) (None, 1000) 0 activation_1[0][0] \n", "____________________________________________________________________________________________________\n", "dense_2 (Dense) (None, 200) 200200 dropout_1[0][0] \n", "____________________________________________________________________________________________________\n", "activation_2 (Activation) (None, 200) 0 dense_2[0][0] \n", "____________________________________________________________________________________________________\n", "dropout_2 (Dropout) (None, 200) 0 activation_2[0][0] \n", "____________________________________________________________________________________________________\n", "dense_3 (Dense) (None, 10) 2010 dropout_2[0][0] \n", "____________________________________________________________________________________________________\n", "activation_3 (Activation) (None, 10) 0 dense_3[0][0] \n", "====================================================================================================\n", "Total params: 987210\n", "____________________________________________________________________________________________________\n" ] } ], "source": [ "mlp.summary()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "optimizer_mlp = 'adam'\n", "loss_mlp = 'categorical_crossentropy'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "Prepare the training and test set for evaluation and training." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of training instances: 6060000\n", "Number of testing instances: 10000\n" ] } ], "source": [ "training_set = training_set.repartition(num_workers)\n", "test_set = test_set.repartition(num_workers)\n", "training_set.cache()\n", "test_set.cache()\n", "print(\"Number of training instances: \" + str(training_set.count()))\n", "print(\"Number of testing instances: \" + str(test_set.count()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n", "\n", "We define a utility function which will compute the accuracy for us." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def evaluate_accuracy(model, test_set, features=\"features_normalized_dense\"):\n", " evaluator = AccuracyEvaluator(prediction_col=\"prediction_index\", label_col=\"label\")\n", " predictor = ModelPredictor(keras_model=model, features_col=features)\n", " transformer = LabelIndexTransformer(output_dim=10)\n", " test_set = test_set.select(features, \"label\")\n", " test_set = predictor.predict(test_set)\n", " test_set = transformer.transform(test_set)\n", " score = evaluator.evaluate(test_set)\n", " \n", " return score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ADAG" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "trainer = ADAG(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=num_workers,\n", " batch_size=4, communication_window=5, num_epoch=1,\n", " features_col=\"features_normalized_dense\", label_col=\"label_encoded\")\n", "# Modify the default parallelism factor.\n", "trained_model = trainer.train(training_set)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[array([[-0.02490237, -0.01861665, 0.03102627, ..., 0.01722135,\n", " 0.02223415, -0.04933412],\n", " [-0.02634868, 0.03564246, -0.05392314, ..., -0.02999102,\n", " -0.01270337, -0.03888189],\n", " [ 0.00727941, 0.04553502, -0.01856072, ..., 0.0319587 ,\n", " -0.00354035, -0.03581727],\n", " ..., \n", " [-0.03245988, -0.01220334, 0.019447 , ..., 0.05723321,\n", " -0.05618715, -0.0248918 ],\n", " [-0.02532675, -0.01772211, 0.05514754, ..., 0.03839124,\n", " -0.05036234, -0.03766601],\n", " [ 0.04610632, 0.01409597, 0.03790993, ..., -0.02038677,\n", " -0.03649681, 0.04742099]], dtype=float32),\n", " array([ -1.29682487e-02, 1.38744503e-01, -3.10007334e-01,\n", " -3.04996595e-02, -1.39434069e-01, -4.05185074e-02,\n", " -2.09797233e-01, -4.62490469e-01, -6.72216356e-01,\n", " -1.83647368e-02, -2.93090612e-01, 5.11649624e-02,\n", " -2.74094105e-01, -9.03906003e-02, -7.21242726e-01,\n", " -2.51375604e-02, -1.40052319e-01, -1.31754786e-01,\n", " -1.88921779e-01, -3.18406552e-01, -3.45931239e-02,\n", " -1.89292878e-01, 3.80539931e-02, 3.54425013e-02,\n", " -6.34538352e-01, -2.27093436e-02, -5.49978614e-01,\n", " -2.85222325e-02, -4.87636119e-01, -2.94719964e-01,\n", " -4.62469608e-01, -4.31859016e-01, -4.95594800e-01,\n", " -7.55963206e-01, -7.07836151e-01, 5.50588481e-02,\n", " 1.01570776e-02, -3.62383217e-01, -2.37895608e-01,\n", " -3.48139226e-01, -5.14193960e-02, -4.49353665e-01,\n", " -2.04702299e-02, -1.28980473e-01, -6.01515993e-02,\n", " -4.11046803e-01, -2.73511171e-01, -4.22501177e-01,\n", " 6.57678917e-02, -3.77899945e-01, -3.68858546e-01,\n", " -3.45079124e-01, -1.21501423e-01, -2.59954304e-01,\n", " -2.77339309e-01, 7.24700987e-02, -1.75704360e-01,\n", " -1.79602101e-01, -3.49472016e-01, -4.22441006e-01,\n", " -3.98772031e-01, 4.78056073e-02, 1.63912345e-02,\n", " -1.73481293e-02, 2.03711018e-01, -1.66458517e-01,\n", " -2.50248574e-02, -4.33256328e-01, -1.77355483e-02,\n", " -6.68845698e-02, -6.33655787e-02, -2.07219645e-01,\n", " -2.81381667e-01, -2.10354477e-01, 9.65033993e-02,\n", " 1.45252123e-01, -1.62108362e-01, -4.10078391e-02,\n", " -5.01093924e-01, 6.61657602e-02, -3.54006797e-01,\n", " -2.72664815e-01, -4.63590562e-01, -2.76888013e-01,\n", " 5.67168836e-03, -1.63264722e-02, -5.64372167e-02,\n", " -3.27719487e-02, -1.25738844e-01, -3.16582769e-02,\n", " -3.16652000e-01, 2.20678657e-01, -4.90398854e-01,\n", " -3.87180448e-01, 4.62217331e-02, -3.87124509e-01,\n", " 3.44271868e-01, -6.47646427e-01, -4.47504744e-02,\n", " -3.12687427e-01, -3.64519686e-01, -1.19691178e-01,\n", " -1.22579239e-01, -1.74031451e-01, -3.50467891e-01,\n", " -3.85930926e-01, -1.01258140e-02, 1.65355578e-01,\n", " 2.38174275e-02, -3.86843532e-01, -2.11541757e-01,\n", " -1.60455573e-02, -3.41660500e-01, -2.41097137e-01,\n", " -3.58184397e-01, -3.74646991e-01, -5.68306029e-01,\n", " 6.03663735e-02, -2.25287676e-01, -3.33954960e-01,\n", " -3.21863830e-01, -5.74063025e-02, -9.54797715e-02,\n", " -1.69863552e-01, 5.25663458e-02, -1.78944767e-01,\n", " -4.96068239e-01, -9.37457308e-02, -4.91037033e-02,\n", " -5.45800686e-01, -4.19147074e-01, -3.63402218e-01,\n", " -9.55256671e-02, -6.56951070e-02, -4.74279895e-02,\n", " 3.94136347e-02, -6.89108312e-01, -6.40569270e-01,\n", " -2.92730868e-01, -4.21674043e-01, -9.05798003e-02,\n", " -9.85799953e-02, -3.34262311e-01, -2.91352630e-01,\n", " -1.20481804e-01, -1.30824670e-01, -3.15101117e-01,\n", " -3.82897407e-01, -3.67818296e-01, -2.51174152e-01,\n", " -4.45220284e-02, -3.63316804e-01, -5.95236719e-01,\n", " -3.27549487e-01, -5.18906057e-01, -1.80942759e-01,\n", " -1.93147764e-01, -1.63675278e-01, 5.25709763e-02,\n", " -1.69222236e-01, -1.66612849e-01, -1.89764783e-01,\n", " 9.59388837e-02, -1.79865390e-01, -2.87416220e-01,\n", " -1.37040511e-01, -3.68917108e-01, -1.97503880e-01,\n", " -4.80307907e-01, -9.74704884e-03, -1.62035048e-01,\n", " -4.33685966e-02, -3.75206321e-01, -2.71574229e-01,\n", " -2.51338482e-01, -1.91602707e-01, -4.66123730e-01,\n", " -3.09535444e-01, -3.18885483e-02, -3.23637798e-02,\n", " -3.71796012e-01, -2.26407617e-01, -4.69909385e-02,\n", " -3.70391518e-01, -5.37406743e-01, -5.00004053e-01,\n", " -4.49130647e-02, 1.55784473e-01, -3.39550585e-01,\n", " -5.15295863e-01, -5.79936266e-01, 4.80024889e-03,\n", " -1.23718642e-01, -6.55675307e-02, -2.74233013e-01,\n", " -2.67147571e-01, -4.20176655e-01, -2.30046362e-02,\n", " -2.80579627e-01, -6.52074635e-01, -2.07271874e-01,\n", " -3.34823787e-01, -5.11079669e-01, -4.89039391e-01,\n", " -1.69896662e-01, -6.09769404e-01, 1.67333558e-01,\n", " -1.52619872e-02, -1.82103708e-01, -1.59035064e-02,\n", " -2.82586038e-01, -4.48576622e-02, -2.77401984e-01,\n", " -1.18868940e-01, -3.09958905e-01, -4.54939663e-01,\n", " -6.84868218e-03, -1.78479820e-01, -4.12694991e-01,\n", " -4.86943096e-01, -4.83419180e-01, -2.92061418e-01,\n", " -3.56696308e-01, -2.38492072e-01, -1.99521467e-01,\n", " -6.62643433e-01, -6.58789635e-01, -3.13386142e-01,\n", " -2.39210613e-02, 3.81695509e-01, 3.89514342e-02,\n", " -4.21914130e-01, -1.78643346e-01, -3.58139843e-01,\n", " -2.31155585e-02, -5.25866091e-01, -2.01350115e-02,\n", " 1.34515122e-01, -4.72941786e-01, 1.28511051e-02,\n", " -1.92628369e-01, -2.94919074e-01, -1.21810228e-01,\n", " -2.63900816e-01, -1.77175865e-01, -3.85966711e-02,\n", " -3.91167760e-01, -3.54940116e-01, -4.08377945e-02,\n", " -2.46946454e-01, -1.70614153e-01, 9.64559093e-02,\n", " -1.58487067e-01, -1.40857771e-01, -2.60191988e-02,\n", " -2.16996279e-02, -2.01046526e-01, 1.07773796e-01,\n", " -7.25519285e-02, -4.59324010e-02, -3.97602469e-01,\n", " -2.86683738e-01, -2.06594560e-02, -2.32254282e-01,\n", " -1.47455707e-01, -2.11738929e-01, -3.97648931e-01,\n", " -1.92232862e-01, -4.22664315e-01, -2.10082695e-01,\n", " -3.69767874e-01, -3.35989922e-01, -2.50372291e-02,\n", " -2.56772131e-01, -7.55918026e-01, -1.45749766e-02,\n", " -5.94904542e-01, -1.83992922e-01, -1.98239967e-01,\n", " 2.28624657e-01, -3.67346585e-01, -2.17467710e-01,\n", " -8.19451883e-02, -5.01424968e-02, -3.00576668e-02,\n", " 2.42029456e-03, -6.11475348e-01, -2.48637870e-01,\n", " -1.25368005e-02, -1.07831452e-02, 3.56794626e-01,\n", " -2.73973256e-01, -5.00894673e-02, -3.93987626e-01,\n", " -6.70151055e-01, 5.03201634e-02, -3.47819924e-01,\n", " 2.21592330e-04, -9.35477093e-02, -4.01370734e-01,\n", " -5.17268419e-01, -2.08003540e-02, -1.58300679e-02,\n", " 1.09454863e-01, 4.86627640e-03, -4.40006703e-01,\n", " 1.10145152e-01, -3.08435559e-01, -2.27646939e-02,\n", " -6.15591705e-02, -6.83150813e-02, 1.51192188e-01,\n", " -2.93954074e-01, 1.76271528e-01, -5.47897398e-01,\n", " -2.94454783e-01, -4.87583935e-01, -2.25682836e-02,\n", " -2.61891991e-01, -2.05876276e-01, -2.91871820e-02,\n", " -4.65158612e-01, -1.10427953e-01, 2.59957045e-01,\n", " -6.44603491e-01, -5.89241982e-01, -2.40099952e-01,\n", " -2.48620026e-02, 2.60877088e-02, -3.69062722e-01,\n", " -5.85998118e-01, 6.35902397e-04, 1.52950898e-01,\n", " -1.31705374e-01, -6.95600629e-01, -6.93177283e-02,\n", " -3.34524751e-01, -2.05166377e-02, -4.04433101e-01,\n", " -3.34488690e-01, 4.12484966e-02, -1.07743412e-01,\n", " -2.31767640e-01, -5.87181449e-01, -1.24916852e-01,\n", " -2.45317779e-02, -4.82061923e-01, 4.29915352e-04,\n", " -2.29062542e-01, -1.53157920e-01, -8.75511765e-02,\n", " -1.93034634e-01, -2.39149824e-01, -2.81021118e-01,\n", " -1.92091212e-01, 4.84096706e-02, -3.15482467e-01,\n", " -9.38970945e-04, -7.32823536e-02, 1.46180347e-01,\n", " -7.48398662e-01, -2.95927972e-01, -1.01935327e-01,\n", " -2.25223079e-02, -3.76603395e-01, -3.72446418e-01,\n", " -5.44973463e-02, -3.04856654e-02, -8.12882781e-01,\n", " -6.35300994e-01, 1.01717256e-01, 1.15769980e-02,\n", " 1.94745436e-01, -4.62203443e-01, -1.94413647e-01,\n", " -1.19787067e-01, 5.01835823e-01, -1.22532628e-01,\n", " -4.83275265e-01, -5.72950900e-01, -1.68230399e-01,\n", " -2.53478941e-02, -8.93718377e-02, -2.09907755e-01,\n", " 1.15736432e-01, 7.35889524e-02, -2.25963101e-01,\n", " -1.25411734e-01, -1.58686683e-01, 3.05348307e-01,\n", " -4.07805927e-02, -6.87129676e-01, -1.78614125e-01,\n", " -6.12517297e-02, -1.26590893e-01, -5.44444025e-01,\n", " -2.87909880e-02, -1.61622658e-01, -6.28022432e-01,\n", " -3.93144011e-01, -4.14166540e-01, -3.36472809e-01,\n", " -2.14290902e-01, -1.57012552e-01, -6.99233487e-02,\n", " -1.79140717e-01, -3.44865173e-01, -4.32067961e-01,\n", " -4.17658724e-02, -1.92612112e-01, -4.07513529e-01,\n", " -2.00688168e-01, -3.12940218e-02, -5.83245270e-02,\n", " -3.02525491e-01, -6.36755228e-01, -2.01398991e-02,\n", " -1.94140598e-01, -5.85560381e-01, -2.78204322e-01,\n", " -4.92228866e-01, 2.85394281e-01, -5.29185772e-01,\n", " -5.80944479e-01, -4.82267290e-01, -3.02456468e-01,\n", " -2.17350312e-02, -2.27617443e-01, -8.41379631e-03,\n", " -5.19459188e-01, -1.92483932e-01, -6.69973344e-02,\n", " -3.18294495e-01, -4.43626344e-01, 1.03083804e-01,\n", " -1.43494621e-01, -3.98965865e-01, -2.91880131e-01,\n", " -1.15407094e-01, -2.33865350e-01, -3.48333865e-01,\n", " -3.13846886e-01, -2.00329088e-02, -2.08419889e-01,\n", " -6.56257868e-02, -3.15933287e-01, -2.66032100e-01,\n", " -2.17209011e-01, -2.57886738e-01, -3.74219060e-01,\n", " -3.42252910e-01, -3.02372843e-01, -2.70351022e-01,\n", " -4.19028729e-01, -2.16944158e-01, 1.65465083e-02,\n", " -1.38239786e-01, 8.82068649e-03, -5.47306299e-01,\n", " -6.58184737e-02, -1.07372276e-01, -1.99595578e-02,\n", " -3.04633468e-01, -2.42436364e-01, -9.85036939e-02,\n", " 8.13045427e-02, -6.01692021e-01, -7.83374131e-01,\n", " -3.54873002e-01, -1.54401422e-01, -1.99920405e-02,\n", " -6.02073036e-03, -7.46182263e-01, -5.17743170e-01,\n", " -1.43411651e-01, 1.35698587e-01, -4.32992607e-01,\n", " -3.22256982e-01, 2.01625749e-01, -1.68692529e-01,\n", " 9.03868079e-02, -7.36883581e-02, -2.26779003e-02,\n", " 7.53887817e-02, -3.51618379e-01, -6.96502507e-01,\n", " -1.97232455e-01, -2.19720408e-01, -1.76197141e-01,\n", " -3.31067145e-01, 2.52920628e-01, -5.32557011e-01,\n", " -9.84433852e-03, -2.28284430e-02, -2.18466327e-01,\n", " -2.50813589e-02, -1.22822799e-01, -6.21357895e-02,\n", " -1.85140949e-02, 1.55188337e-01, -2.91802138e-01,\n", " -1.76329892e-02, -3.60844210e-02, -5.81378281e-01,\n", " -6.11039221e-01, -3.28095675e-01, -2.83731908e-01,\n", " -1.66193381e-01, 5.52292354e-02, 6.29878119e-02,\n", " -3.41305107e-01, -1.39835373e-01, 1.71938047e-01,\n", " -1.84613727e-02, 7.50863180e-02, -3.44148017e-02,\n", " -3.53854299e-01, -5.12476027e-01, 1.22042328e-01,\n", " -5.39535470e-02, 3.05281021e-03, -1.19409911e-01,\n", " -2.89323032e-01, -6.71940520e-02, -2.19452642e-02,\n", " -2.90004104e-01, -1.76387712e-01, -4.56134796e-01,\n", " -8.09880495e-01, -1.83778346e-01, -2.31890544e-01,\n", " -4.52327728e-01, -2.06816241e-01, -1.38748497e-01,\n", " -4.18441355e-01, -5.38856745e-01, -5.05130768e-01,\n", " -1.75971299e-01, -1.19080685e-01, -9.46213081e-02,\n", " -3.64823714e-02, -3.22997957e-01, -1.34447142e-01,\n", " -1.27073288e-01, 1.64654911e-01, -9.78678912e-02,\n", " -4.47389364e-01, -2.54144296e-02, 1.73969138e-02,\n", " -2.04480872e-01, -4.30503398e-01, -1.67036086e-01,\n", " -2.49711365e-01, -3.37412119e-01, -6.02359474e-01,\n", " -6.62094355e-01, -1.16948448e-01, 9.77696292e-03,\n", " -5.21902740e-01, -2.33485606e-02, -6.64649755e-02,\n", " -6.00027978e-01, -5.42070754e-02, -2.38561943e-01,\n", " -4.47000265e-01, 1.17274612e-01, -1.11540303e-01,\n", " -1.02203742e-01, -6.74192980e-02, -1.72974497e-01,\n", " -2.43933983e-02, -2.18470603e-01, -1.02555685e-01,\n", " -5.01730680e-01, -1.63745075e-01, -2.48166338e-01,\n", " 4.25796956e-02, -8.81046131e-02, -4.94634926e-01,\n", " -2.48743445e-01, 8.22583865e-03, -2.14855313e-01,\n", " -5.94667614e-01, 1.23224966e-01, -2.28983104e-01,\n", " -4.89580818e-02, -3.53976309e-01, -1.02518976e-01,\n", " -2.80924350e-01, 2.18932718e-01, -9.42684943e-04,\n", " -2.78814733e-01, -2.43697301e-01, -4.07780051e-01,\n", " -1.57622676e-02, -4.32732075e-01, 2.76384447e-02,\n", " -2.56971091e-01, -1.39276221e-01, -2.89412320e-01,\n", " -7.84103293e-03, -5.75612962e-01, -2.65779234e-02,\n", " -2.83633530e-01, -2.42152084e-02, -3.54716778e-01,\n", " -5.25303543e-01, -6.30853772e-02, -2.22892091e-01,\n", " -3.32897723e-01, -8.58137235e-02, -1.35768950e-01,\n", " -4.00102228e-01, -6.81776628e-02, -1.11637965e-01,\n", " 8.71941745e-02, 7.97185600e-02, -4.74733919e-01,\n", " -5.36120776e-03, -2.00053956e-02, 2.74125468e-02,\n", " -5.23373425e-01, -3.52810740e-01, -5.75067937e-01,\n", " -1.27765425e-02, -2.41196215e-01, 1.35370884e-02,\n", " -3.42776716e-01, -2.61937886e-01, -1.73471346e-01,\n", " -7.74265826e-01, -3.25414896e-01, -6.52070194e-02,\n", " -1.75177939e-02, -2.78512776e-01, -1.26804650e-01,\n", " -1.54330492e-01, -2.43354395e-01, -5.10048628e-01,\n", " -5.22104055e-02, -4.48061913e-01, -2.54915148e-01,\n", " -3.71145964e-01, -2.34785691e-01, -5.76828778e-01,\n", " -5.20584345e-01, -2.01370478e-01, -3.43574703e-01,\n", " -3.95394504e-01, -7.02085435e-01, 3.80159239e-03,\n", " -5.05006194e-01, -6.66690245e-02, -2.13820174e-01,\n", " -1.86356172e-01, -1.98591515e-01, -2.26664558e-01,\n", " -9.84562710e-02, 9.10461769e-02, -1.63858235e-01,\n", " -6.71461642e-01, -2.07045935e-02, -1.84064224e-01,\n", " -1.52253630e-02, -6.44623414e-02, -1.90693051e-01,\n", " -3.26317549e-01, -3.90465967e-02, -4.31612767e-02,\n", " -2.69320831e-02, -2.61054486e-01, -5.56032240e-01,\n", " -1.39396250e-01, -3.04626554e-01, -4.00418974e-02,\n", " -5.22964954e-01, -2.74515212e-01, -2.05182180e-01,\n", " -4.55017984e-01, -4.10655349e-01, -3.91681463e-01,\n", " -2.95707285e-01, -1.75162852e-02, -1.80232033e-01,\n", " -9.38054398e-02, -4.48614866e-01, -1.20916396e-01,\n", " -1.26026660e-01, -6.13098264e-01, -9.16779786e-02,\n", " -1.24931745e-01, -1.14639051e-01, -5.89349389e-01,\n", " -2.86892831e-01, -4.32475626e-01, -4.53839451e-01,\n", " -5.40873766e-01, -3.22011739e-01, -1.04171380e-01,\n", " -2.03116417e-01, -7.34383706e-03, -2.95767933e-01,\n", " 3.77100818e-02, -3.95163864e-01, -9.11748350e-01,\n", " -2.14269429e-01, -4.47106093e-01, -1.02919694e-02,\n", " -1.46425188e-01, 1.30215868e-01, 3.46448004e-01,\n", " -7.53604919e-02, -3.68188143e-01, -1.75004661e-01,\n", " -3.42096955e-01, -1.19322361e-02, 9.38493479e-03,\n", " -5.18787801e-01, -1.09108455e-01, 6.15557991e-02,\n", " -8.33496079e-03, -6.41730651e-02, -1.36719868e-02,\n", " -3.73748362e-01, -3.73859495e-01, 2.80248914e-02,\n", " -3.09117913e-01, -2.88713902e-01, -4.28494245e-01,\n", " -5.13740003e-01, -1.57594740e-01, -4.70732421e-01,\n", " -1.38654308e-02, -6.85215056e-01, -3.66586596e-01,\n", " -1.41351402e-01, -1.13854766e-01, -5.36643863e-01,\n", " -4.75565642e-01, -5.00832915e-01, -4.08477843e-01,\n", " -3.66504490e-01, -1.15367234e-01, -2.48915218e-02,\n", " -4.96757418e-01, 1.17366053e-01, -2.26039514e-01,\n", " -5.49678802e-01, -2.75789142e-01, -5.08426309e-01,\n", " 1.07284091e-01, -2.54364550e-01, -3.72139484e-01,\n", " -3.34391892e-01, 2.10764147e-02, -1.33560911e-01,\n", " -9.50245783e-02, -3.13357562e-01, -2.62188077e-01,\n", " -5.32095313e-01, -5.31459413e-03, -3.21489833e-02,\n", " -7.84164011e-01, -1.10715240e-01, -2.87352562e-01,\n", " -5.71807444e-01, -2.04134420e-01, 7.85130933e-02,\n", " -3.69185776e-01, -1.98006928e-02, 6.63151639e-03,\n", " -2.87224799e-01, 5.36596589e-02, -7.96930939e-02,\n", " -2.82612413e-01, -1.87133670e-01, -6.54792845e-01,\n", " -8.59472081e-02, -1.13062121e-01, -1.83315545e-01,\n", " -2.58277714e-01, -5.51701725e-01, -5.59242129e-01,\n", " -1.50169775e-01, 4.73141856e-02, -1.68764800e-01,\n", " -2.75284111e-01, -4.43699747e-01, -2.76820183e-01,\n", " -3.51191200e-02, -1.07176892e-01, -4.73967902e-02,\n", " -4.53751475e-01, -2.84370124e-01, -4.89342690e-01,\n", " -3.81000303e-02, -5.29655755e-01, -1.50656566e-01,\n", " -4.64593619e-01, -1.58045471e-01, -7.06188157e-02,\n", " -4.04648870e-01, -3.15317452e-01, -2.87708908e-01,\n", " -1.71832666e-01, -2.27938369e-01, -2.11054739e-02,\n", " -3.29687774e-01, -1.82581544e-01, -2.17228252e-02,\n", " 2.08218992e-02, -1.46109968e-01, -7.96382129e-02,\n", " -3.17795098e-01, -5.75634658e-01, -3.44916396e-02,\n", " -4.36014533e-01, -2.85244137e-02, -5.68732560e-01,\n", " -5.59068859e-01, -1.22407533e-01, -2.56792486e-01,\n", " -2.97368616e-01, -3.03129584e-01, -1.62084669e-01,\n", " -2.64727145e-01, -4.05563980e-01, 3.00995618e-01,\n", " -1.86940640e-01, -9.05097499e-02, -1.19438395e-01,\n", " -1.88409179e-01, -3.68620992e-01, 3.19603570e-02,\n", " -5.20787895e-01, -2.95364499e-01, -1.96136490e-01,\n", " 1.30156171e+00, -3.09764799e-02, -1.63758829e-01,\n", " -1.63395420e-01, -1.06308326e-01, -3.37606370e-01,\n", " -4.02779371e-01, -1.04163669e-01, -3.29879135e-01,\n", " -6.24738149e-02, 7.57394284e-02, -6.51596487e-01,\n", " -2.37611696e-01, -5.25772333e-01, 1.44061729e-01,\n", " -2.59940475e-01, -2.72920489e-01, -3.10522407e-01,\n", " -8.48866284e-01, -5.29746771e-01, -1.75354518e-02,\n", " -8.73476788e-02, -4.62230533e-01, -3.12623024e-01,\n", " -4.66565102e-01, -2.35941991e-01, -4.72842991e-01,\n", " -8.59152302e-02, -3.31128508e-01, -1.34016275e-01,\n", " -6.82140663e-02, -1.31053597e-01, 3.27668451e-02,\n", " -4.59252357e-01, -7.40645081e-02, -2.32884094e-01,\n", " -2.48913141e-03, -5.38118541e-01, -6.48121983e-02,\n", " -2.82097995e-01, -4.83397216e-01, -3.75957131e-01,\n", " -1.20243065e-01, -2.91992631e-02, -2.34807402e-01,\n", " -8.57004896e-02, -1.76332936e-01, -4.79596853e-01,\n", " -3.59954983e-01, -3.86393666e-01, -1.49604112e-01,\n", " 9.89474952e-02, -1.43513409e-02, -5.00253379e-01,\n", " -2.31766224e-01, -2.78296471e-01, -1.47517323e-01,\n", " -2.70760179e-01, 5.62180728e-02, 1.26814142e-01,\n", " -2.58570649e-02, -3.02321255e-01, -5.06240189e-01,\n", " -3.60810488e-01, -1.61365643e-01, -1.28059566e-01,\n", " -2.62734950e-01, -1.67697724e-02, 9.22571719e-02,\n", " -7.30941415e-01, -3.17986846e-01, -3.49215209e-01,\n", " -4.75899428e-01, -5.54573357e-01, -2.22814456e-01,\n", " -9.33618564e-03, -4.88777943e-02, -2.79946309e-02,\n", " -2.43498668e-01, 1.63741887e-01, -8.86490270e-02,\n", " -1.80582032e-02, 5.81286959e-02, -5.06547272e-01,\n", " -2.36781448e-01, -2.82066971e-01, 3.62231545e-02,\n", " 5.59952706e-02, -5.27004182e-01, -5.63789010e-02,\n", " -6.33812070e-01, -7.20118701e-01, -3.27905029e-01,\n", " -1.09615184e-01, -1.97968498e-01, -3.48774903e-02,\n", " -4.36178327e-01, -1.90760285e-01, -2.00712010e-01,\n", " -4.05785292e-02, -7.98018798e-02, -6.48312092e-01,\n", " -5.16030610e-01, -1.82418972e-02, -3.22774321e-01,\n", " -1.91510841e-01, -1.31354675e-01, -5.67911983e-01,\n", " -4.27046567e-01, -2.61492878e-01, -7.63690919e-02,\n", " -3.53502780e-01, -2.86672637e-02, 6.57036155e-02,\n", " -2.32697666e-01, -2.25740999e-01, -2.21521795e-01,\n", " 3.64017077e-02, -4.65820670e-01, -1.67809874e-01,\n", " -2.34040041e-02, -3.40095460e-01, 5.10562137e-02,\n", " -2.80955017e-01, 2.17410009e-02, -2.25610495e-01,\n", " -2.61850543e-02, -1.18860357e-01, 9.67218876e-02,\n", " -6.98161423e-01, -4.03901875e-01, -2.49750782e-02,\n", " -1.49894670e-01, -1.55417640e-02, -2.35045440e-02,\n", " -1.22158304e-02, -3.60701740e-01, -5.72664201e-01,\n", " -4.56410229e-01, -9.86423045e-02, -5.59065938e-01,\n", " -2.43323550e-01, 1.14932351e-01, -1.32146357e-02,\n", " -1.13701306e-01, -2.43878905e-02, 3.04878563e-01,\n", " -2.93137670e-01, -4.26690668e-01, -1.90759376e-01,\n", " -5.80423713e-01, 1.61198322e-02, -3.25486124e-01,\n", " -3.21475148e-01, -2.53617167e-01, -1.20874017e-01,\n", " -4.76823658e-01, -3.47528964e-01, -2.89901286e-01,\n", " 2.24457998e-02, -4.97344643e-01, 1.08718812e+00,\n", " -2.79220223e-01], dtype=float32),\n", " array([[ 0.03900816, 0.00785677, -0.06511776, ..., 0.00776991,\n", " -0.05963232, -0.05985177],\n", " [-0.20750827, 0.08817152, 0.40323174, ..., 0.20854132,\n", " -0.11089708, 0.14705186],\n", " [-0.24851227, 0.36102909, 0.07329425, ..., 0.12305254,\n", " 0.02824712, 0.2746895 ],\n", " ..., \n", " [-0.27076459, 0.04397521, 0.10150083, ..., -0.02952144,\n", " 0.35495111, 0.01788467],\n", " [-0.22880824, -0.14765862, -0.01148497, ..., -0.04802479,\n", " -0.11898327, 0.16021334],\n", " [-0.01458607, 0.51388001, 0.25630933, ..., 0.10885861,\n", " -0.15997633, 0.01113635]], dtype=float32),\n", " array([-0.36252829, -0.41307127, -0.37561458, -0.790694 , -0.7867986 ,\n", " -0.39656818, -0.49989551, -0.56961799, -0.67535901, -0.78190619,\n", " -0.64679927, -0.62336636, -0.73334086, -0.51707494, -0.80007225,\n", " -0.57039291, -0.43117863, -0.57423478, -1.01204598, -0.99576569,\n", " -0.45388478, -0.9715423 , -0.57562113, -0.85434681, -0.4783178 ,\n", " -0.65333492, -0.56394655, -0.51519966, -0.87941819, -0.9431147 ,\n", " -0.52889907, -0.51141596, -1.04037309, -0.87605566, -0.5586676 ,\n", " -0.67145008, -0.62178028, -0.74712718, -0.47700772, -0.81794 ,\n", " -0.94796181, -1.03332078, -0.99911004, -0.35762793, -0.41830212,\n", " -0.44990394, -0.54796964, -0.64622766, -0.36980084, -0.62949306,\n", " -0.73081511, -0.92071664, -0.96040893, -0.17141432, -0.50711352,\n", " -0.68742466, -0.58205402, -0.60873783, -0.51237881, -0.42307621,\n", " -0.59278268, -0.77905166, -0.70859444, -0.99470675, -0.68357819,\n", " -0.45728955, -0.98573047, -0.7740072 , -0.76561183, -0.38337517,\n", " -0.78785807, -0.9682638 , -0.41092423, -0.81709141, -0.4595961 ,\n", " -0.45476505, -0.89052409, -0.95178139, -0.920165 , -0.83498871,\n", " -0.54309958, -0.62142682, -0.10648966, -0.55824465, -0.51698029,\n", " -0.65391433, -0.73073816, -0.63968295, -0.73563075, -0.37823838,\n", " -0.83874625, -0.35336301, -0.72945499, -0.61786187, -1.04557991,\n", " -0.58565521, -0.35223064, -0.30662736, -0.66361117, -0.74605358,\n", " -0.79575521, -1.12011874, -0.65195775, -0.66316205, -0.30292839,\n", " -0.97478765, -0.30300212, -0.98781288, -0.88087404, -0.56088251,\n", " -0.82704026, -0.57432526, -0.44808209, -0.65736598, -0.7800023 ,\n", " -0.43863136, -0.71997589, -0.79668957, -0.58597511, -0.79392022,\n", " -0.91689253, -0.17079359, -0.70273119, -0.31935337, -0.99297088,\n", " -1.21429086, -0.54536754, -0.66847122, -1.0803057 , -0.02116329,\n", " -0.36946481, -0.78094089, -0.67028719, -0.63478422, -0.56762469,\n", " -0.59048861, -0.40834036, -0.76510531, -0.86944491, -0.26183733,\n", " -0.64363545, -0.21043499, -0.80520427, -0.98543239, -1.02239132,\n", " -0.87130302, -1.06532812, -0.47601402, -0.55352145, -0.75008106,\n", " -0.57477021, -0.73686802, -0.44472244, -0.64302158, -0.61648601,\n", " -1.09791934, -0.83204991, -0.40939972, -0.82405424, -0.57132626,\n", " -0.85813493, -0.84275389, -0.53043413, -1.03980398, -0.41696942,\n", " -0.99465734, -0.70751721, -0.94126099, -0.70646006, -0.85644752,\n", " -0.75323451, -0.62099051, -0.99225199, -0.81427616, -0.72105873,\n", " -0.3865678 , -0.71929121, -0.85359961, -0.47467613, -0.49992275,\n", " -0.78395241, -0.66783226, -0.85084015, -0.37230313, -0.74241304,\n", " -0.52368313, -0.57518154, -0.88761586, -0.78079957, -0.84552658,\n", " -0.60064358, -0.58771318, -0.68866116, -0.7030834 , -0.8059988 ,\n", " -0.71570534, -0.56441271, -0.89694452, -0.83912975, -0.46641162], dtype=float32),\n", " array([[-0.78751951, 0.02826324, -0.07172652, ..., -0.27620244,\n", " -0.47863257, -0.49731782],\n", " [-0.49682441, 0.04474993, -0.77598727, ..., -0.54524791,\n", " -0.21792939, -0.47720003],\n", " [-0.2323969 , -0.88028777, -0.2349651 , ..., -0.14491257,\n", " -0.17279406, -0.64144588],\n", " ..., \n", " [-0.7111882 , -0.30641097, -0.66904122, ..., -0.0798426 ,\n", " -0.57756215, -0.08725328],\n", " [ 0.11830693, 0.07352046, 0.08562858, ..., 0.09446803,\n", " -0.41451645, -0.35526502],\n", " [-0.92134595, 0.0993112 , -0.0636774 , ..., -0.0216356 ,\n", " -0.54615569, -0.05519475]], dtype=float32),\n", " array([-0.28950188, -0.33981469, -0.49054769, -0.24692491, -0.54108179,\n", " -0.53850734, -0.51629019, -0.45034203, 0.94987106, 0.34385717], dtype=float32)]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# View the weights of the trained model.\n", "trained_model.get_weights()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 22619.2383449\n", "Accuracy: 0.9859\n" ] } ], "source": [ "print(\"Training time: \" + str(trainer.get_training_time()))\n", "print(\"Accuracy: \" + str(evaluate_accuracy(trained_model, test_set)))" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [conda root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 1 }