{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#### Data source:\n", "Lee M, Teber ET, Holmes O, Nones K, Patch AM, Dagg RA, Lau LMS, Lee JH, Napier CE, Arthur JW, Grimmond SM, Hayward NK, Johansson PA, Mann GJ, Scolyer RA, Wilmott JS, Reddel RR, Pearson JV, Waddell N, Pickett HA. \n", "**Telomere sequence content can be used to determine ALT activity in tumours.**\n", "_Nucleic Acids Res._ 2018 Jun 1;46(10):4903-4918." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# STEP 1: Preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TTAGGGATAGGGCTAGGGGTAGGGTAAGGGTCAGGGTGAGGGTTCGGGTTGGGGTTTGGG...TTACGGTTATGGTTAGAGTTAGCGTTAGTGTTAGGATTAGGCTTAGGTrel_TLTMM
094.8460.0190.4300.4220.2160.5441.7620.5350.3380.068...0.0280.1180.1530.0000.0490.0600.0330.089-0.89-
194.9510.0110.2410.4910.2230.3171.3510.8180.7020.090...0.0240.1250.0800.0240.0350.1550.0300.093-0.39-
294.8890.0430.4390.4780.3550.3161.1510.6250.3130.079...0.0410.2530.1950.0320.0430.1610.0470.185-1.66-
394.2020.0170.2520.5090.3960.5481.8770.8560.4400.097...0.0530.1100.1250.0000.0430.0690.0290.110-1.73-
496.3680.0110.0780.1310.0150.3061.5251.1650.1260.000...0.0140.0990.0220.0000.0190.0260.0090.0140.21-
598.8430.0010.1120.1790.0000.0730.2850.2800.0940.000...0.0090.0110.0450.0000.0020.0140.0100.0030.56-
697.0410.0020.2090.3240.2000.3910.6400.3530.2570.041...0.0140.0660.0890.0000.0170.0860.0160.0410.01-
793.6870.0340.4440.6510.4630.6551.3470.8360.5500.133...0.0380.1360.1910.0460.0340.1600.0330.099-0.79-
897.5000.0150.1490.2640.0780.2380.9170.1840.2060.038...0.0190.0860.0600.0000.0290.0450.0240.031-0.83-
997.1100.0010.2230.3030.1660.2930.7290.3820.2530.026...0.0260.0360.1480.0000.0170.0590.0190.029-0.02-
\n", "

10 rows × 21 columns

\n", "
" ], "text/plain": [ " TTAGGG ATAGGG CTAGGG GTAGGG TAAGGG TCAGGG TGAGGG TTCGGG TTGGGG \\\n", "0 94.846 0.019 0.430 0.422 0.216 0.544 1.762 0.535 0.338 \n", "1 94.951 0.011 0.241 0.491 0.223 0.317 1.351 0.818 0.702 \n", "2 94.889 0.043 0.439 0.478 0.355 0.316 1.151 0.625 0.313 \n", "3 94.202 0.017 0.252 0.509 0.396 0.548 1.877 0.856 0.440 \n", "4 96.368 0.011 0.078 0.131 0.015 0.306 1.525 1.165 0.126 \n", "5 98.843 0.001 0.112 0.179 0.000 0.073 0.285 0.280 0.094 \n", "6 97.041 0.002 0.209 0.324 0.200 0.391 0.640 0.353 0.257 \n", "7 93.687 0.034 0.444 0.651 0.463 0.655 1.347 0.836 0.550 \n", "8 97.500 0.015 0.149 0.264 0.078 0.238 0.917 0.184 0.206 \n", "9 97.110 0.001 0.223 0.303 0.166 0.293 0.729 0.382 0.253 \n", "\n", " TTTGGG ... TTACGG TTATGG TTAGAG TTAGCG TTAGTG TTAGGA TTAGGC \\\n", "0 0.068 ... 0.028 0.118 0.153 0.000 0.049 0.060 0.033 \n", "1 0.090 ... 0.024 0.125 0.080 0.024 0.035 0.155 0.030 \n", "2 0.079 ... 0.041 0.253 0.195 0.032 0.043 0.161 0.047 \n", "3 0.097 ... 0.053 0.110 0.125 0.000 0.043 0.069 0.029 \n", "4 0.000 ... 0.014 0.099 0.022 0.000 0.019 0.026 0.009 \n", "5 0.000 ... 0.009 0.011 0.045 0.000 0.002 0.014 0.010 \n", "6 0.041 ... 0.014 0.066 0.089 0.000 0.017 0.086 0.016 \n", "7 0.133 ... 0.038 0.136 0.191 0.046 0.034 0.160 0.033 \n", "8 0.038 ... 0.019 0.086 0.060 0.000 0.029 0.045 0.024 \n", "9 0.026 ... 0.026 0.036 0.148 0.000 0.017 0.059 0.019 \n", "\n", " TTAGGT rel_TL TMM \n", "0 0.089 -0.89 - \n", "1 0.093 -0.39 - \n", "2 0.185 -1.66 - \n", "3 0.110 -1.73 - \n", "4 0.014 0.21 - \n", "5 0.003 0.56 - \n", "6 0.041 0.01 - \n", "7 0.099 -0.79 - \n", "8 0.031 -0.83 - \n", "9 0.029 -0.02 - \n", "\n", "[10 rows x 21 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "# Load data\n", "data = pd.read_csv(\"data/telomere_ALT/telomere.csv\", sep='\\t')\n", "data.head(10) # Show first ten samples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split data into training and test sets" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn import model_selection\n", "\n", "num_row = data.shape[0] # number of samples in the dataset\n", "num_col = data.shape[1] # number of features in the dataset (plus the label column)\n", "\n", "X = data.iloc[:, 0: num_col-1] # feature columns\n", "y = data['TMM'] # label column\n", "\n", "X_train, X_test, y_train, y_test = \\\n", " model_selection.train_test_split(X, y, \n", " test_size=0.2, # reserve 20 percent data for testing\n", " stratify=y, # stratified sampling\n", " random_state=0)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(128, 20)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(33, 20)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "- 96\n", "+ 32\n", "Name: TMM, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train.value_counts()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "- 25\n", "+ 8\n", "Name: TMM, dtype: int64" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test.value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# STEP 2: Learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train a Random Forest classifier" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn import ensemble\n", "\n", "rf = ensemble.RandomForestClassifier(\n", " n_estimators = 10, # 10 random trees in the forest\n", " criterion = 'entropy', # use entropy as the measure of uncertainty\n", " max_depth = 3, # maximum depth of each tree is 3\n", " min_samples_split = 5, # generate a split only when there are at least 5 samples at current node\n", " class_weight = 'balanced', # class weight is inversely proportional to class frequencies\n", " random_state = 0)\n", "\n", "k = 3 # number of folds\n", "\n", "# split data into k folds\n", "kfold = model_selection.StratifiedKFold(\n", " n_splits = k,\n", " shuffle = True,\n", " random_state = 0)\n", "\n", "cv = list(kfold.split(X_train, y_train)) # indices of samples in each fold" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "validation fold 1:\n", "[ 3 6 7 10 12 13 20 24 31 34 36 37 39 42 45 56 57 59\n", " 62 63 69 70 71 75 77 81 83 88 90 91 95 97 100 104 105 106\n", " 107 108 111 113 116 117 127]\n", "validation fold 2:\n", "[ 0 2 4 8 9 11 14 16 17 18 22 25 26 29 33 38 40 43\n", " 46 48 51 53 54 55 65 66 68 72 73 76 78 80 82 84 93 98\n", " 103 112 114 119 120 121 123]\n", "validation fold 3:\n", "[ 1 5 15 19 21 23 27 28 30 32 35 41 44 47 49 50 52 58\n", " 60 61 64 67 74 79 85 86 87 89 92 94 96 99 101 102 109 110\n", " 115 118 122 124 125 126]\n" ] } ], "source": [ "for j, (train_idx, val_idx) in enumerate(cv):\n", " print('validation fold %d:' % (j+1))\n", " print(val_idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute classifier accuracy" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fold 1 accuracy: 0.860\n", "Fold 2 accuracy: 0.860\n", "Fold 3 accuracy: 0.857\n", "Average accuracy: 0.859\n" ] } ], "source": [ "accuracy = []\n", "\n", "# Compute classifier's accuracy on each fold\n", "for j, (train_idx, val_idx) in enumerate(cv):\n", " rf.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])\n", " accuracy_per_fold = rf.score(X_train.iloc[val_idx], y_train.iloc[val_idx])\n", " accuracy.append(accuracy_per_fold)\n", " print('Fold %d accuracy: %.3f' % (j+1, accuracy_per_fold))\n", "\n", "print('Average accuracy: %.3f' % np.mean(accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot receiver operating characteristic (ROC) curve" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "from sklearn import metrics\n", "\n", "fpr = dict() # false positive rate\n", "tpr = dict() # true positive rate\n", "auroc = dict() # area under the ROC curve (AUROC)\n", "\n", "# Compute an ROC curve (with respect to the positive class) for each fold\n", "# Compute AUROC for each fold\n", "for j, (train_idx, val_idx) in enumerate(cv):\n", " rf.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])\n", " y_prob = rf.predict_proba(X_train.iloc[val_idx])\n", " fpr[j], tpr[j], _ = metrics.roc_curve(y_train.iloc[val_idx], y_prob[:, 0], pos_label='+')\n", " auroc[j] = metrics.auc(fpr[j], tpr[j])\n", "\n", "# Compute an average ROC curve for all folds\n", "fpr['avg'] = np.unique(np.concatenate([fpr[j] for j in range(k)]))\n", "tpr['avg'] = np.zeros_like(fpr['avg'])\n", "\n", "for j in range(k):\n", " tpr['avg'] += np.interp(fpr['avg'], fpr[j], tpr[j])\n", "\n", "tpr['avg'] /= k\n", "\n", "# Compute AUROC of the average ROC curve\n", "auroc['avg'] = metrics.auc(fpr['avg'], tpr['avg'])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the ROC curves\n", "for j in range(k):\n", " plt.plot(fpr[j], tpr[j], label='fold %d (area = %.3f)' % (j+1, auroc[j]))\n", "\n", "plt.plot(fpr['avg'], tpr['avg'], 'k--', label='mean ROC (area = %.3f)' % auroc['avg'], lw=2)\n", "plt.plot([0,1], [0,1], linestyle='--', color=(0.6, 0.6, 0.6)) # ROC curve for a random classifier\n", "plt.plot([0,0,1], [0,1,1], linestyle='--', color=(0.6, 0.6, 0.6)) # ROC curve for a perfect classifier\n", "\n", "plt.xlim([-0.05, 1.05])\n", "plt.ylim([-0.05, 1.05])\n", "plt.xlabel('FPR')\n", "plt.ylabel('TPR')\n", "plt.legend(loc='lower right')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot precision-recall (PR) curve" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "precision_plus = dict()\n", "recall_plus = dict()\n", "auprc_plus = dict()\n", "\n", "precision_minus = dict()\n", "recall_minus = dict()\n", "auprc_minus = dict()\n", "\n", "# Compute PR curves (with respect to the positive and negative classes) for each fold\n", "# Compute AUPRC for each fold\n", "for j, (train_idx, val_idx) in enumerate(cv):\n", " rf.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])\n", " y_prob = rf.predict_proba(X_train.iloc[val_idx])\n", " precision_plus[j], recall_plus[j], _ = metrics.precision_recall_curve(y_train.iloc[val_idx], y_prob[:, 0], pos_label='+')\n", " precision_minus[j], recall_minus[j], _ = metrics.precision_recall_curve(y_train.iloc[val_idx], y_prob[:, 1], pos_label='-')\n", " auprc_plus[j] = metrics.auc(recall_plus[j], precision_plus[j])\n", " auprc_minus[j] = metrics.auc(recall_minus[j], precision_minus[j])\n", "\n", "# Compute average PR curves for all folds\n", "recall_plus['avg'] = np.unique(np.concatenate([recall_plus[j] for j in range(k)]))\n", "recall_minus['avg'] = np.unique(np.concatenate([recall_minus[j] for j in range(k)]))\n", "precision_plus['avg'] = np.zeros_like(recall_plus['avg'])\n", "precision_minus['avg'] = np.zeros_like(recall_minus['avg'])\n", "\n", "for j in range(k):\n", " precision_plus['avg'] += np.interp(recall_plus['avg'], np.flip(recall_plus[j], 0), np.flip(precision_plus[j], 0))\n", " precision_minus['avg'] += np.interp(recall_minus['avg'], np.flip(recall_minus[j], 0), np.flip(precision_minus[j], 0))\n", "\n", "precision_plus['avg'] /= k\n", "precision_minus['avg'] /= k\n", "\n", "# Compute AUPRCs for the average curves\n", "auprc_plus['avg'] = metrics.auc(recall_plus['avg'], precision_plus['avg'])\n", "auprc_minus['avg'] = metrics.auc(recall_minus['avg'], precision_minus['avg'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PR curve w.r.t the positive class" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the PR curves for the positive class\n", "for j in range(k):\n", " plt.plot(recall_plus[j], precision_plus[j], label='fold %d (area = %.3f)' % (j+1, auprc_plus[j]))\n", "\n", "plt.plot(recall_plus['avg'], precision_plus['avg'], 'k--', label='mean PR (area = %.3f)' % auprc_plus['avg'], lw=2)\n", "plt.plot([0,1,1], [1,1,0], linestyle='--', color=(0.6, 0.6, 0.6)) # PR curve for a perfect classifier\n", "\n", "plt.xlim([-0.05, 1.05])\n", "plt.ylim([-0.05, 1.05])\n", "plt.xlabel('recall')\n", "plt.ylabel('precision')\n", "plt.legend(loc='lower left')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PR curve w.r.t the negative class" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the PR curves for the negative class\n", "for j in range(k):\n", " plt.plot(recall_minus[j], precision_minus[j], label='fold %d (area = %.3f)' % (j+1, auprc_minus[j]))\n", "\n", "plt.plot(recall_minus['avg'], precision_minus['avg'], 'k--', label='mean PR (area = %.3f)' % auprc_minus['avg'], lw=2)\n", "plt.plot([0,1,1], [1,1,0], linestyle='--', color=(0.6, 0.6, 0.6)) # PR curve for a perfect classifier\n", "\n", "plt.xlim([-0.05, 1.05])\n", "plt.ylim([-0.05, 1.05])\n", "plt.xlabel('recall')\n", "plt.ylabel('precision')\n", "plt.legend(loc='lower left')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot confusion matrix" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Compute a 2x2 confusion matrix\n", "confusion_matrix = np.zeros([2, 2])\n", "\n", "for j, (train_idx, val_idx) in enumerate(cv):\n", " rf.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])\n", " y = rf.predict(X_train.iloc[val_idx])\n", " confusion_matrix += metrics.confusion_matrix(y_train.iloc[val_idx], y)\n", "\n", "# Average over all folds\n", "confusion_matrix /= k\n", "confusion_matrix = np.around(confusion_matrix, decimals=3)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ4AAAEMCAYAAAA1eViuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFM1JREFUeJzt3Xl4VdW9xvHvLwkgJISIiUwBUVFJRAlJBKcqKgWLCAIiWIdSI5NWrbb22tZr672OVVux4oBDKfaKQxFHFCtOTBKQ2QEoZSqgBCGACYEkrPtHNmnCEM4KnCHk/TzPeThZe+29fpuT583ea5+zjznnEBHxERftAkSk7lFwiIg3BYeIeFNwiIg3BYeIeFNwiIg3BYeIeFNwRJCZdTezcdGuQ+RQKThExJuCQ0S8md5yHn5mNhtoBCQBzYE1waL/cs5NiVphIrWk4IggM+sODHXODY1yKVILZnYjMCz4sbdzbn0064mmhGgXIFJXOOfGAGOiXUcs0ByHiHjTqYqIeNMRh4h4U3CIiDcFh4h4U3CIiDcFh4h4U3BEmJkNj3YNUnt6/SooOCJPv3h1m14/FBwiUgsx+Qaw5sekurbtjot2GWHx3aYCjklNi3YZYZUQb9EuIWw2FRSQmnbkvn5ff/VVcVFRUeLB+sXkZ1XatjuO9z6aGe0ypJbSkhtFuwSppfbHtS0MpZ9OVUTEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4JDRLwpOETEm4IjTMY+8Rjdz+rCBWdlMyrvGkpKSqot37lzJyOuu5qzszO5pMcPWLtmVXQKrafWrl1Lj4su4PROmXQ+7VQee2z0fvt98vHH5GRn0fm0U7nwgvMr2wsLCxk86HI6ZXbktFMzmDVrFgA/HjKYnOwscrKz6HBCe3KysyKyP5GWEO0CjkQb1q/juafH8PFnC2jcuDEjfnoVb7z2CoN/fG1lnwkvjCOlWQoz533J6xNf4Z7f38nTz/8tilXXLwkJCfzhoUfIzs5m+/btdDsjhx49fkhmZmZln8LCQm762Q28Pfk92rVrx8aNGyuX3frzW+jZ62JefvXv7Nq1i+LiYgBefOnlyj63//IXNGvWLHI7FUFhP+Iws+5mNi7c48SasrIySkp2UFZWxo7iYlq0bFVt+ZR332LQlVcD0KffAKZ/8hHOuWiUWi+1atWK7OxsAJo2bUrHjhmsX7euWp8JE17ksv4DaNeuHQDHHnssAFu3bmX6tE+5Li8PgIYNG5KSklJtXeccf3/1FQYPuTLcuxIVOlUJg1at2zDqpls547STyOrYnqbJyXS/8IfV+nyzfj2t26QDFX/9kpOT2bz5u2iUW++tWrWKBQvm07Vbt2rty5ctY8uWLVx0YXe6npHDC+PHA7By5UpS09LIu+6n5OZ0Yfiw6ykqKqq27vRp0zi2RQtOOumkSO1GRMVMcJjZcDOba2Zzv9tUEO1yDklh4RamTH6L2Qu+Zv5XKykuLmbiyy9GuyzZj++//54rBg3kkT8+SnJycrVlZWVlzJv3OW++9Q6T353Cfff+L8uWLaOsrIz58+YxYuQo5n4+n8TERP7w4APV1n3ppQkMOUKPNiCMwWFms81sAfAs0NfMFgSPXvvr75wb65zLdc7lHpOaFq6yImLaxx/S9rj2HJOaRoMGDeh9aT/m5n9WrU/L1q1Zv+7fQMUv6LZt22je/JholFtvlZaWcsXlA7nyx1fRf8CAfZanp6fTs2cvEhMTSU1N5dwfnMeihQtJT08nPT2dbsERysCBlzN/3rzK9crKynh90msMumJwxPYl0sIWHM65bs65LOB64E3nXFbwmBKuMWNFm/S2zJubT3FxMc45pn/yER1O6VitT8+L+/DqhIrJ0LffeI1zz+uOmUWj3HrJOcew6/PomJHBrbfett8+l/btx4wZ0ykrK6O4uJg5+bPpmJFBy5YtSW/blqVLlwLw4YdTyagyqTr1gw84pWNH0tPTI7Iv0aCrKmGQnduVS/r2p1f3M0mIT6DT6Z25+id5/OG+u+mclUOv3n248pqh3DzyOs7OziTl6OY8+dz4aJddr8yYMYP/+9sLdDrttMpLpvfccx9r1qwBYMTIkWRkZNCr18VkZ51OXFwcP827nk6dOgHw6Og/c+01V7Fr1y5OOP4Enn3+L5Xbfvnllxg8+Mg9TQGwcM/km1l3YKhzbmio63TukuPe+2hm2GqS8EpLbhTtEqSW2h/Xdv2/165tc7B+YT/icM59DHwc7nFEJHJi5qqKiNQdCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvCg4R8abgEBFvB/zuWDNLrmlF59y2w1+OiNQFNX3p9BeAA6xK256fHdAujHWJSAw7YHA459pGshARqTtCmuMwsyFm9pvgebqZ5YS3LBGJZQcNDjN7HLgAuCZoKgaeCmdRIhLbaprj2ONs51y2mc0HcM5tNrOGYa5LRGJYKKcqpWYWR8WEKGZ2DLA7rFWJSEwLJTjGABOBNDO7G5gOPBjWqkQkph30VMU5N97MPgd6BE2DnHNLwluWiMSyUOY4AOKBUipOV/RuU5F6LpSrKr8FJgCtgXTgRTP7dbgLE5HYFcoRx7VAF+dcMYCZ3QvMB+4PZ2EiErtCOe3YQPWASQjaRKSequlDbn+iYk5jM/CFmU0Jfu4JzIlMeSISi2o6Vdlz5eQL4J0q7Z+FrxwRqQtq+pDbc5EsRETqjoNOjprZicC9QCZw1J5259zJYaxLRGJYKJOj44C/UHEfjh8BrwAvh7EmEYlxoQRHE+fcFADn3Arn3J1UBIiI1FOhvI9jZ/AhtxVmNhJYBzQNb1kiEstCCY5bgUTgZirmOpoB14WzKBGJbaF8yG128HQ7/7mZj4jUYzW9AWwSwT049sc5NyAsFYlIzKvpiOPxiFWxl4Q4o3mSbjJWVz34zORolyC1tO7bwpD61fQGsKmHrRoROaLo3hoi4k3BISLeQg4OM2sUzkJEpO4I5Q5gXc1sMbA8+Lmzmf057JWJSMwK5YjjMaAP8B2Ac24hFV/QJCL1VCjBEeecW71XW3k4ihGRuiGUt5yvNbOugDOzeOAmYFl4yxKRWBbKEcco4DagHfAtcGbQJiL1VCifVdkIDIlALSJSR4RyB7Bn2M9nVpxzw8NSkYjEvFDmOD6o8vwooD+wNjzliEhdEMqpSrXbBJrZC1R88bSI1FO1ecv58UCLw12IiNQdocxxbOE/cxxxVHxB0x3hLEpEYluNwWFmBnSm4j6jALudcwe8uY+I1A81nqoEITHZOVcePBQaIhLSHMcCM+sS9kpEpM6o6Z6jCc65MqALMMfMVgBFVHwxk3POZUeoRhGJMTXNceQD2UDfCNUiInVETcFhUPHtbRGqRUTqiJqCI83MbjvQQufcH8NQj4jUATUFRzyQRHDkISKyR03BscE59z8Rq0RE6oyaLsfqSENE9qum4LgoYlWISJ1ywOBwzm2OZCEiUnfoC5lExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S81fTdsXIQJSUl9LjgfHbu3ElZeRn9Bwzkrt/dXa3PM08/xVNPPkF8fDyJSUk88eTTZGRmMic/nxtHjQDAOcedd/2Ofpf1D2mbUjuFmwuYOO4Rvt9WiJmRe+7FnH1RPzasXcEbL46hrHQXcXHx9L3yBtKPP2U/629k0guPsW1LAWBc+7O7OTq1Bc45PnhjPEvmTScuLo6u513CWRf2Zdr7E1mY/xEAu3fvpmDDWn798Is0SWwa4T0//Mw5F+0a9pGTk+tmzp4T7TIOyjlHUVERSUlJlJaWcuH5P+DhPz5KtzPPrOyzbds2kpOTAXj7rTd5+qkneeuddykuLqZhw4YkJCSwYcMGuuZksXLNOuLj4w+6zVj38HPvRruE/dq+dTPbt26mdbsO7Cwp5on7buGqkf/NO6+O5ZyLLuPkTrksXTyHae9P5PpfPLDP+s8+cgfdfzSYDpld2FmyA4szGjY8is9n/oOVSxcx4Ce3EhcXx/fbCklKTqm27teLZjNj6uvk3Xp/pHa3Vu68afB6t2t7m4P10xHHITAzkpKSACgtLaW0tBSz6t/VvSc0AIqKiiqXN2nSpLK9pKSksj2UbUrtNG3WnKbNmgPQ6KgmpLVsy7bC7zAzdpYUA1BSUkRySvN91t24fg27d5fTIbNLsH7jymX5n0zmirzbiYurOPPfOzQAFs35hNNzzz/s+xQtCo5DVF5ezlldc1mx4p+MHHUDXbt126fPU0+MYfToP7Fr1y6mvD+1sj1/9mxGDM9jzerVPD9uPAkJCSFvUw7Nlk3fsmHtv0g//hR6DxrGXx+7i3cnPofb7Rj+q4f36b9p4zoaN0nkxafuYct333Jixyx69h9KXFw8mzdtYPHcT/lywSwSmzbjkitGkNriP3+0d+0qYfkXn9NnyKhI7mJYxczkqJkNN7O5Zja3YFNBtMsJWXx8PPmfz2fFqrXMmTOHL5Ys2afPyBtu5Kul/+Te+x7g/vvurWzv2q0b8xcuYcasfB568AFKSkpC3qbU3s6SHUwYey+9rxjGUY2bkP/pZHoPGsav7v8rvQcNY9ILj+6zzu7yclYt/4KLB+Yx8o5H2bzpG+bN+gCA8rJSEho05IbfjCb33F5MemF0tXWXLsqn3YmZR8Tcxh4xExzOubHOuVznXG5aalq0y/GWkpLC+d278/777x2wzxWDh/DWm6/v094xI4PEpKR9AiKUbYqf8vIyJoy9j85dL+DULucAMH/WVDK7nA1Ap5xzWbdq2T7rNTs6lVZtT6B5Wivi4+PJ6HwWG9asACA5JbVy/cyss/nm3yurrbtozqecfsaRc5oCEQwOM7vRzBYEj9aRGjecCgoKKCwsBGDHjh1M/eADTjmlY7U+/1y+vPL5u5PfoUOHkwBYuXIlZWVlAKxevZplS7/muPbtQ9qm1I5zjknjR5PWsi3n9Ohf2Z6c0pyVyxYD8K+lCznm2H1/Pdu0P4mS4iKKtm+t7JfWqh0AGVlnsnLpIgBWLltc7TSlZEcRq5YvJqNz3ZncDkXE5jicc2OAMZEaLxK+2bCB668bSnl5ObvdbgZePojel/Th7t/fRU5OLn0u7cuTTzzOhx9OpUFCA1KOPppnnx8HwMwZ03n4oQdpkNCAuLg4Rv95DKmpqSxetGi/25RDt3rFlyyY/SEt2rTn8Xt+BsAP+/2EflffzORXnmZ3+W4SGjSg31U3AbBu9XLyP51M/2tuIS4unosH5vH8o78B52jdrgO55/YC4Lxeg3j1+YeYOfV1GjZqzGXX3Fw55pfzZ9IhM5uGjY6K/A6HkS7HymEXq5dj5eBCvRwbM3McIlJ3KDhExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8KThExJuCQ0S8mXMu2jXsw8wKgNXRriNMUoFN0S5Cau1If/2Oc86lHaxTTAbHkczM5jrncqNdh9SOXr8KOlUREW8KDhHxpuCIvLGHc2NmVm5mC8xsiZm9amZNDmFb3c3s7eB5XzO7o4a+KWZ2Qy3G+L2Z/TLU9r36jDOzyz3Gam9mS3xrPIjD+vrVVQqOCHPOHe5fvB3OuSznXCdgFzCy6kKr4P06O+fedM49UEOXFMA7OOq6MLx+dZKC48gyDegQ/KVdambjgSVAWzPraWazzGxecGSSBGBmF5vZ12Y2DxiwZ0NmNtTMHg+etzCzSWa2MHicDTwAnBgc7TwU9LvdzOaY2SIzu7vKtn5rZsvMbDpwysF2wsyGBdtZaGYT9zqK6mFmc4Pt9Qn6x5vZQ1XGHnGo/5FSMwXHEcLMEoAfAYuDppOAJ5xzpwJFwJ1AD+dcNjAXuM3MjgKeAS4FcoCWB9j8Y8AnzrnOQDbwBXAHsCI42rndzHoGY3YFsoAcMzvPzHKAIUFbb+CMEHbnNefcGcF4XwF5VZa1D8a4BHgq2Ic8YKtz7oxg+8PM7PgQxpFaSoh2AXLIGpvZguD5NOA5oDWw2jn3WdB+JpAJzDAzgIbALKAjsNI5txzAzP4GDN/PGBcC1wI458qBrWZ29F59egaP+cHPSVQESVNgknOuOBjjzRD2qZOZ3UPF6VASMKXKslecc7uB5Wb2r2AfegKnV5n/aBaMvSyEsaQWFBx13w7nXFbVhiAciqo2Af9wzl25V79q6x0iA+53zj291xg/r8W2xgGXOecWmtlQoHuVZXu/8cgFY9/knKsaMJhZ+1qMLSHQqUr98Blwjpl1ADCzRDM7GfgaaG9mJwb9rjzA+lOBUcG68WbWDNhOxdHEHlOA66rMnbQxs2OBT4HLzKyxmTWl4rToYJoCG8ysAXDVXssGmVlcUPMJwNJg7FFBf8zsZDNLDGEcqSUdcdQDzrmC4C/3BDNrFDTf6ZxbZmbDgXfMrJiKU52m+9nELcBYM8sDyoFRzrlZZjYjuNz5bjDPkQHMCo54vgeuds7NM7OXgYXARmBOCCX/NzAbKAj+rVrTGiAfSAZGOudKzOxZKuY+5lnF4AXAZaH970ht6C3nIuJNpyoi4k3BISLeFBwi4k3BISLeFBwi4k3BISLeFBwi4k3BISLe/h96DsKcsMHgGAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the confusion matrix\n", "plt.matshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues, alpha=0.5)\n", "plt.xticks(np.arange(2), ('+', '-'))\n", "plt.yticks(np.arange(2), ('+', '-'))\n", "\n", "for i in range(2):\n", " for j in range(2):\n", " plt.text(j, i, confusion_matrix[i, j], ha='center', va='center')\n", "\n", "plt.xlabel('Predicted label')\n", "plt.ylabel('True label')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# STEP 3: Evaluation" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Refit the classifier using the full training set\n", "rf = rf.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute classifier accuracy on the test set" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 0.848\n" ] } ], "source": [ "# Compute classifier's accuracy on the test set\n", "test_accuracy = rf.score(X_test, y_test)\n", "print('Test accuracy: %.3f' % test_accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate plots for the test set" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Predict labels for the test set\n", "y = rf.predict(X_test) # Solid prediction\n", "y_prob = rf.predict_proba(X_test) # Soft prediction\n", "\n", "# Compute the ROC and PR curves for the test set\n", "fpr, tpr, _ = metrics.roc_curve(y_test, y_prob[:, 0], pos_label='+')\n", "precision_plus, recall_plus, _ = metrics.precision_recall_curve(y_test, y_prob[:, 0], pos_label='+')\n", "precision_minus, recall_minus, _ = metrics.precision_recall_curve(y_test, y_prob[:, 1], pos_label='-')\n", "\n", "# Compute the AUROC and AUPRCs for the test set\n", "auroc = metrics.auc(fpr, tpr)\n", "auprc_plus = metrics.auc(recall_plus, precision_plus)\n", "auprc_minus = metrics.auc(recall_minus, precision_minus)\n", "\n", "# Compute the confusion matrix for the test set\n", "cm = metrics.confusion_matrix(y_test, y)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the ROC curve for the test set\n", "plt.plot(fpr, tpr, label='test (area = %.3f)' % auroc)\n", "plt.plot([0,1], [0,1], linestyle='--', color=(0.6, 0.6, 0.6))\n", "plt.plot([0,0,1], [0,1,1], linestyle='--', color=(0.6, 0.6, 0.6))\n", "\n", "plt.xlim([-0.05, 1.05])\n", "plt.ylim([-0.05, 1.05])\n", "plt.xlabel('FPR')\n", "plt.ylabel('TPR')\n", "plt.legend(loc='lower right')\n", "plt.title('ROC curve')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the PR curve for the test set\n", "plt.plot(recall_plus, precision_plus, label='positive class (area = %.3f)' % auprc_plus)\n", "plt.plot(recall_minus, precision_minus, label='negative class (area = %.3f)' % auprc_minus)\n", "\n", "plt.plot([0,1,1], [1,1,0], linestyle='--', color=(0.6, 0.6, 0.6))\n", "\n", "plt.xlim([-0.05, 1.05])\n", "plt.ylim([-0.05, 1.05])\n", "plt.xlabel('recall')\n", "plt.ylabel('precision')\n", "plt.legend(loc='lower left')\n", "plt.title('Precision-recall curve')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ4AAAEMCAYAAAA1eViuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADwdJREFUeJzt3XmwlYV5x/Hfj4u4AIIsilxAVIIgDgqCC2YsKnHfYl2wpkpVqNaYtE4zNY2daTox0WTaTqw6KeoE7YKYqq1LlLrUuBRFgqiIRmMcg2gVJAiyaLg+/eO+mCuBw3munPOeA9/PzB3Oec97z/tcDvPlfd97FkeEACCjS9kDAGg+hANAGuEAkEY4AKQRDgBphANAGuEAkEY46sj2RNszyp4D+LwIB4A0wgEgzTzlvPZsPyNpR0k9JPWR9Ovipr+KiNmlDQZ0EuGoI9sTJU2JiCklj4JOsH2ZpKnF1RMj4u0y5ylT17IHAJpFRNwg6Yay52gEnOMAkMahCoA09jgApBEOAGmEA0Aa4QCQRjgApBGOOrM9rewZ0Hk8fu0IR/3xD6+58fiJcADohIZ8AljvPn1jz0FDyh6jJlYsX6beffqVPUZN7dKtpewRambZ0qXq179/2WPUzCsvv7xm9erV3be0XkO+VmXPQUN0632Plj0GOmnsoN3KHgGdNHSvwSuqWY9DFQBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphANAGuEAkEY4AKQRDgBphKOOZt58oyZPOlznfmmCrrr8Yn20bl3ZI6FKsx98UKNG7qcRw4fp+9deU/Y4pSMcdfLe/72tWT+erhn3PaqZD/2vPmlr00P33lX2WKhCW1ubvnb5Zbr3/gf0wsJFuv32mVq0aFHZY5Wq5uGwPdH2jFpvpxm0ta3XR+vWaf369Vq3dq367TGg7JFQhblz52rffYdpn332Ubdu3XTOOZN17z3/VfZYpWKPo052HzBQ5037qk47fLROGj9SPXruqsOOPLrssVCFt5cs0aDBgz+93to6SEuWLClxovI1TDhsT7M9z/a8FcuXlT3OVrfygxV6/L8f0N1PPqf75y7S2rVr9MBdd5Q9FtApNQuH7WdsL5B0s6RTbS8ovo7b1PoRMT0ixkXEuN59+tVqrNI8++RjGjh4iHbr209dd9hBRx1/sl78+dyyx0IVBra26q3Fiz+9vmTJW2ptbS1xovLVLBwRcWhEHCTpYkn3RMRBxdfsWm2zke0xcJAWPjdP69auUUTo2ace19Bhw8seC1UYP368fvnL1/TGG2/o448/1qxZt+vkU04te6xSdS17gO3FAWPG6egTT9X5Jx2llpYWDR81Wqf/0QVlj4UqdO3aVT+87nqddMJxamtr05Q/uVCjRo0qe6xSOSJquwF7oqQpETGl2u8ZOXpM3HrfozWbCbU1dtBuZY+AThq61+C331q8eIvHYTXf44iIxyQ9VuvtAKifhvmtCoDmQTgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmb/exY27tW+saIWLn1xwHQDCp96PRLkkKSOyzbcD0kDanhXAAa2GbDERGD6zkIgOZR1TkO25Nt/3VxeZDtg2s7FoBGtsVw2L5e0lGS/rhYtEbSj2o5FIDGVukcxwYTImKs7eckKSKW2+5W47kANLBqDlV+a7uL2k+IynZfSZ/UdCoADa2acNwg6U5J/W1/W9KTkq6t6VQAGtoWD1Ui4jbbP5c0qVh0VkQsrO1YABpZNec4JKlF0m/VfrjCs02B7Vw1v1X5lqSZkgZKGiTp321/s9aDAWhc1exxnC9pTESskSTbV0t6TtL3ajkYgMZVzWHHO/psYLoWywBspyq9yO0f1X5OY7mkl2zPLq4fK+nZ+owHoBFVOlTZ8JuTlyTd32H507UbB0AzqPQit1vqOQiA5rHFk6O295V0taT9Je20YXlEDK/hXAAaWDUnR2dI+rHa34fjBEl3SJpVw5kANLhqwrFLRMyWpIh4PSKuUntAAGynqnkex0fFi9xet32JpCWSetZ2LACNrJpw/IWk7pK+pvZzHb0kXVjLoQA0tmpe5PZMcXGVfvdmPgC2Y5WeAHa3ivfg2JSIOKMmEwFoeJX2OK6v2xQb2aGLtWf3nba8IhrStTf9tOwR0ElL3l1R1XqVngD2yFabBsA2hffWAJBGOACkVR0O2zvWchAAzaOadwA7xPaLkl4rrh9o+59qPhmAhlXNHsd1kk6W9L4kRcTzav+AJgDbqWrC0SUi3txoWVsthgHQHKp5yvli24dICtstki6X9GptxwLQyKrZ47hU0hWShkh6V9JhxTIA26lqXqvynqTJdZgFQJOo5h3AbtImXrMSEdNqMhGAhlfNOY6HO1zeSdKXJS2uzTgAmkE1hyqfeZtA2/+i9g+eBrCd6sxTzveWtMfWHgRA86jmHMdv9LtzHF3U/gFNV9ZyKACNrWI4bFvSgWp/n1FJ+iQiNvvmPgC2DxUPVYpI/DQi2oovogGgqnMcC2yPqfkkAJpGpfcc7RoR6yWNkfSs7dclrVb7BzNFRIyt04wAGkylcxxzJY2VdGqdZgHQJCqFw1L7p7fVaRYATaJSOPrbvmJzN0bEP9RgHgBNoFI4WiT1ULHnAQAbVArHOxHxd3WbBEDTqPTrWPY0AGxSpXAcU7cpADSVzYYjIpbXcxAAzYMPZAKQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Aa4QCQRjgApBEOAGmEA0Bapc+OxVY2YcwIde/RUy0tXdTS0lX3P/JU2SNhM1YsX6o7Z/y9Ply5QrY17ovHa8Ixp2nN6lWaddM1WvH+e+rdd3dNnnqldu7es+xx645w1Nms/3xAffr2K3sMbEFLS4tOOPNiDRwyTB+tW6Mbv/t1DRs5RvPnPKx9RhyoPzj+bP3swTv0+Oyf6LgzLix73LrjUAXYhJ69+mjgkGGSpB132kX9BwzWyhXv65UXntbYwydJksYePkkvP/90mWOWhnDUkW195cxTdOLRE/Rvt95S9jio0m+Wvat3Fv9Kg/beTx+uXKGevfpIknrsups+XLmi5OnK0TCHKranSZomSa2DBpc8TW3cef/DGrBnq5YtfU/nnXmKhn1hPx064Ytlj4UKPlq3VjOnX60Tz56qnXbe5TO32ZZc0mAla5g9joiYHhHjImLctnoOYMCerZKkfv1313EnnqIF8+eVPBEqaWtbr5nTv6sDDzlKo8YcIUnqsWtvrfpguSRp1QfL1aNn7zJHLE3dwmH7MtsLiq+B9dpuo1izerU+XLXq08tPPPaI9hu5f8lTYXMiQnff9kP1HzBYR0z68qfLR4w+VPPnPCxJmj/nYY0YfVhZI5aqbocqEXGDpBvqtb1Gs3Tpe5p2wWRJ0vr163X6H56ticccW/JU2Jw3X1+kBc88qj1ah+r673xVkvSl0y7Qkcedpdtvukbzn3pIvfr21+Sp3yx50nI4Isqe4feMPmhs8ByH5nXbf/xP2SOgk666/Jy34+NVrVtar2HOcQBoHoQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJBGOACkEQ4AaYQDQBrhAJDmiCh7ht9je6mkN8ueo0b6SVpW9hDotG398dsrIvpvaaWGDMe2zPa8iBhX9hzoHB6/dhyqAEgjHADSCEf9Td+ad2a7zfYC2wtt/8T2Lp/jvibavq+4fKrtKyus29v2n3ViG39r+y+rXb7ROjNsn5nY1lDbC7MzbsFWffyaFeGos4jY2v/w1kbEQRFxgKSPJV3S8Ua3Sz/OEXFPRFxTYZXektLhaHY1ePyaEuHYtjwhaVjxP+0vbN8maaGkwbaPtT3H9vxiz6SHJNk+3vYrtudLOmPDHdmeYvv64vIetu+2/XzxNUHSNZL2LfZ2flCs9w3bz9p+wfa3O9zXt2y/avtJSftt6YewPbW4n+dt37nRXtQk2/OK+zu5WL/F9g86bPtPP+9fJCojHNsI210lnSDpxWLRFyTdGBGjJK2WdJWkSRExVtI8SVfY3knSTZJOkXSwpAGbufvrJP0sIg6UNFbSS5KulPR6sbfzDdvHFts8RNJBkg62faTtgyVNLpadKGl8FT/OXRExvtjey5Iu6nDb0GIbJ0n6UfEzXCTpg4gYX9z/VNt7V7EddFLXsgfA57az7QXF5Sck3SJpoKQ3I+LpYvlhkvaX9JRtSeomaY6kEZLeiIjXJMn2v0qatoltHC3pfEmKiDZJH9jebaN1ji2+niuu91B7SHpKujsi1hTbuKeKn+kA299R++FQD0mzO9x2R0R8Iuk1278qfoZjJY3ucP6jV7HtV6vYFjqBcDS/tRFxUMcFRRxWd1wk6aGIOHej9T7zfZ+TJX0vIv55o238eSfua4ak0yPiedtTJE3scNvGTzyKYtuXR0THwMj20E5sG1XgUGX78LSkI2wPkyTb3W0Pl/SKpKG29y3WO3cz3/+IpEuL722x3UvSKrXvTWwwW9KFHc6dtNreXdLjkk63vbPtnmo/LNqSnpLesb2DpPM2uu0s212KmfeR9Iti25cW68v2cNvdq9gOOok9ju1ARCwt/ueeaXvHYvFVEfGq7WmS7re9Ru2HOj03cRdflzTd9kWS2iRdGhFzbD9V/LrzgeI8x0hJc4o9ng8lfSUi5tueJel5Se9JeraKkf9G0jOSlhZ/dpzp15LmStpV0iURsc72zWo/9zHf7RtfKun06v520Bk85RxAGocqANIIB4A0wgEgjXAASCMcANIIB4A0wgEgjXAASPt/HQBcfRU9jKoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the confusion matrix for the test set\n", "plt.matshow(cm, interpolation='nearest', cmap=plt.cm.Blues, alpha=0.5)\n", "plt.xticks(np.arange(2), ('+', '-'))\n", "plt.yticks(np.arange(2), ('+', '-'))\n", "\n", "for i in range(2):\n", " for j in range(2):\n", " plt.text(j, i, cm[i, j], ha='center', va='center')\n", "\n", "plt.xlabel('Predicted label')\n", "plt.ylabel('True label')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Prediction" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TTAGGGATAGGGCTAGGGGTAGGGTAAGGGTCAGGGTGAGGGTTCGGGTTGGGGTTTGGGTTAAGGTTACGGTTATGGTTAGAGTTAGCGTTAGTGTTAGGATTAGGCTTAGGTrel_TL
098.6210.0020.0290.1840.0000.8170.1300.0790.0650.0000.0130.0070.0100.0080.0000.0070.0160.0090.0032.00
194.3000.0330.4480.6510.3380.4621.0520.6780.6280.1100.2630.0610.1200.1910.0530.0560.3280.0460.182-1.01
298.6660.0020.0730.1630.0000.0760.3340.4150.0750.0000.0230.0100.0140.0530.0000.0090.0650.0130.0070.81
397.3840.0080.2750.4250.1860.2770.5480.1560.1850.0120.1130.0340.1100.0950.0000.0360.0810.0360.0370.00
496.5250.0260.2040.2300.2090.4170.9340.1970.3970.1680.1930.0260.0700.0930.0030.0340.1580.0500.066-1.04
596.1500.0670.2730.3070.2520.3551.0610.2750.4440.0780.1570.0200.0920.1080.0210.0220.2190.0330.067-0.13
\n", "
" ], "text/plain": [ " TTAGGG ATAGGG CTAGGG GTAGGG TAAGGG TCAGGG TGAGGG TTCGGG TTGGGG \\\n", "0 98.621 0.002 0.029 0.184 0.000 0.817 0.130 0.079 0.065 \n", "1 94.300 0.033 0.448 0.651 0.338 0.462 1.052 0.678 0.628 \n", "2 98.666 0.002 0.073 0.163 0.000 0.076 0.334 0.415 0.075 \n", "3 97.384 0.008 0.275 0.425 0.186 0.277 0.548 0.156 0.185 \n", "4 96.525 0.026 0.204 0.230 0.209 0.417 0.934 0.197 0.397 \n", "5 96.150 0.067 0.273 0.307 0.252 0.355 1.061 0.275 0.444 \n", "\n", " TTTGGG TTAAGG TTACGG TTATGG TTAGAG TTAGCG TTAGTG TTAGGA TTAGGC \\\n", "0 0.000 0.013 0.007 0.010 0.008 0.000 0.007 0.016 0.009 \n", "1 0.110 0.263 0.061 0.120 0.191 0.053 0.056 0.328 0.046 \n", "2 0.000 0.023 0.010 0.014 0.053 0.000 0.009 0.065 0.013 \n", "3 0.012 0.113 0.034 0.110 0.095 0.000 0.036 0.081 0.036 \n", "4 0.168 0.193 0.026 0.070 0.093 0.003 0.034 0.158 0.050 \n", "5 0.078 0.157 0.020 0.092 0.108 0.021 0.022 0.219 0.033 \n", "\n", " TTAGGT rel_TL \n", "0 0.003 2.00 \n", "1 0.182 -1.01 \n", "2 0.007 0.81 \n", "3 0.037 0.00 \n", "4 0.066 -1.04 \n", "5 0.067 -0.13 " ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Enter unlabeled data for prediction\n", "newdata = pd.read_csv(\"data/telomere_ALT/telomere_new.csv\", sep='\\t')\n", "newdata.head(6)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['+', '-', '+', '-', '-', '-'], dtype=object)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# predict labels\n", "y_new = rf.predict(newdata)\n", "\n", "y_new" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }