{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4\n" ] } ], "source": [ "import ktrain\n", "from ktrain import graph as gr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Node Classification in Graphs\n", "\n", "\n", "In this notebook, we will use *ktrain* to perform node classificaiton on the Cora citation graph. Each node represents a paper pertaining to one of several paper topics. Links represent citations between papers. The attributes or features assigned to each node are in the form of a multi-hot-encoded vector of words appearing in the paper. The dataset is available [here](https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz).\n", "\n", "The dataset is already in the form expected by *ktrain*, so let's begin." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 1: Load and Preprocess Data\n", "\n", "We will hold out 10% of the nodes as a test set. Since we set `holdout_for_inductive=False`, the nodes being heldout will remain in the graph, but only their features (not labels) will be visible to our model. This is referred to as transductive inference. Of the remaining nodes, 10% will be used for training and the remaining nodes will be used for validation (also transductive inference). As with the holdout nodes, the features (but not labels) of validation nodes will be available to the model during training. The return value `df_holdout` contain the features for the heldout nodes and `G_complete` is the original graph including the holdout nodes. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Largest subgraph statistics: 2485 nodes, 5069 edges\n", "Size of training graph: 2485 nodes\n", "Training nodes: 223\n", "Validation nodes: 2013\n", "Nodes treated as unlabeled for testing/inference: 249\n", "Holdout node features are visible during training (transductive inference)\n", "\n" ] } ], "source": [ "(train_data, val_data, preproc, df_holdout, G_complete) = gr.graph_nodes_from_csv(\n", " 'data/cora/cora.content', # node attributes/labels\n", " 'data/cora/cora.cites', # edge list\n", " sample_size=20, \n", " holdout_pct=0.1, holdout_for_inductive=False,\n", " train_pct=0.1, sep='\\t')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `preproc` object includes a reference to the training graph and a dataframe showing the features and target for each node in the graph (both training and validation nodes)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Neural_Networks 726\n", "Genetic_Algorithms 406\n", "Probabilistic_Methods 379\n", "Theory 344\n", "Case_Based 285\n", "Reinforcement_Learning 214\n", "Rule_Learning 131\n", "Name: target, dtype: int64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preproc.df.target.value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 2: Build a Model and Wrap in Learner Object" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "graphsage: GraphSAGE: http://arxiv.org/pdf/1607.01759.pdf\n" ] } ], "source": [ "gr.print_node_classifiers()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "done\n" ] } ], "source": [ "learner = ktrain.get_learner(model=gr.graph_node_classifier('graphsage', train_data, ), \n", " train_data=train_data, \n", " val_data=val_data, \n", " batch_size=64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 3: Estimate LR \n", "Given the small number of batches per epoch, a larger number of epochs is required to estimate the learning rate. We will cap it at 100 here." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/100\n", "3/3 [==============================] - 1s 441ms/step - loss: 1.9648 - acc: 0.1302\n", "Epoch 2/100\n", "3/3 [==============================] - 0s 158ms/step - loss: 2.0053 - acc: 0.0873\n", "Epoch 3/100\n", "3/3 [==============================] - 1s 191ms/step - loss: 1.9632 - acc: 0.1510\n", "Epoch 4/100\n", "3/3 [==============================] - 0s 138ms/step - loss: 1.9682 - acc: 0.1411\n", "Epoch 5/100\n", "3/3 [==============================] - 1s 176ms/step - loss: 1.9776 - acc: 0.1304\n", "Epoch 6/100\n", "3/3 [==============================] - 0s 153ms/step - loss: 1.9632 - acc: 0.1536\n", "Epoch 7/100\n", "3/3 [==============================] - 1s 186ms/step - loss: 1.9682 - acc: 0.1562\n", "Epoch 8/100\n", "3/3 [==============================] - 0s 130ms/step - loss: 1.9509 - acc: 0.1275\n", "Epoch 9/100\n", "3/3 [==============================] - 1s 175ms/step - loss: 1.9662 - acc: 0.1137\n", "Epoch 10/100\n", "3/3 [==============================] - 1s 174ms/step - loss: 1.9605 - acc: 0.1919\n", "Epoch 11/100\n", "3/3 [==============================] - 1s 198ms/step - loss: 1.9962 - acc: 0.1042\n", "Epoch 12/100\n", "3/3 [==============================] - 0s 140ms/step - loss: 1.9700 - acc: 0.1179\n", "Epoch 13/100\n", "3/3 [==============================] - 1s 178ms/step - loss: 1.9776 - acc: 0.1578\n", "Epoch 14/100\n", "3/3 [==============================] - 1s 193ms/step - loss: 1.9769 - acc: 0.1406\n", "Epoch 15/100\n", "3/3 [==============================] - 1s 170ms/step - loss: 1.9816 - acc: 0.1510\n", "Epoch 16/100\n", "3/3 [==============================] - 0s 133ms/step - loss: 1.9620 - acc: 0.1481\n", "Epoch 17/100\n", "3/3 [==============================] - 0s 162ms/step - loss: 1.9662 - acc: 0.1591\n", "Epoch 18/100\n", "3/3 [==============================] - 0s 166ms/step - loss: 1.9790 - acc: 0.1288\n", "Epoch 19/100\n", "3/3 [==============================] - 1s 193ms/step - loss: 1.9705 - acc: 0.1198\n", "Epoch 20/100\n", "3/3 [==============================] - 0s 132ms/step - loss: 1.9617 - acc: 0.1508\n", "Epoch 21/100\n", "3/3 [==============================] - 1s 167ms/step - loss: 1.9818 - acc: 0.1317\n", "Epoch 22/100\n", "3/3 [==============================] - 0s 161ms/step - loss: 1.9709 - acc: 0.1288\n", "Epoch 23/100\n", "3/3 [==============================] - 1s 190ms/step - loss: 1.9560 - acc: 0.1615\n", "Epoch 24/100\n", "3/3 [==============================] - 0s 132ms/step - loss: 1.9715 - acc: 0.1508\n", "Epoch 25/100\n", "3/3 [==============================] - 1s 171ms/step - loss: 1.9741 - acc: 0.1851\n", "Epoch 26/100\n", "3/3 [==============================] - 1s 187ms/step - loss: 1.9706 - acc: 0.1406\n", "Epoch 27/100\n", "3/3 [==============================] - 0s 165ms/step - loss: 1.9826 - acc: 0.1398\n", "Epoch 28/100\n", "3/3 [==============================] - 0s 126ms/step - loss: 1.9698 - acc: 0.1262\n", "Epoch 29/100\n", "3/3 [==============================] - 1s 183ms/step - loss: 1.9711 - acc: 0.1523\n", "Epoch 30/100\n", "3/3 [==============================] - 0s 161ms/step - loss: 1.9680 - acc: 0.1549\n", "Epoch 31/100\n", "3/3 [==============================] - 1s 190ms/step - loss: 1.9472 - acc: 0.1615\n", "Epoch 32/100\n", "3/3 [==============================] - 0s 127ms/step - loss: 1.9847 - acc: 0.1646\n", "Epoch 33/100\n", "3/3 [==============================] - 0s 162ms/step - loss: 1.9565 - acc: 0.1411\n", "Epoch 34/100\n", "3/3 [==============================] - 1s 168ms/step - loss: 1.9785 - acc: 0.1549\n", "Epoch 35/100\n", "3/3 [==============================] - 1s 197ms/step - loss: 1.9499 - acc: 0.1927\n", "Epoch 36/100\n", "3/3 [==============================] - 0s 130ms/step - loss: 1.9497 - acc: 0.1578\n", "Epoch 37/100\n", "3/3 [==============================] - 0s 163ms/step - loss: 1.9379 - acc: 0.1880\n", "Epoch 38/100\n", "3/3 [==============================] - 1s 193ms/step - loss: 1.9216 - acc: 0.1823\n", "Epoch 39/100\n", "3/3 [==============================] - 1s 167ms/step - loss: 1.9734 - acc: 0.1358\n", "Epoch 40/100\n", "3/3 [==============================] - 0s 126ms/step - loss: 1.9371 - acc: 0.1481\n", "Epoch 41/100\n", "3/3 [==============================] - 1s 175ms/step - loss: 1.9302 - acc: 0.1468\n", "Epoch 42/100\n", "3/3 [==============================] - 0s 163ms/step - loss: 1.9158 - acc: 0.2099\n", "Epoch 43/100\n", "3/3 [==============================] - 0s 141ms/step - loss: 1.8992 - acc: 0.2222\n", "Epoch 44/100\n", "3/3 [==============================] - 1s 181ms/step - loss: 1.8642 - acc: 0.3021\n", "Epoch 45/100\n", "3/3 [==============================] - 1s 178ms/step - loss: 1.8753 - acc: 0.2552\n", "Epoch 46/100\n", "3/3 [==============================] - 1s 186ms/step - loss: 1.8553 - acc: 0.3281\n", "Epoch 47/100\n", "3/3 [==============================] - 1s 169ms/step - loss: 1.8448 - acc: 0.3155\n", "Epoch 48/100\n", "3/3 [==============================] - 0s 122ms/step - loss: 1.8037 - acc: 0.3582\n", "Epoch 49/100\n", "3/3 [==============================] - 0s 166ms/step - loss: 1.7770 - acc: 0.4334\n", "Epoch 50/100\n", "3/3 [==============================] - 1s 181ms/step - loss: 1.7460 - acc: 0.4323\n", "Epoch 51/100\n", "3/3 [==============================] - 0s 164ms/step - loss: 1.6978 - acc: 0.4980\n", "Epoch 52/100\n", "3/3 [==============================] - 0s 128ms/step - loss: 1.6504 - acc: 0.5324\n", "Epoch 53/100\n", "3/3 [==============================] - 1s 183ms/step - loss: 1.6264 - acc: 0.5573\n", "Epoch 54/100\n", "3/3 [==============================] - 1s 176ms/step - loss: 1.5451 - acc: 0.5914\n", "Epoch 55/100\n", "3/3 [==============================] - 1s 172ms/step - loss: 1.4829 - acc: 0.7040\n", "Epoch 56/100\n", "3/3 [==============================] - 0s 127ms/step - loss: 1.4272 - acc: 0.8013\n", "Epoch 57/100\n", "3/3 [==============================] - 0s 160ms/step - loss: 1.3344 - acc: 0.8698\n", "Epoch 58/100\n", "3/3 [==============================] - 0s 157ms/step - loss: 1.2562 - acc: 0.8808\n", "Epoch 59/100\n", "3/3 [==============================] - 1s 188ms/step - loss: 1.2021 - acc: 0.8646\n", "Epoch 60/100\n", "3/3 [==============================] - 0s 120ms/step - loss: 1.0503 - acc: 0.9575\n", "Epoch 61/100\n", "3/3 [==============================] - 0s 164ms/step - loss: 0.9593 - acc: 0.9562\n", "Epoch 62/100\n", "3/3 [==============================] - 1s 189ms/step - loss: 0.8614 - acc: 0.9479\n", "Epoch 63/100\n", "3/3 [==============================] - 1s 169ms/step - loss: 0.7299 - acc: 0.9836\n", "Epoch 64/100\n", "3/3 [==============================] - 0s 125ms/step - loss: 0.6011 - acc: 0.9781\n", "Epoch 65/100\n", "3/3 [==============================] - 1s 178ms/step - loss: 0.4877 - acc: 0.9836\n", "Epoch 66/100\n", "3/3 [==============================] - 1s 204ms/step - loss: 0.4136 - acc: 0.9740\n", "Epoch 67/100\n", "3/3 [==============================] - 0s 132ms/step - loss: 0.2811 - acc: 0.9941\n", "Epoch 68/100\n", "3/3 [==============================] - 1s 173ms/step - loss: 0.2441 - acc: 0.9896\n", "Epoch 69/100\n", "3/3 [==============================] - 1s 184ms/step - loss: 0.1701 - acc: 0.9948\n", "Epoch 70/100\n", "3/3 [==============================] - 1s 173ms/step - loss: 0.1220 - acc: 0.9945\n", "Epoch 71/100\n", "3/3 [==============================] - 1s 170ms/step - loss: 0.0776 - acc: 0.9945\n", "Epoch 72/100\n", "3/3 [==============================] - 0s 135ms/step - loss: 0.0630 - acc: 0.9945\n", "Epoch 73/100\n", "3/3 [==============================] - 1s 200ms/step - loss: 0.0780 - acc: 0.9844\n", "Epoch 74/100\n", "3/3 [==============================] - 1s 182ms/step - loss: 0.0392 - acc: 0.9945\n", "Epoch 75/100\n", "3/3 [==============================] - 0s 166ms/step - loss: 0.0540 - acc: 0.9836\n", "Epoch 76/100\n", "3/3 [==============================] - 0s 124ms/step - loss: 0.0416 - acc: 0.9945\n", "Epoch 77/100\n", "3/3 [==============================] - 0s 165ms/step - loss: 0.0482 - acc: 0.9945\n", "Epoch 78/100\n", "3/3 [==============================] - 1s 167ms/step - loss: 0.0385 - acc: 1.0000\n", "Epoch 79/100\n", "3/3 [==============================] - 0s 144ms/step - loss: 0.0917 - acc: 0.9643\n", "Epoch 80/100\n", "3/3 [==============================] - 1s 170ms/step - loss: 0.1521 - acc: 0.9427\n", "Epoch 81/100\n", "3/3 [==============================] - 0s 161ms/step - loss: 0.1830 - acc: 0.9286\n", "Epoch 82/100\n", "3/3 [==============================] - 1s 183ms/step - loss: 0.2672 - acc: 0.9115\n", "Epoch 83/100\n", "3/3 [==============================] - 0s 157ms/step - loss: 0.1182 - acc: 0.9671\n", "Epoch 84/100\n", "3/3 [==============================] - 0s 124ms/step - loss: 0.0851 - acc: 0.9726\n", "Epoch 85/100\n", "3/3 [==============================] - 1s 186ms/step - loss: 0.1062 - acc: 0.9688\n", "Epoch 86/100\n", "3/3 [==============================] - 1s 170ms/step - loss: 0.0700 - acc: 0.9684\n", "Epoch 87/100\n", "3/3 [==============================] - 0s 165ms/step - loss: 0.0589 - acc: 0.9890\n", "Epoch 88/100\n", "3/3 [==============================] - 0s 127ms/step - loss: 0.0858 - acc: 0.9849\n", "Epoch 89/100\n", "3/3 [==============================] - 1s 188ms/step - loss: 0.0380 - acc: 0.9794\n", "Epoch 90/100\n", "3/3 [==============================] - 1s 191ms/step - loss: 0.1058 - acc: 0.9688\n", "Epoch 91/100\n", "3/3 [==============================] - 0s 164ms/step - loss: 0.1064 - acc: 0.9739\n", "Epoch 92/100\n", "3/3 [==============================] - 0s 130ms/step - loss: 0.0653 - acc: 0.9836\n", "Epoch 93/100\n", "3/3 [==============================] - 1s 179ms/step - loss: 0.1252 - acc: 0.9507\n", "Epoch 94/100\n", "3/3 [==============================] - 1s 179ms/step - loss: 0.0929 - acc: 0.9643\n", "Epoch 95/100\n", "3/3 [==============================] - 1s 190ms/step - loss: 0.1500 - acc: 0.9583\n", "Epoch 96/100\n", "3/3 [==============================] - 0s 134ms/step - loss: 0.2589 - acc: 0.9343\n", "Epoch 97/100\n", "3/3 [==============================] - 0s 162ms/step - loss: 0.3288 - acc: 0.9246\n", "Epoch 98/100\n", "3/3 [==============================] - 0s 161ms/step - loss: 0.3882 - acc: 0.8931\n", "Epoch 99/100\n", "3/3 [==============================] - 1s 189ms/step - loss: 0.6683 - acc: 0.8854\n", "Epoch 100/100\n", "3/3 [==============================] - 0s 130ms/step - loss: 0.6474 - acc: 0.8957\n", "\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find(max_epochs=100)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 4: Train the Model\n", "We will train the model using `autofit`, which uses a triangular learning rate policy. The training will automatically stop when the validation loss no longer improves. We save the weights of the model during training in case we would like to reload the weights from any epoch." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "early_stopping automatically enabled at patience=5\n", "reduce_on_plateau automatically enabled at patience=2\n", "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.01...\n", "Epoch 1/1024\n", "4/4 [==============================] - 7s 2s/step - loss: 1.9479 - acc: 0.2029 - val_loss: 1.7514 - val_acc: 0.3060\n", "Epoch 2/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 1.6925 - acc: 0.4066 - val_loss: 1.6553 - val_acc: 0.3492\n", "Epoch 3/1024\n", "4/4 [==============================] - 6s 1s/step - loss: 1.5708 - acc: 0.5345 - val_loss: 1.5262 - val_acc: 0.4898\n", "Epoch 4/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 1.4280 - acc: 0.6994 - val_loss: 1.4030 - val_acc: 0.7074\n", "Epoch 5/1024\n", "4/4 [==============================] - 6s 1s/step - loss: 1.2972 - acc: 0.8828 - val_loss: 1.2960 - val_acc: 0.7765\n", "Epoch 6/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 1.1721 - acc: 0.9143 - val_loss: 1.2132 - val_acc: 0.7879\n", "Epoch 7/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 1.0570 - acc: 0.9572 - val_loss: 1.1320 - val_acc: 0.8003\n", "Epoch 8/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.9660 - acc: 0.9531 - val_loss: 1.0657 - val_acc: 0.8008\n", "Epoch 9/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.8845 - acc: 0.9685 - val_loss: 1.0068 - val_acc: 0.8053\n", "Epoch 10/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.8171 - acc: 0.9692 - val_loss: 0.9503 - val_acc: 0.8132\n", "Epoch 11/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.7351 - acc: 0.9612 - val_loss: 0.9076 - val_acc: 0.8127\n", "Epoch 12/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.6809 - acc: 0.9766 - val_loss: 0.8652 - val_acc: 0.8182\n", "Epoch 13/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.6138 - acc: 0.9886 - val_loss: 0.8332 - val_acc: 0.8102\n", "Epoch 14/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.5587 - acc: 0.9846 - val_loss: 0.8024 - val_acc: 0.8207\n", "Epoch 15/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.5186 - acc: 0.9886 - val_loss: 0.7824 - val_acc: 0.8236\n", "Epoch 16/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.4747 - acc: 0.9886 - val_loss: 0.7619 - val_acc: 0.8192\n", "Epoch 17/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.4317 - acc: 0.9927 - val_loss: 0.7434 - val_acc: 0.8152\n", "Epoch 18/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.4035 - acc: 0.9960 - val_loss: 0.7226 - val_acc: 0.8187\n", "Epoch 19/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.3773 - acc: 0.9960 - val_loss: 0.7148 - val_acc: 0.8187\n", "Epoch 20/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.3537 - acc: 1.0000 - val_loss: 0.7064 - val_acc: 0.8187\n", "Epoch 21/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.3154 - acc: 1.0000 - val_loss: 0.6969 - val_acc: 0.8162\n", "Epoch 22/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.2936 - acc: 1.0000 - val_loss: 0.6849 - val_acc: 0.8147\n", "Epoch 23/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.2744 - acc: 1.0000 - val_loss: 0.6781 - val_acc: 0.8197\n", "Epoch 24/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.2564 - acc: 1.0000 - val_loss: 0.6704 - val_acc: 0.8207\n", "Epoch 25/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.2442 - acc: 1.0000 - val_loss: 0.6647 - val_acc: 0.8187\n", "Epoch 26/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.2218 - acc: 1.0000 - val_loss: 0.6667 - val_acc: 0.8177\n", "Epoch 27/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.2138 - acc: 1.0000 - val_loss: 0.6546 - val_acc: 0.8212\n", "Epoch 28/1024\n", "4/4 [==============================] - 6s 1s/step - loss: 0.2000 - acc: 1.0000 - val_loss: 0.6521 - val_acc: 0.8236\n", "Epoch 29/1024\n", "4/4 [==============================] - 6s 1s/step - loss: 0.1819 - acc: 1.0000 - val_loss: 0.6467 - val_acc: 0.8167\n", "Epoch 30/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1757 - acc: 0.9960 - val_loss: 0.6344 - val_acc: 0.8202\n", "Epoch 31/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1791 - acc: 0.9920 - val_loss: 0.6349 - val_acc: 0.8187\n", "Epoch 32/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1593 - acc: 1.0000 - val_loss: 0.6287 - val_acc: 0.8172\n", "Epoch 33/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1563 - acc: 0.9960 - val_loss: 0.6207 - val_acc: 0.8266\n", "Epoch 34/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1498 - acc: 0.9960 - val_loss: 0.6221 - val_acc: 0.8222\n", "Epoch 35/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1391 - acc: 0.9960 - val_loss: 0.6202 - val_acc: 0.8266\n", "Epoch 36/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1330 - acc: 1.0000 - val_loss: 0.6186 - val_acc: 0.8296\n", "Epoch 37/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1202 - acc: 1.0000 - val_loss: 0.6260 - val_acc: 0.8227\n", "Epoch 38/1024\n", "4/4 [==============================] - 6s 1s/step - loss: 0.1172 - acc: 1.0000 - val_loss: 0.6193 - val_acc: 0.8251\n", "\n", "Epoch 00038: Reducing Max LR on Plateau: new max lr will be 0.005 (if not early_stopping).\n", "Epoch 39/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1163 - acc: 0.9960 - val_loss: 0.6133 - val_acc: 0.8266\n", "Epoch 40/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1114 - acc: 1.0000 - val_loss: 0.6221 - val_acc: 0.8236\n", "Epoch 41/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1093 - acc: 1.0000 - val_loss: 0.6210 - val_acc: 0.8222\n", "\n", "Epoch 00041: Reducing Max LR on Plateau: new max lr will be 0.0025 (if not early_stopping).\n", "Epoch 42/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1098 - acc: 1.0000 - val_loss: 0.6320 - val_acc: 0.8207\n", "Epoch 43/1024\n", "4/4 [==============================] - 6s 1s/step - loss: 0.1059 - acc: 1.0000 - val_loss: 0.6224 - val_acc: 0.8222\n", "\n", "Epoch 00043: Reducing Max LR on Plateau: new max lr will be 0.00125 (if not early_stopping).\n", "Epoch 44/1024\n", "4/4 [==============================] - 5s 1s/step - loss: 0.1035 - acc: 1.0000 - val_loss: 0.6291 - val_acc: 0.8197\n", "Restoring model weights from the end of the best epoch\n", "Epoch 00044: early stopping\n", "Weights from best epoch have been loaded into model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(0.01, checkpoint_folder='/tmp/saved_weights')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate\n", "\n", "#### Validate" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " Case_Based 0.73 0.81 0.77 227\n", " Genetic_Algorithms 0.90 0.96 0.93 331\n", " Neural_Networks 0.83 0.86 0.84 592\n", " Probabilistic_Methods 0.87 0.83 0.85 314\n", "Reinforcement_Learning 0.80 0.75 0.77 170\n", " Rule_Learning 0.86 0.60 0.71 106\n", " Theory 0.73 0.70 0.72 273\n", "\n", " accuracy 0.82 2013\n", " macro avg 0.82 0.79 0.80 2013\n", " weighted avg 0.82 0.82 0.82 2013\n", "\n" ] }, { "data": { "text/plain": [ "array([[183, 5, 11, 4, 6, 2, 16],\n", " [ 1, 318, 10, 0, 1, 0, 1],\n", " [ 18, 5, 507, 25, 12, 1, 24],\n", " [ 5, 0, 31, 262, 5, 0, 11],\n", " [ 3, 16, 19, 1, 128, 0, 3],\n", " [ 21, 0, 4, 2, 0, 64, 15],\n", " [ 20, 9, 29, 8, 9, 7, 191]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a Predictor Object" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "p = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transductive Inference: Making Predictions for Validation and Test Nodes in Original Training Graph\n", "In transductive inference, we make predictions for unlabeled nodes whose features are visible during training. Making predictions on validation nodes in the training graph is transductive inference.\n", "\n", "Let's see how well our prediction is for the first validation example." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.00738885, 0.00764509, 0.94959724, 0.00979447, 0.00634191,\n", " 0.00760743, 0.01162501]], dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict_transductive(val_data.ids[0:1], return_proba=True)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0., 0., 1., 0., 0., 0., 0.])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val_data[0][1][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make predictions for all **test** nodes in the holdout set, measure test accuracy, and visually compare some of them with ground truth." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "y_pred = p.predict_transductive(df_holdout.index, return_proba=False)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "y_true = df_holdout.target.values" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Ground TruthPredicted
0TheoryTheory
1Genetic_AlgorithmsTheory
2Neural_NetworksNeural_Networks
3Neural_NetworksNeural_Networks
4Reinforcement_LearningReinforcement_Learning
\n", "
" ], "text/plain": [ " Ground Truth Predicted\n", "0 Theory Theory\n", "1 Genetic_Algorithms Theory\n", "2 Neural_Networks Neural_Networks\n", "3 Neural_Networks Neural_Networks\n", "4 Reinforcement_Learning Reinforcement_Learning" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "pd.DataFrame(zip(y_true, y_pred), columns=['Ground Truth', 'Predicted']).head()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8232931726907631" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "(y_true == np.array(y_pred)).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our final test accuracy for transductive inference on the holdout nodes is **82.32%** accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }