{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classifying primary tumor with DNA methylation data \n", "\n", "### Raw data from TCGA is processed in R \n", "- `getDNAM.R` \n", "- `tcga2stat.R` \n", "\n", "### All DNAm results were computed on the UIUC nano cluster:\n", "- `rf_dnam.py` for random forest classification \n", "- `knn_dnam.py` for k-nearest neighbors classification \n", "- `nn_loop_dnam.py` for tuning neural network hyperparamters \n", "- `nn_dnam.py` for neural net classification \n", "\n", "Data structures from the cluster are saved using cPickle, and loaded into this jupyter notebook for further visualization and analysis\n", "\n", "### Cluster outputs can be viewed in the following files: \n", "- `py_knn_10_dnam.out`\n", "- `py_nn_dnam.out`\n", "- `py_nn_sigmoid_scale.out`\n", "- `py_rf_dnam.out`\n", "\n", "Files can be found under the `tumor-origin/dnam` folder" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy\n", "import pandas as pd \n", "from tensorflow.keras import layers \n", "from tensorflow.keras.utils import to_categorical\n", "\n", "from numpy import random\n", "\n", "from sklearn import metrics\n", "from sklearn.metrics import confusion_matrix\n", "import matplotlib.pyplot as plt\n", "\n", "import cPickle as pickle\n", "\n", "# Files containing raw data were not loaded into ipynb \n", "# data = pd.read_csv('dnam/merged-rand.txt', sep='\\t')\n", "# types = pd.read_csv('dnam/types-rand.txt', sep='\\t')\n", "# labels = pd.read_csv('dnam/labels-rand.txt', sep='\\t')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tuning the learning rate\n", "\n", "The output `200e-2e5e-dnam` contains dictionaries for accuracy results at different learning rates, and accuracy results for different epochs. \n", "\n", "Learning rates were generated using `numpy.geomspace` for logarithmically spaced values. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#### CAREFUL\n", "#### You are about to load in new variables from the cluster \n", "#### To use for further visualization \n", "filename='db/200e-2e5e-dnam'\n", "with open(filename, 'rb') as fp:\n", " learnloss = pickle.load(fp)\n", " histories = pickle.load(fp)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#Visualize loss at each learning rate to pick an ideal learning rate \n", "\n", "def visualizeLearnLossRange(learnloss): \n", " plt.title('model test accuracy')\n", " plt.xscale('log')\n", " plt.ylabel('accuracy')\n", " plt.xlabel('learning rate')\n", " plt.scatter(list(learnloss.keys()),list(learnloss.values()))\n", " plt.gca().invert_xaxis()\n", " plt.savefig(\"accuracy_learningrate_range.pdf\")\n", " plt.show()\n", "\n", "visualizeLearnLossRange(learnloss)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Max test accuracy: \n", "(0.9932432432432432, [0.00021544346900318845])\n", "Rate: 1e-05\tAccuracy: 0.9915540540540541\n", "Rate: 2.1544346900318867e-05\tAccuracy: 0.9915540540540541\n", "Rate: 4.641588833612782e-05\tAccuracy: 0.9915540540540541\n", "Rate: 0.0001\tAccuracy: 0.9915540540540541\n", "Rate: 0.00021544346900318845\tAccuracy: 0.9932432432432432\n", "Rate: 0.0004641588833612782\tAccuracy: 0.9915540540540541\n", "Rate: 0.001\tAccuracy: 0.9898648648648649\n", "Rate: 0.0021544346900318843\tAccuracy: 0.9847972972972973\n", "8\n" ] } ], "source": [ "## Look at rates and decide lower/upper bound \n", "max_acc = max(learnloss.values()) # maximum value\n", "max_lr = [k for k, v in learnloss.items() if v == max_acc] # getting all keys containing the `maximum`\n", "print(\"Max test accuracy: \")\n", "print(max_acc, max_lr)\n", "\n", "orderedkeys = []\n", "for ll in learnloss:\n", " orderedkeys.append(ll)\n", "orderedkeys = sorted(orderedkeys)\n", "\n", "# Get high learning rates and visualize their accuracy over X epochs \n", "counter = 0\n", "goodrates = []\n", "for ll in orderedkeys: \n", " if learnloss[ll] > 0.77: \n", " counter+=1\n", " goodrates.append(ll)\n", " print(\"Rate: \" + str(ll) + \"\\tAccuracy: \"+ str(learnloss[ll]))\n", "print(counter)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def visualizeEpochs(learnloss, histories):\n", " fig = plt.figure(figsize=(10,3))\n", " counter = 0\n", " nrRows = 4\n", " nrCols = 3\n", " for lr in goodrates:\n", " if lr != 4.641588833612782e-05 and lr != 0.0001:\n", " # generate subplots \n", " ax = fig.add_subplot(nrRows, nrCols, counter+1)\n", " ax.plot(histories[lr]['acc'])\n", " ax.plot(histories[lr]['val_acc'])\n", " plt.title('learning rate: ' + '{:.3g}'.format(lr))\n", " ax.set_ylabel('accuracy')\n", " ax.set_xlabel('epoch')\n", " ax.legend(['train', 'test'], loc='lower right')\n", " counter +=1\n", " fig.set_figheight(12)\n", " fig.set_figwidth(15)\n", " plt.tight_layout()\n", " plt.savefig(\"dnam/6-accuracy_epochs.pdf\")\n", " plt.show()\n", " \n", "\n", "visualizeEpochs(learnloss, histories)" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [], "source": [ "# Learning rate: 0.01 to 0.000001\n", "# encoded labels are one-hot encoded \n", "# Test labels are treated with ravel\n", "learnloss = {} \n", "histories = {}\n", "def learnLoss(learningRate, epochs, train, encoded_train, test, encoded_test, test_labels):\n", " model = tf.keras.Sequential()\n", " model.add(layers.Dense(128, activation='sigmoid'))\n", " model.add(layers.Dense(128, activation='sigmoid'))\n", " model.add(layers.Dense(17, activation='softmax'))\n", " model.compile(optimizer=tf.train.RMSPropOptimizer(learningRate),\n", " loss='categorical_crossentropy',\n", " metrics=['accuracy'])\n", " \n", " model.fit(train, encoded_train, validation_data=(test, encoded_test), epochs=epochs, batch_size=32)\n", " \n", " \n", " # test \n", " pred_y = model.predict_classes(test)\n", " nnyhat = confusion_matrix(test_types, pred_y)\n", " accuracy = metrics.accuracy_score(test_labels, pred_y)\n", " print(\"Accuracy: \", accuracy)\n", " learnloss[learningRate] = accuracy \n", " histories[learningRate] = model.history.history" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learningRates = numpy.geomspace(0.01, 0.000001, num=50)\n", "print(learningRates)\n", "for lr in learningRates:\n", " learnLoss(lr, 500, train, encoded_train, test, encoded_test, r_test_types)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### CAREFUL\n", "### You are about to DUMP and replace the current saved variables\n", "\n", "filename='500epoch50learn'\n", "with open(filename, 'wb')as fp:\n", " pickle.dump(learnloss, fp)\n", " pickle.dump(histories, fp)\n", " \n", "filename='500epoch50learn'\n", "with open(filename, 'rb') as fp:\n", " learnloss = pickle.load(fp)\n", " histories = pickle.load(fp)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Sample code for debugging purposes. " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model = tf.keras.Sequential()\n", "model.add(layers.Dense(128, activation='sigmoid'))\n", "model.add(layers.Dense(128, activation='sigmoid'))\n", "model.add(layers.Dense(17, activation='softmax'))\n", "model.compile(optimizer=tf.train.RMSPropOptimizer(5.1794746792312125e-05),\n", " loss='categorical_crossentropy',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 5574 samples, validate on 2376 samples\n", "Epoch 1/200\n", "5574/5574 [==============================] - 1s 256us/step - loss: 0.8388 - acc: 0.7384 - val_loss: 0.8616 - val_acc: 0.7315\n", "Epoch 2/200\n", "5574/5574 [==============================] - 1s 169us/step - loss: 0.8350 - acc: 0.7363 - val_loss: 0.8543 - val_acc: 0.7319\n", "Epoch 3/200\n", "5574/5574 [==============================] - 1s 172us/step - loss: 0.8317 - acc: 0.7388 - val_loss: 0.8579 - val_acc: 0.7273\n", "Epoch 4/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.8313 - acc: 0.7393 - val_loss: 0.8530 - val_acc: 0.7315\n", "Epoch 5/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.8236 - acc: 0.7417 - val_loss: 0.8508 - val_acc: 0.7306\n", "Epoch 6/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.8281 - acc: 0.7404 - val_loss: 0.8445 - val_acc: 0.7344\n", "Epoch 7/200\n", "5574/5574 [==============================] - 1s 171us/step - loss: 0.8233 - acc: 0.7413 - val_loss: 0.8467 - val_acc: 0.7285\n", "Epoch 8/200\n", "5574/5574 [==============================] - 1s 175us/step - loss: 0.8254 - acc: 0.7397 - val_loss: 0.8501 - val_acc: 0.7298\n", "Epoch 9/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.8309 - acc: 0.7393 - val_loss: 0.8462 - val_acc: 0.7327\n", "Epoch 10/200\n", "5574/5574 [==============================] - 1s 208us/step - loss: 0.8299 - acc: 0.7406 - val_loss: 0.8523 - val_acc: 0.7353\n", "Epoch 11/200\n", "5574/5574 [==============================] - 1s 183us/step - loss: 0.8269 - acc: 0.7443 - val_loss: 0.8453 - val_acc: 0.7336\n", "Epoch 12/200\n", "5574/5574 [==============================] - 1s 194us/step - loss: 0.8255 - acc: 0.7456 - val_loss: 0.8423 - val_acc: 0.7323\n", "Epoch 13/200\n", "5574/5574 [==============================] - 1s 200us/step - loss: 0.8235 - acc: 0.7451 - val_loss: 0.8446 - val_acc: 0.7319\n", "Epoch 14/200\n", "5574/5574 [==============================] - 1s 183us/step - loss: 0.8228 - acc: 0.7456 - val_loss: 0.8441 - val_acc: 0.7336\n", "Epoch 15/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.8234 - acc: 0.7458 - val_loss: 0.8427 - val_acc: 0.7370\n", "Epoch 16/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.8192 - acc: 0.7463 - val_loss: 0.8407 - val_acc: 0.7344\n", "Epoch 17/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.8194 - acc: 0.7449 - val_loss: 0.8375 - val_acc: 0.7374\n", "Epoch 18/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.8114 - acc: 0.7478 - val_loss: 0.8363 - val_acc: 0.7370\n", "Epoch 19/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.8113 - acc: 0.7485 - val_loss: 0.8340 - val_acc: 0.7340\n", "Epoch 20/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.8118 - acc: 0.7485 - val_loss: 0.8376 - val_acc: 0.7336\n", "Epoch 21/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.8136 - acc: 0.7474 - val_loss: 0.8357 - val_acc: 0.7277\n", "Epoch 22/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.8136 - acc: 0.7460 - val_loss: 0.8393 - val_acc: 0.7294\n", "Epoch 23/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.8150 - acc: 0.7445 - val_loss: 0.8378 - val_acc: 0.7273\n", "Epoch 24/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.8160 - acc: 0.7440 - val_loss: 0.8363 - val_acc: 0.7298\n", "Epoch 25/200\n", "5574/5574 [==============================] - 1s 183us/step - loss: 0.8157 - acc: 0.7449 - val_loss: 0.8407 - val_acc: 0.7273\n", "Epoch 26/200\n", "5574/5574 [==============================] - 1s 178us/step - loss: 0.8201 - acc: 0.7454 - val_loss: 0.8455 - val_acc: 0.7273\n", "Epoch 27/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.8222 - acc: 0.7418 - val_loss: 0.8426 - val_acc: 0.7260\n", "Epoch 28/200\n", "5574/5574 [==============================] - 1s 172us/step - loss: 0.8198 - acc: 0.7433 - val_loss: 0.8428 - val_acc: 0.7302\n", "Epoch 29/200\n", "5574/5574 [==============================] - 1s 175us/step - loss: 0.8177 - acc: 0.7404 - val_loss: 0.8411 - val_acc: 0.7269\n", "Epoch 30/200\n", "5574/5574 [==============================] - 1s 181us/step - loss: 0.8129 - acc: 0.7429 - val_loss: 0.8348 - val_acc: 0.7285\n", "Epoch 31/200\n", "5574/5574 [==============================] - 1s 182us/step - loss: 0.8090 - acc: 0.7461 - val_loss: 0.8367 - val_acc: 0.7281\n", "Epoch 32/200\n", "5574/5574 [==============================] - 1s 178us/step - loss: 0.8114 - acc: 0.7431 - val_loss: 0.8361 - val_acc: 0.7277\n", "Epoch 33/200\n", "5574/5574 [==============================] - 1s 180us/step - loss: 0.8095 - acc: 0.7463 - val_loss: 0.8297 - val_acc: 0.7306\n", "Epoch 34/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.8093 - acc: 0.7445 - val_loss: 0.8315 - val_acc: 0.7336\n", "Epoch 35/200\n", "5574/5574 [==============================] - 1s 181us/step - loss: 0.8088 - acc: 0.7442 - val_loss: 0.8278 - val_acc: 0.7344\n", "Epoch 36/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.8046 - acc: 0.7470 - val_loss: 0.8226 - val_acc: 0.7382\n", "Epoch 37/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.8030 - acc: 0.7472 - val_loss: 0.8248 - val_acc: 0.7361\n", "Epoch 38/200\n", "5574/5574 [==============================] - 1s 175us/step - loss: 0.8013 - acc: 0.7483 - val_loss: 0.8249 - val_acc: 0.7391\n", "Epoch 39/200\n", "5574/5574 [==============================] - 1s 180us/step - loss: 0.7997 - acc: 0.7492 - val_loss: 0.8190 - val_acc: 0.7416\n", "Epoch 40/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.7950 - acc: 0.7526 - val_loss: 0.8222 - val_acc: 0.7365\n", "Epoch 41/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.7963 - acc: 0.7497 - val_loss: 0.8218 - val_acc: 0.7399\n", "Epoch 42/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.7951 - acc: 0.7531 - val_loss: 0.8212 - val_acc: 0.7403\n", "Epoch 43/200\n", "5574/5574 [==============================] - 1s 182us/step - loss: 0.7941 - acc: 0.7512 - val_loss: 0.8181 - val_acc: 0.7382\n", "Epoch 44/200\n", "5574/5574 [==============================] - 1s 181us/step - loss: 0.7933 - acc: 0.7519 - val_loss: 0.8181 - val_acc: 0.7365\n", "Epoch 45/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.7941 - acc: 0.7551 - val_loss: 0.8224 - val_acc: 0.7374\n", "Epoch 46/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.7989 - acc: 0.7522 - val_loss: 0.8229 - val_acc: 0.7365\n", "Epoch 47/200\n", "5574/5574 [==============================] - 1s 181us/step - loss: 0.7962 - acc: 0.7549 - val_loss: 0.8189 - val_acc: 0.7399\n", "Epoch 48/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.7962 - acc: 0.7537 - val_loss: 0.8232 - val_acc: 0.7361\n", "Epoch 49/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.7965 - acc: 0.7546 - val_loss: 0.8238 - val_acc: 0.7382\n", "Epoch 50/200\n", "5574/5574 [==============================] - 1s 184us/step - loss: 0.7945 - acc: 0.7567 - val_loss: 0.8203 - val_acc: 0.7365\n", "Epoch 51/200\n", "5574/5574 [==============================] - 1s 184us/step - loss: 0.7905 - acc: 0.7553 - val_loss: 0.8162 - val_acc: 0.7428\n", "Epoch 52/200\n", "5574/5574 [==============================] - 1s 172us/step - loss: 0.7874 - acc: 0.7573 - val_loss: 0.8143 - val_acc: 0.7395\n", "Epoch 53/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.7884 - acc: 0.7546 - val_loss: 0.8200 - val_acc: 0.7365\n", "Epoch 54/200\n", "5574/5574 [==============================] - 1s 182us/step - loss: 0.7882 - acc: 0.7539 - val_loss: 0.8177 - val_acc: 0.7386\n", "Epoch 55/200\n", "5574/5574 [==============================] - 1s 188us/step - loss: 0.7841 - acc: 0.7560 - val_loss: 0.8144 - val_acc: 0.7403\n", "Epoch 56/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.7833 - acc: 0.7551 - val_loss: 0.8079 - val_acc: 0.7428\n", "Epoch 57/200\n", "5574/5574 [==============================] - 1s 185us/step - loss: 0.7803 - acc: 0.7582 - val_loss: 0.8053 - val_acc: 0.7428\n", "Epoch 58/200\n", "5574/5574 [==============================] - 1s 180us/step - loss: 0.7731 - acc: 0.7601 - val_loss: 0.8068 - val_acc: 0.7416\n", "Epoch 59/200\n", "5574/5574 [==============================] - 1s 185us/step - loss: 0.7755 - acc: 0.7569 - val_loss: 0.8073 - val_acc: 0.7428\n", "Epoch 60/200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "5574/5574 [==============================] - 1s 170us/step - loss: 0.7722 - acc: 0.7585 - val_loss: 0.8057 - val_acc: 0.7407\n", "Epoch 61/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.7736 - acc: 0.7530 - val_loss: 0.8013 - val_acc: 0.7428\n", "Epoch 62/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.7692 - acc: 0.7544 - val_loss: 0.7986 - val_acc: 0.7441\n", "Epoch 63/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7744 - acc: 0.7537 - val_loss: 0.8065 - val_acc: 0.7445\n", "Epoch 64/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7759 - acc: 0.7539 - val_loss: 0.8040 - val_acc: 0.7462\n", "Epoch 65/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.7732 - acc: 0.7524 - val_loss: 0.8011 - val_acc: 0.7424\n", "Epoch 66/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7711 - acc: 0.7558 - val_loss: 0.7974 - val_acc: 0.7471\n", "Epoch 67/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.7690 - acc: 0.7565 - val_loss: 0.8017 - val_acc: 0.7420\n", "Epoch 68/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.7657 - acc: 0.7583 - val_loss: 0.8035 - val_acc: 0.7399\n", "Epoch 69/200\n", "5574/5574 [==============================] - 1s 166us/step - loss: 0.7662 - acc: 0.7573 - val_loss: 0.7995 - val_acc: 0.7433\n", "Epoch 70/200\n", "5574/5574 [==============================] - 1s 169us/step - loss: 0.7694 - acc: 0.7558 - val_loss: 0.7969 - val_acc: 0.7441\n", "Epoch 71/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.7667 - acc: 0.7625 - val_loss: 0.7992 - val_acc: 0.7479\n", "Epoch 72/200\n", "5574/5574 [==============================] - 1s 171us/step - loss: 0.7656 - acc: 0.7598 - val_loss: 0.7996 - val_acc: 0.7416\n", "Epoch 73/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.7609 - acc: 0.7605 - val_loss: 0.8004 - val_acc: 0.7445\n", "Epoch 74/200\n", "5574/5574 [==============================] - 1s 169us/step - loss: 0.7615 - acc: 0.7591 - val_loss: 0.8025 - val_acc: 0.7458\n", "Epoch 75/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7650 - acc: 0.7583 - val_loss: 0.8020 - val_acc: 0.7445\n", "Epoch 76/200\n", "5574/5574 [==============================] - 1s 172us/step - loss: 0.7611 - acc: 0.7598 - val_loss: 0.8054 - val_acc: 0.7403\n", "Epoch 77/200\n", "5574/5574 [==============================] - 1s 172us/step - loss: 0.7639 - acc: 0.7580 - val_loss: 0.7995 - val_acc: 0.7504\n", "Epoch 78/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.7634 - acc: 0.7564 - val_loss: 0.8019 - val_acc: 0.7483\n", "Epoch 79/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.7604 - acc: 0.7610 - val_loss: 0.7970 - val_acc: 0.7466\n", "Epoch 80/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7621 - acc: 0.7598 - val_loss: 0.7983 - val_acc: 0.7441\n", "Epoch 81/200\n", "5574/5574 [==============================] - 1s 178us/step - loss: 0.7666 - acc: 0.7576 - val_loss: 0.8008 - val_acc: 0.7428\n", "Epoch 82/200\n", "5574/5574 [==============================] - 1s 205us/step - loss: 0.7699 - acc: 0.7594 - val_loss: 0.8025 - val_acc: 0.7445\n", "Epoch 83/200\n", "5574/5574 [==============================] - 1s 199us/step - loss: 0.7676 - acc: 0.7610 - val_loss: 0.8059 - val_acc: 0.7428\n", "Epoch 84/200\n", "5574/5574 [==============================] - 1s 195us/step - loss: 0.7640 - acc: 0.7587 - val_loss: 0.8001 - val_acc: 0.7424\n", "Epoch 85/200\n", "5574/5574 [==============================] - 1s 196us/step - loss: 0.7633 - acc: 0.7596 - val_loss: 0.8037 - val_acc: 0.7449\n", "Epoch 86/200\n", "5574/5574 [==============================] - 1s 171us/step - loss: 0.7637 - acc: 0.7610 - val_loss: 0.8016 - val_acc: 0.7437\n", "Epoch 87/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.7648 - acc: 0.7605 - val_loss: 0.7986 - val_acc: 0.7445\n", "Epoch 88/200\n", "5574/5574 [==============================] - 1s 171us/step - loss: 0.7653 - acc: 0.7576 - val_loss: 0.8026 - val_acc: 0.7479\n", "Epoch 89/200\n", "5574/5574 [==============================] - 1s 173us/step - loss: 0.7658 - acc: 0.7589 - val_loss: 0.7992 - val_acc: 0.7424\n", "Epoch 90/200\n", "5574/5574 [==============================] - 1s 169us/step - loss: 0.7625 - acc: 0.7591 - val_loss: 0.8009 - val_acc: 0.7424\n", "Epoch 91/200\n", "5574/5574 [==============================] - 1s 175us/step - loss: 0.7624 - acc: 0.7596 - val_loss: 0.8019 - val_acc: 0.7395\n", "Epoch 92/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7621 - acc: 0.7614 - val_loss: 0.7999 - val_acc: 0.7365\n", "Epoch 93/200\n", "5574/5574 [==============================] - 1s 191us/step - loss: 0.7565 - acc: 0.7600 - val_loss: 0.7977 - val_acc: 0.7399\n", "Epoch 94/200\n", "5574/5574 [==============================] - 1s 193us/step - loss: 0.7532 - acc: 0.7603 - val_loss: 0.7901 - val_acc: 0.7433\n", "Epoch 95/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.7548 - acc: 0.7607 - val_loss: 0.7863 - val_acc: 0.7445\n", "Epoch 96/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.7493 - acc: 0.7632 - val_loss: 0.7876 - val_acc: 0.7420\n", "Epoch 97/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.7485 - acc: 0.7616 - val_loss: 0.7820 - val_acc: 0.7407\n", "Epoch 98/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.7477 - acc: 0.7634 - val_loss: 0.7910 - val_acc: 0.7395\n", "Epoch 99/200\n", "5574/5574 [==============================] - 1s 174us/step - loss: 0.7517 - acc: 0.7612 - val_loss: 0.7865 - val_acc: 0.7471\n", "Epoch 100/200\n", "5574/5574 [==============================] - 1s 190us/step - loss: 0.7501 - acc: 0.7634 - val_loss: 0.7844 - val_acc: 0.7424\n", "Epoch 101/200\n", "5574/5574 [==============================] - 1s 224us/step - loss: 0.7478 - acc: 0.7646 - val_loss: 0.7829 - val_acc: 0.7441\n", "Epoch 102/200\n", "5574/5574 [==============================] - 1s 186us/step - loss: 0.7510 - acc: 0.7632 - val_loss: 0.7812 - val_acc: 0.7475\n", "Epoch 103/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.7497 - acc: 0.7653 - val_loss: 0.7837 - val_acc: 0.7462\n", "Epoch 104/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.7514 - acc: 0.7628 - val_loss: 0.7822 - val_acc: 0.7437\n", "Epoch 105/200\n", "5574/5574 [==============================] - 1s 211us/step - loss: 0.7471 - acc: 0.7653 - val_loss: 0.7792 - val_acc: 0.7462\n", "Epoch 106/200\n", "5574/5574 [==============================] - 1s 183us/step - loss: 0.7447 - acc: 0.7664 - val_loss: 0.7889 - val_acc: 0.7416\n", "Epoch 107/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.7453 - acc: 0.7661 - val_loss: 0.7883 - val_acc: 0.7424\n", "Epoch 108/200\n", "5574/5574 [==============================] - 1s 195us/step - loss: 0.7429 - acc: 0.7673 - val_loss: 0.7812 - val_acc: 0.7454\n", "Epoch 109/200\n", "5574/5574 [==============================] - 1s 181us/step - loss: 0.7407 - acc: 0.7664 - val_loss: 0.7813 - val_acc: 0.7475\n", "Epoch 110/200\n", "5574/5574 [==============================] - 1s 204us/step - loss: 0.7411 - acc: 0.7704 - val_loss: 0.7807 - val_acc: 0.7433\n", "Epoch 111/200\n", "5574/5574 [==============================] - 1s 214us/step - loss: 0.7390 - acc: 0.7711 - val_loss: 0.7807 - val_acc: 0.7475\n", "Epoch 112/200\n", "5574/5574 [==============================] - 1s 241us/step - loss: 0.7390 - acc: 0.7695 - val_loss: 0.7778 - val_acc: 0.7492\n", "Epoch 113/200\n", "5574/5574 [==============================] - 1s 215us/step - loss: 0.7402 - acc: 0.7664 - val_loss: 0.7727 - val_acc: 0.7496\n", "Epoch 114/200\n", "5574/5574 [==============================] - 1s 198us/step - loss: 0.7367 - acc: 0.7670 - val_loss: 0.7783 - val_acc: 0.7475\n", "Epoch 115/200\n", "5574/5574 [==============================] - 1s 238us/step - loss: 0.7368 - acc: 0.7679 - val_loss: 0.7751 - val_acc: 0.7500\n", "Epoch 116/200\n", "5574/5574 [==============================] - 1s 192us/step - loss: 0.7335 - acc: 0.7686 - val_loss: 0.7697 - val_acc: 0.7492\n", "Epoch 117/200\n", "5574/5574 [==============================] - 1s 188us/step - loss: 0.7311 - acc: 0.7679 - val_loss: 0.7672 - val_acc: 0.7555\n", "Epoch 118/200\n", "5574/5574 [==============================] - 1s 185us/step - loss: 0.7298 - acc: 0.7671 - val_loss: 0.7660 - val_acc: 0.7525\n", "Epoch 119/200\n", "5574/5574 [==============================] - 1s 215us/step - loss: 0.7276 - acc: 0.7709 - val_loss: 0.7620 - val_acc: 0.7546\n", "Epoch 120/200\n", "5574/5574 [==============================] - 2s 286us/step - loss: 0.7280 - acc: 0.7707 - val_loss: 0.7625 - val_acc: 0.7546\n", "Epoch 121/200\n", "5574/5574 [==============================] - 2s 339us/step - loss: 0.7263 - acc: 0.7707 - val_loss: 0.7598 - val_acc: 0.7542\n", "Epoch 122/200\n", "5574/5574 [==============================] - 2s 316us/step - loss: 0.7253 - acc: 0.7711 - val_loss: 0.7611 - val_acc: 0.7555\n", "Epoch 123/200\n", "5574/5574 [==============================] - 1s 251us/step - loss: 0.7224 - acc: 0.7718 - val_loss: 0.7614 - val_acc: 0.7542\n", "Epoch 124/200\n", "5574/5574 [==============================] - 2s 324us/step - loss: 0.7193 - acc: 0.7731 - val_loss: 0.7598 - val_acc: 0.7542\n", "Epoch 125/200\n", "5574/5574 [==============================] - 1s 254us/step - loss: 0.7217 - acc: 0.7743 - val_loss: 0.7546 - val_acc: 0.7588\n", "Epoch 126/200\n", "5574/5574 [==============================] - 1s 236us/step - loss: 0.7198 - acc: 0.7729 - val_loss: 0.7577 - val_acc: 0.7601\n", "Epoch 127/200\n", "5574/5574 [==============================] - 1s 222us/step - loss: 0.7199 - acc: 0.7711 - val_loss: 0.7586 - val_acc: 0.7551\n", "Epoch 128/200\n", "5574/5574 [==============================] - 1s 222us/step - loss: 0.7141 - acc: 0.7770 - val_loss: 0.7544 - val_acc: 0.7626\n", "Epoch 129/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.7134 - acc: 0.7768 - val_loss: 0.7514 - val_acc: 0.7601\n", "Epoch 130/200\n", "5574/5574 [==============================] - 1s 163us/step - loss: 0.7129 - acc: 0.7783 - val_loss: 0.7565 - val_acc: 0.7555\n", "Epoch 131/200\n", "5574/5574 [==============================] - 1s 202us/step - loss: 0.7172 - acc: 0.7725 - val_loss: 0.7598 - val_acc: 0.7567\n", "Epoch 132/200\n", "5574/5574 [==============================] - 1s 164us/step - loss: 0.7119 - acc: 0.7759 - val_loss: 0.7574 - val_acc: 0.7588\n", "Epoch 133/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.7159 - acc: 0.7747 - val_loss: 0.7558 - val_acc: 0.7559\n", "Epoch 134/200\n", "5574/5574 [==============================] - 1s 220us/step - loss: 0.7146 - acc: 0.7731 - val_loss: 0.7587 - val_acc: 0.7576\n", "Epoch 135/200\n", "5574/5574 [==============================] - 1s 226us/step - loss: 0.7138 - acc: 0.7748 - val_loss: 0.7560 - val_acc: 0.7618\n", "Epoch 136/200\n", "5574/5574 [==============================] - 1s 208us/step - loss: 0.7111 - acc: 0.7774 - val_loss: 0.7550 - val_acc: 0.7609\n", "Epoch 137/200\n", "5574/5574 [==============================] - 1s 212us/step - loss: 0.7106 - acc: 0.7770 - val_loss: 0.7566 - val_acc: 0.7605\n", "Epoch 138/200\n", "5574/5574 [==============================] - 1s 189us/step - loss: 0.7095 - acc: 0.7743 - val_loss: 0.7652 - val_acc: 0.7542\n", "Epoch 139/200\n", "5574/5574 [==============================] - 1s 161us/step - loss: 0.7084 - acc: 0.7761 - val_loss: 0.7570 - val_acc: 0.7555\n", "Epoch 140/200\n", "5574/5574 [==============================] - 1s 158us/step - loss: 0.7115 - acc: 0.7713 - val_loss: 0.7623 - val_acc: 0.7563\n", "Epoch 141/200\n", "5574/5574 [==============================] - 1s 160us/step - loss: 0.7105 - acc: 0.7732 - val_loss: 0.7556 - val_acc: 0.7563\n", "Epoch 142/200\n", "5574/5574 [==============================] - 1s 182us/step - loss: 0.7104 - acc: 0.7745 - val_loss: 0.7578 - val_acc: 0.7572\n", "Epoch 143/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.7119 - acc: 0.7718 - val_loss: 0.7565 - val_acc: 0.7567\n", "Epoch 144/200\n", "5574/5574 [==============================] - 1s 171us/step - loss: 0.7134 - acc: 0.7731 - val_loss: 0.7600 - val_acc: 0.7551\n", "Epoch 145/200\n", "5574/5574 [==============================] - 1s 170us/step - loss: 0.7112 - acc: 0.7750 - val_loss: 0.7598 - val_acc: 0.7572\n", "Epoch 146/200\n", "5574/5574 [==============================] - 1s 230us/step - loss: 0.7170 - acc: 0.7680 - val_loss: 0.7588 - val_acc: 0.7576\n", "Epoch 147/200\n", "5574/5574 [==============================] - 1s 225us/step - loss: 0.7168 - acc: 0.7687 - val_loss: 0.7630 - val_acc: 0.7513\n", "Epoch 148/200\n", "5574/5574 [==============================] - 1s 248us/step - loss: 0.7157 - acc: 0.7709 - val_loss: 0.7659 - val_acc: 0.7538\n", "Epoch 149/200\n", "5574/5574 [==============================] - 1s 223us/step - loss: 0.7152 - acc: 0.7684 - val_loss: 0.7662 - val_acc: 0.7576\n", "Epoch 150/200\n", "5574/5574 [==============================] - 1s 214us/step - loss: 0.7177 - acc: 0.7702 - val_loss: 0.7649 - val_acc: 0.7563\n", "Epoch 151/200\n", "5574/5574 [==============================] - 1s 203us/step - loss: 0.7152 - acc: 0.7711 - val_loss: 0.7683 - val_acc: 0.7525\n", "Epoch 152/200\n", "5574/5574 [==============================] - 1s 216us/step - loss: 0.7203 - acc: 0.7705 - val_loss: 0.7710 - val_acc: 0.7538\n", "Epoch 153/200\n", "5574/5574 [==============================] - 1s 220us/step - loss: 0.7244 - acc: 0.7682 - val_loss: 0.7667 - val_acc: 0.7513\n", "Epoch 154/200\n", "5574/5574 [==============================] - 1s 206us/step - loss: 0.7190 - acc: 0.7680 - val_loss: 0.7657 - val_acc: 0.7601\n", "Epoch 155/200\n", "5574/5574 [==============================] - 1s 179us/step - loss: 0.7123 - acc: 0.7693 - val_loss: 0.7659 - val_acc: 0.7584\n", "Epoch 156/200\n", "5574/5574 [==============================] - 1s 246us/step - loss: 0.7098 - acc: 0.7722 - val_loss: 0.7565 - val_acc: 0.7555\n", "Epoch 157/200\n", "5574/5574 [==============================] - 1s 189us/step - loss: 0.7070 - acc: 0.7727 - val_loss: 0.7544 - val_acc: 0.7576\n", "Epoch 158/200\n", "5574/5574 [==============================] - 1s 205us/step - loss: 0.7081 - acc: 0.7734 - val_loss: 0.7497 - val_acc: 0.7563\n", "Epoch 159/200\n", "5574/5574 [==============================] - 1s 211us/step - loss: 0.7111 - acc: 0.7731 - val_loss: 0.7510 - val_acc: 0.7555\n", "Epoch 160/200\n", "5574/5574 [==============================] - 1s 247us/step - loss: 0.7120 - acc: 0.7720 - val_loss: 0.7559 - val_acc: 0.7546\n", "Epoch 161/200\n", "5574/5574 [==============================] - 2s 271us/step - loss: 0.7110 - acc: 0.7720 - val_loss: 0.7559 - val_acc: 0.7546\n", "Epoch 162/200\n", "5574/5574 [==============================] - 1s 198us/step - loss: 0.7093 - acc: 0.7745 - val_loss: 0.7606 - val_acc: 0.7542\n", "Epoch 163/200\n", "5574/5574 [==============================] - 1s 192us/step - loss: 0.7083 - acc: 0.7757 - val_loss: 0.7568 - val_acc: 0.7635\n", "Epoch 164/200\n", "5574/5574 [==============================] - 1s 214us/step - loss: 0.7062 - acc: 0.7736 - val_loss: 0.7535 - val_acc: 0.7605\n", "Epoch 165/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.7045 - acc: 0.7757 - val_loss: 0.7479 - val_acc: 0.7601\n", "Epoch 166/200\n", "5574/5574 [==============================] - 1s 162us/step - loss: 0.7014 - acc: 0.7736 - val_loss: 0.7463 - val_acc: 0.7626\n", "Epoch 167/200\n", "5574/5574 [==============================] - 1s 162us/step - loss: 0.6955 - acc: 0.7786 - val_loss: 0.7383 - val_acc: 0.7618\n", "Epoch 168/200\n", "5574/5574 [==============================] - 1s 168us/step - loss: 0.6942 - acc: 0.7761 - val_loss: 0.7377 - val_acc: 0.7656\n", "Epoch 169/200\n", "5574/5574 [==============================] - 1s 189us/step - loss: 0.6938 - acc: 0.7781 - val_loss: 0.7398 - val_acc: 0.7652\n", "Epoch 170/200\n", "5574/5574 [==============================] - 1s 209us/step - loss: 0.6899 - acc: 0.7793 - val_loss: 0.7360 - val_acc: 0.7643\n", "Epoch 171/200\n", "5574/5574 [==============================] - 1s 165us/step - loss: 0.6876 - acc: 0.7797 - val_loss: 0.7374 - val_acc: 0.7677\n", "Epoch 172/200\n", "5574/5574 [==============================] - 1s 164us/step - loss: 0.6866 - acc: 0.7801 - val_loss: 0.7386 - val_acc: 0.7614\n", "Epoch 173/200\n", "5574/5574 [==============================] - 1s 167us/step - loss: 0.6879 - acc: 0.7783 - val_loss: 0.7307 - val_acc: 0.7668\n", "Epoch 174/200\n", "5574/5574 [==============================] - 1s 165us/step - loss: 0.6837 - acc: 0.7793 - val_loss: 0.7330 - val_acc: 0.7652\n", "Epoch 175/200\n", "5574/5574 [==============================] - 1s 210us/step - loss: 0.6845 - acc: 0.7786 - val_loss: 0.7300 - val_acc: 0.7689\n", "Epoch 176/200\n", "5574/5574 [==============================] - 1s 160us/step - loss: 0.6853 - acc: 0.7784 - val_loss: 0.7333 - val_acc: 0.7643\n", "Epoch 177/200\n", "5574/5574 [==============================] - 1s 161us/step - loss: 0.6824 - acc: 0.7779 - val_loss: 0.7291 - val_acc: 0.7652\n", "Epoch 178/200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "5574/5574 [==============================] - 1s 160us/step - loss: 0.6830 - acc: 0.7757 - val_loss: 0.7300 - val_acc: 0.7673\n", "Epoch 179/200\n", "5574/5574 [==============================] - 1s 175us/step - loss: 0.6857 - acc: 0.7781 - val_loss: 0.7277 - val_acc: 0.7681\n", "Epoch 180/200\n", "5574/5574 [==============================] - 1s 187us/step - loss: 0.6854 - acc: 0.7784 - val_loss: 0.7351 - val_acc: 0.7664\n", "Epoch 181/200\n", "5574/5574 [==============================] - 1s 187us/step - loss: 0.6842 - acc: 0.7783 - val_loss: 0.7317 - val_acc: 0.7652\n", "Epoch 182/200\n", "5574/5574 [==============================] - 1s 178us/step - loss: 0.6850 - acc: 0.7797 - val_loss: 0.7367 - val_acc: 0.7639\n", "Epoch 183/200\n", "5574/5574 [==============================] - 1s 176us/step - loss: 0.6889 - acc: 0.7783 - val_loss: 0.7395 - val_acc: 0.7605\n", "Epoch 184/200\n", "5574/5574 [==============================] - 1s 221us/step - loss: 0.6884 - acc: 0.7777 - val_loss: 0.7391 - val_acc: 0.7601\n", "Epoch 185/200\n", "5574/5574 [==============================] - 1s 249us/step - loss: 0.6870 - acc: 0.7793 - val_loss: 0.7343 - val_acc: 0.7660\n", "Epoch 186/200\n", "5574/5574 [==============================] - 1s 197us/step - loss: 0.6858 - acc: 0.7788 - val_loss: 0.7262 - val_acc: 0.7681\n", "Epoch 187/200\n", "5574/5574 [==============================] - 1s 189us/step - loss: 0.6810 - acc: 0.7788 - val_loss: 0.7270 - val_acc: 0.7689\n", "Epoch 188/200\n", "5574/5574 [==============================] - 1s 180us/step - loss: 0.6825 - acc: 0.7799 - val_loss: 0.7328 - val_acc: 0.7635\n", "Epoch 189/200\n", "5574/5574 [==============================] - 1s 215us/step - loss: 0.6831 - acc: 0.7802 - val_loss: 0.7392 - val_acc: 0.7597\n", "Epoch 190/200\n", "5574/5574 [==============================] - 1s 193us/step - loss: 0.6851 - acc: 0.7806 - val_loss: 0.7414 - val_acc: 0.7588\n", "Epoch 191/200\n", "5574/5574 [==============================] - 1s 177us/step - loss: 0.6834 - acc: 0.7788 - val_loss: 0.7355 - val_acc: 0.7652\n", "Epoch 192/200\n", "5574/5574 [==============================] - 1s 182us/step - loss: 0.6863 - acc: 0.7808 - val_loss: 0.7313 - val_acc: 0.7635\n", "Epoch 193/200\n", "5574/5574 [==============================] - 1s 199us/step - loss: 0.6886 - acc: 0.7792 - val_loss: 0.7373 - val_acc: 0.7572\n", "Epoch 194/200\n", "5574/5574 [==============================] - 1s 224us/step - loss: 0.6885 - acc: 0.7804 - val_loss: 0.7428 - val_acc: 0.7572\n", "Epoch 195/200\n", "5574/5574 [==============================] - 1s 223us/step - loss: 0.6862 - acc: 0.7827 - val_loss: 0.7428 - val_acc: 0.7546\n", "Epoch 196/200\n", "5574/5574 [==============================] - 1s 202us/step - loss: 0.6869 - acc: 0.7817 - val_loss: 0.7448 - val_acc: 0.7546\n", "Epoch 197/200\n", "5574/5574 [==============================] - 1s 182us/step - loss: 0.6869 - acc: 0.7820 - val_loss: 0.7458 - val_acc: 0.7618\n", "Epoch 198/200\n", "5574/5574 [==============================] - 1s 199us/step - loss: 0.6839 - acc: 0.7818 - val_loss: 0.7401 - val_acc: 0.7605\n", "Epoch 199/200\n", "5574/5574 [==============================] - 1s 188us/step - loss: 0.6815 - acc: 0.7826 - val_loss: 0.7375 - val_acc: 0.7609\n", "Epoch 200/200\n", "5574/5574 [==============================] - 1s 195us/step - loss: 0.6792 - acc: 0.7844 - val_loss: 0.7357 - val_acc: 0.7630\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(train, encoded_train, validation_data=(test, encoded_test), epochs=200, batch_size=32)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "dense_3 (Dense) multiple 307712 \n", "_________________________________________________________________\n", "dense_4 (Dense) multiple 16512 \n", "_________________________________________________________________\n", "dense_5 (Dense) multiple 2193 \n", "=================================================================\n", "Total params: 326,417\n", "Trainable params: 326,417\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cross-validation (in-progress)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "# Load libraries\n", "import numpy as np\n", "from tensorflow.keras import models\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.wrappers.scikit_learn import KerasClassifier\n", "from sklearn.model_selection import cross_val_score\n", "from sklearn.datasets import make_classification" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "# Create function returning a compiled network\n", "def create_network():\n", " model = tf.keras.Sequential()\n", " model.add(layers.Dense(128, activation='sigmoid'))\n", " model.add(layers.Dense(128, activation='sigmoid'))\n", " model.add(layers.Dense(17, activation='softmax'))\n", " model.compile(optimizer=tf.train.RMSPropOptimizer(5.1794746792312125e-05),\n", " loss='categorical_crossentropy',\n", " metrics=['accuracy'])\n", " # Return compiled network\n", " return model" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "# Wrap Keras model so it can be used by scikit-learn\n", "neural_network = KerasClassifier(build_fn=create_network, \n", " epochs=100, \n", " batch_size=32, \n", " verbose=0)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.65285253, 0.66469322, 0.66953714])" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate neural network using three-fold cross-validation\n", "cross_val_score(neural_network, train, encoded_train, cv=3)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.10" } }, "nbformat": 4, "nbformat_minor": 2 }