{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fast Training of Support Vector Machines for Survival Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This document demonstrates how to use the efficient implementation of *Survival Support Vector Machines* as proposed in\n", "\n", "> Pölsterl, S., Navab, N., and Katouzian, A.,\n", "> *Fast Training of Support Vector Machines for Survival Analysis*,\n", "> Machine Learning and Knowledge Discovery in Databases: European Conference,\n", "> ECML PKDD 2015, Porto, Portugal,\n", "> Lecture Notes in Computer Science, vol. 9285, pp. 243-259 (2015)\n", "\n", "The source code and installation instructions are available at https://github.com/sebp/scikit-survival.\n", "\n", "The main class of interest is ``sksurv.svm.FastSurvivalSVM``, which implements the different optimizers for training\n", "a Survival Support Vector Machine. Training data consists of $n$ triplets $(\\mathbf{x}_i, y_i, \\delta_i)$, where\n", "$\\mathbf{x}_i$ is a $d$-dimensional feature vector, $y_i > 0$ the survival time or time of censoring, and $\\delta_i \\in \\{0,1\\}$ the binary event indicator. Using the training data, the objective is to minimize the following function:\n", "\n", "\\begin{equation}\n", " \\arg \\min_{\\mathbf{w}, b} \\frac{1}{2} \\mathbf{w}^T \\mathbf{w}+ \\frac{\\alpha}{2} \\left[\n", " r \\sum_{i,j \\in \\mathcal{P}}\n", " \\max(0, 1 - (\\mathbf{w}^T \\mathbf{x}_i - \\mathbf{w}^T \\mathbf{x}_j))^2\n", "+ (1 - r) \\sum_{i=0}^n \\left( \\zeta_{\\mathbf{w},b} (y_i, x_i, \\delta_i) \\right)^2\n", "\\right]\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "\\zeta_{\\mathbf{w},b} (y_i, \\mathbf{x}_i, \\delta_i) =\n", "\\begin{cases}\n", " \\max(0, y_i - \\mathbf{w}^T \\mathbf{x}_i - b) & \\text{if $\\delta_i = 0$,} \\\\\n", " y_i - \\mathbf{w}^T \\mathbf{x}_i - b & \\text{if $\\delta_i = 1$,} \\\\\n", "\\end{cases}\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "\\mathcal{P} = \\{ (i, j)~|~y_i > y_j \\land \\delta_j = 1 \\}_{i,j=1,\\dots,n}\n", "\\end{equation}\n", "\n", "The hyper-parameter $\\alpha > 0$ determines the amount of regularization to apply: a smaller value increases the amount of regularization and a higher value reduces the amount of regularization. The hyper-parameter $r \\in [0; 1]$ determines the trade-off between the ranking objective and the regresson objective. If $r = 1$ it reduces to the ranking objective, and if $r = 0$ to the regression objective. If the regression objective is used, it is advised to log-transform the survival/censoring time first.\n", "\n", "In this example, I'm going to use the ranking objective ($r = 1$) and grid search to determine the best setting for the hyper-parameter $\\alpha$.\n", "\n", "The class ``sksurv.svm.FastSurvivalSVM`` adheres to interfaces used in [scikit-learn](http://scikit-learn.org) and thus it is possible to combine it with auxiliary classes and functions from scikit-learn. Here, I'm going to use [GridSearchCV](http://scikit-learn.org/stable/modules/generated/sklearn.module_selection.GridSearchCV.html) to determine which set hyper-parameters performs best for the Veteran's Lung Cancer data. Since, we require an event indicator $\\delta_i$, which is boolean, and the survival/censoring time $y_i$ for training, we have to create a structured array that contains both information.\n", "\n", "But first, we have to import the classes we are going to use." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas\n", "import seaborn as sns\n", "from sklearn.model_selection import ShuffleSplit, GridSearchCV\n", "\n", "from sksurv.datasets import load_veterans_lung_cancer\n", "from sksurv.column import encode_categorical\n", "from sksurv.metrics import concordance_index_censored\n", "from sksurv.svm import FastSurvivalSVM\n", "\n", "sns.set_style(\"whitegrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, load data of the *Veteran's Administration Lung Cancer Trial* from disk and convert it to numeric values. The data consists of 137 patients and 6 features. The primary outcome measure was death (`Status`, `Survival_in_days`).\n", "The original data can be retrieved from http://lib.stat.cmu.edu/datasets/veteran.\n", "\n", "Note that it does not matter how you name the fields corresponding to the event indicator and time, as long as the event indicator comes first." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data_x, y = load_veterans_lung_cancer()\n", "x = encode_categorical(data_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we are essentially ready to start training, but before let's determine what the amount of censoring for this data is and plot the survival/censoring times." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.6% of records are censored\n" ] } ], "source": [ "n_censored = y.shape[0] - y[\"Status\"].sum()\n", "print(\"%.1f%% of records are censored\" % (n_censored / y.shape[0] * 100))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhMAAAFhCAYAAADKoShzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAFpZJREFUeJzt3X+M3XW95/HXFDqiMndoqNCWIhXIfpnEyMWqsIVSojTadndRVII3nSs1apYA8Wqr/JAqQjEgWCLce70NTW1FDYQWTO7VosG4LrLqxdErK45fsrAgFln51QroFGhn/5hpLXTamc7nnDOd6eORkMz5nvM985l3m+HZ7znf823r7+8PAMBoTRrrBQAA45uYAACKiAkAoIiYAACKiAkAoIiYAACKHNysJ+7p6XHOKQBMILNnz24banvTYmLwmzb8OXt7e9PV1dXw52Vo5t06Zt1a5t06Zt1azZp3T0/PHu/zMgcAUERMAABFxAQAUERMAABFxAQAUERMAABFxAQAUKSpnzMBwIFl1iXf2cM9D4/q+R65ZtFe77/mmmvywAMP5Mknn0xfX1+OPvroTJkyJeeff35+8IMf5MILLxzV9x2J+++/P5deemne+c53ZunSpTu3v/nNb85JJ52UJOnr68tpp52Wiy66KJMm7du/3++77750dHTkhBNOyKmnnpp77723oetvJDEBwLh1ySWXJEnuuOOOPPzww1m2bNnO+5r9QVk//vGPc+6556a7u/sV2zs7O3PLLbckSfr7+/P5z38+3/zmN3d73HA2bNiQhQsX5oQTTmjYmptFTAAw4fzsZz/LrbfemhtuuCHz58/PSSedlEcffTSnnHJKnnvuudx///1505velOuuuy5/+MMfsnz58mzdujWvec1rctVVV2X69Ok7n+ull17KZZddlsceeyzbtm3LkiVLMnPmzKxfvz6TJ0/OtGnTMn/+/CHX0dbWliVLluSyyy5Ld3d3Nm7cmLVr12bSpEmZPXt2li1blieeeCJXXHFFtm7dms2bN+eCCy7ItGnTcs899+SBBx7I8ccfnxdffDFLly7N448/nsMOOyw33nhjJk+e3KpxDktMADChbdq0KevWrcsb3vCGvOMd78jtt9+e5cuX513velf+9Kc/5dprr013d3fmzZuXn/zkJ7n++uvz5S9/eef+t912W6ZMmZLrrrsuzz//fM4+++zceuuted/73pepU6fuMSR2mDp1ap599tls3rw5N910UzZs2JDXvva1+fSnP5177713Z3CcfPLJ+cUvfpGbbropX/va1zJ37twsXLgwM2bMyJ///Od88pOfzMyZM9Pd3Z3e3t685S1vafboRkxMADChHXbYYZkxY0aS5HWve12OP/74JElHR0e2bt2aBx98MKtWrcrq1avT39+/27/4H3roocyZMydJcuihh+a4447LY489NuLvv2nTpkybNi2/+93v8swzz+TjH/94kuSFF17IY489ltmzZ+erX/1q1q9fn7a2trz88su7PUdnZ2dmzpyZZCBO/vKXv+z7IJpITAAwobW1DXmhy52OPfbYfOQjH8lb3/rWPPTQQ7nvvvtecf9xxx2Xn//855k/f36ef/75PPjggzv/xz6c7du3Z82aNVm0aFFmzpyZ6dOnZ82aNZk8eXLuuOOOdHV15Stf+Uo++MEPZt68edmwYUPuvPPOnevu7+8f0c8w1sQEAAe0iy++eOd7Fvr6+vLZz372Ffefc845Wb58eT70oQ9l69atufDCC3P44Yfv8fm2bNmS7u7unUcZ5syZkw984ANpa2vLeeedl+7u7mzbti1HHXVUFixYkPe85z25+uqrs2rVqkyfPj3PPvtskuTEE0/M9ddfP+JwGUttO6qn0Xp6evqbdgny204Z/RNcsaVxizkAuHRw65h1a5l365h1azXzEuSzZ88e8hCJD60CAIqICQCgiJgAAIqICQCgiJgAAIqICQCgyLj8nIlZfd8a9b6PNG4ZALzaFZ27bSo6SXGY0/n3x6uGbtmyJddee20effTRbNu2LdOnT8+VV16Zjo6Opq1lV6tXr87SpUt3fupnK4zLmACAZP+8auinPvWpnHvuuTuv2bF27dp87nOfyw033NDU9ezw0Y9+tKUhkYgJACagsbpq6KZNm/LUU0+94uJf3d3def/7358kQ1419Kabbsrvf//7PP3003n88cdz6aWXZu7cubnhhhvy05/+NNu3b8+iRYty3nnn5Te/+U2uuuqqHHTQQTvXun379px//vk57LDDcvrpp2fjxo257rrr8t3vfnfI5/3hD3+YG2+8MYceemg6OztTVVUuuuiionmLCQAmtFZeNfSPf/zjbh9/fdBBB6Wjo2OPVw1Nkvb29qxevTr33ntv1qxZk7lz5+bb3/52vvGNb+TII4/MHXfckSS5/PLLc/XVV6erqyt33313rrnmmnzmM5/Jk08+mQ0bNqS9vT0bN27c+b1f/bxz5szJihUrctttt2Xq1KmveHmmhJgAYEJr5VVDZ8yYkSeeeOIV21566aXcddddOeaYY4a8amjy15dkpk2blhdffDFJsnLlyqxcuTJPPfVU5s6dm2QgVnY89u1vf/vO6Jk5c2ba29t3W8+rn/eZZ57JoYcemqlTpyZJ3va2t+Wpp54a2SD3wtkcAExoI7lq6LJly3LLLbfkC1/4Qt797ne/4v4dVw1NMuxVQ4888shMmTIld999985tX//613P33Xe/4qqht9xySxYvXpwTTzxxyDW++OKLueuuu7Jy5cqsW7cud955ZzZt2pQjjjgiv/3tb5Mk9913X2bNmpUkmTRp6P+dv/p5Dz/88Lzwwgt55plnkiS/+tWv9jqbkXJkAoADWqOvGvqlL30pV155ZdasWZOXXnopb3zjG7NixYp0dHQMedXQobS3t6ezszNnnXVWOjs7c+qpp2bGjBlZsWJFrrrqqvT39+eggw7KF7/4xX36WSdNmpTly5fnYx/7WDo6OrJ9+/Ycc8wx+/QcQxmXVw1dsO7hUe//yDWLGriaic/V/lrHrFvLvFvHrFtruHmvWrUqS5YsSXt7e5YtW5bTTjst733ve4d93r1dNdSRCQA4gLz+9a/POeeck0MOOSRHHXVUFi5cWPycYgIADiCLFy/O4sWLG/qc3oAJABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAkRFdgryqqiOS9CSZn+TlJGuT9Cf5dZIL6rre3qwFAgD7t2GPTFRVNTnJqiR/Gdy0MsnldV3PTdKW5KzmLQ8A2N+N5GWO65P8S5LHB2/PTvKjwa83JjmzCesCAMaJvb7MUVXVeUmerOv6e1VVXTq4ua2u6/7Br59L0rmn/Xt7exuyyF319fUV7d+MNU1kfX19ZtYiZt1a5t06Zt1aYzHv4d4z8ZEk/VVVnZnkb5N8PckRu9zfkWTznnbu6uoqXuCrlQ6oGWuayHp7e82sRcy6tcy7dcy6tZo1756enj3et9eXOeq6Pr2u63l1XZ+R5D+S/H2SjVVVnTH4kAVJ7mnMMgGA8WhEZ3O8ytIkN1dV1Z6kN8n6xi4JABhPRhwTg0cndpjX+KUAAOORD60CAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIocPNwDqqo6KMnNSaok25IsSdKWZG2S/iS/TnJBXdfbm7dMAGB/NZIjE/81Seq6PjXJ55KsHPzv8rqu52YgLM5q2goBgP3asDFR1/W3k3x88OYxSf5fktlJfjS4bWOSM5uyOgBgvzfsyxxJUtf1y1VVrUvyviQfSPJf6rruH7z7uSSdQ+3X29vbkEXuqq+vr2j/ZqxpIuvr6zOzFjHr1jLv1jHr1hqLeY8oJpKkrusPV1V1cZKfJXntLnd1JNk81D5dXV1lqxtC6YCasaaJrLe318xaxKxby7xbx6xbq1nz7unp2eN9w77MUVVVd1VVlw7e/HOS7Ul+XlXVGYPbFiS5p3CNAMA4NZIjE3ck+VpVVf8zyeQk/5CkN8nNVVW1D369vnlLBAD2Z8PGRF3XLyQ5Z4i75jV+OQDAeONDqwCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAIgeP9QJabdYl3xn1vo9cs6iBKwGAicGRCQCgiJgAAIqICQCgyF7fM1FV1eQka5LMSvKaJCuS/CbJ2iT9SX6d5IK6rrc3dZUAwH5ruCMTi5M8Xdf13CQLkvxjkpVJLh/c1pbkrOYuEQDYnw0XE7cnWb7L7ZeTzE7yo8HbG5Oc2YR1AQDjxF5f5qjr+vkkqaqqI8n6JJcnub6u6/7BhzyXpHNP+/f29jZomX/V19fX8OccqWb8PPu7vr6+A/LnHgtm3Vrm3Tpm3VpjMe9hP2eiqqqjk9yZ5J/ruv5WVVVf2uXujiSb97RvV1dX+QpfZSz/Qjbj59nf9fb2HpA/91gw69Yy79Yx69Zq1rx7enr2eN9eX+aoqurIJN9PcnFd12sGN/+yqqozBr9ekOSeBqwRABinhjsycVmSKUmWV1W1470Tn0hyY1VV7Ul6M/DyBwBwgBruPROfyEA8vNq85iwHABhvfGgVAFBETAAARcQEAFBETAAARcQEAFBETAAARcQEAFBETAAARcQEAFBk2At9sYsr9niB1BHsu6Vx6wCA/YgjEwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQREwBAETEBABQ5eKwXcMC4orNw/y2NWQcANJgjEwBAETEBABQREwBAETEBABQREwBAETEBABRxaug+mNX3rVHv+8ghf9fAlQDA/sORCQCgiJgAAIqICQCgiJgAAIqICQCgiJgAAIo4NbRFSk4rTZJHGrMMAGg4RyYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAocvBIHlRV1clJrq3r+oyqqo5PsjZJf5JfJ7mgruvtzVsiALA/G/bIRFVVn0myOskhg5tWJrm8ruu5SdqSnNW85QEA+7uRvMzxUJKzd7k9O8mPBr/emOTMRi8KABg/hn2Zo67rDVVVzdplU1td1/2DXz+XpHNP+/b29patbgh9fX0Nf87xYNYl3xn1vhs/fOyo9+3r62vKnyO7M+vWMu/WMevWGot5j+g9E6+y6/sjOpJs3tMDu7q6RvH0e+cv5L4r+XPo7e1typ8juzPr1jLv1jHr1mrWvHt6evZ432jO5vhlVVVnDH69IMk9o3gOAGCCGM2RiaVJbq6qqj1Jb5L1jV0SADCejCgm6rp+JMkpg18/mGReE9cEAIwjPrQKACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAImICACgiJgCAIgeP9QLYv3Xddsrod75iS+MWAsB+y5EJAKCImAAAiogJAKCImAAAiogJAKCImAAAijg19EBwReeod53V961R7/vIqPccv2Zd8p1R77vxw8c2cCUArePIBABQREwAAEXEBABQREwAAEXEBABQREwAAEWcGnoAKDm9s+j7FpwmWeqRaxaN2fcGONA4MgEAFBETAEARMQEAFBETAEARMQEAFHE2B7zKWJ6FAlDyO2iszmRzZAIAKCImAIAiYgIAKCImAIAiYgIAKCImAIAiTg1lQhqPp3cuWPdwkodHtW/p6WBjdSraWP45bfzwsWP2vWGicWQCACgiJgCAImICACgyqvdMVFU1Kck/JzkxydYkH63r+v80cmEAwPgw2iMT701ySF3X/znJJUm+3LglAQDjyWhj4rQkdyVJXdc/TfK2hq0IABhX2vr7+/d5p6qqVifZUNf1xsHbv0tybF3XL+94TE9Pz74/MQCw35o9e3bbUNtH+zkTf0rSscvtSbuGxN6+IQAwsYz2ZY57kyxMkqqqTknyvxu2IgBgXBntkYk7k8yvqup/JWlLsqRxSwIAxpNRvWdiLDgdtTmqqpqcZE2SWUlek2RFkt8kWZukP8mvk1xQ1/X2qqo+n2RRkpeT/ENd1/8+Fmse76qqOiJJT5L5GZjl2ph1U1RVdWmS/5akPQO/P34U8264wd8j6zLwe2Rbko/F3+2mqKrq5CTX1nV9RlVVx2eEM97TYxu1rvH0oVVOR22OxUmerut6bpIFSf4xycoklw9ua0tyVlVVb00yL8nJSc5N8k9jtN5xbfCX7qokfxncZNZNUlXVGUnmJDk1A/M8OubdLAuTHFzX9ZwkVya5OmbdcFVVfSbJ6iSHDG7alxnv9thGrm08xYTTUZvj9iTLd7n9cpLZGfgXXJJsTHJmBub//bqu++u6/l2Sg6uqekNLVzoxXJ/kX5I8PnjbrJvn3Rl4P9edSf41yb/FvJvlwQzMbVKSv0nyUsy6GR5KcvYut/dlxkM9tmHGU0z8TZItu9zeVlWVq54Wquv6+bqun6uqqiPJ+iSXJ2mr63rH61/PJenM7vPfsZ0RqqrqvCRP1nX9vV02m3XzTM3APzo+mOS/J/lmBs48M+/Gez4DL3H8NsnNSW6Mv9sNV9f1hgyE2g77MuOhHtsw4ykmhj0dldGpquroJD9Mcktd199KsuvraB1JNmf3+e/Yzsh9JANvXP4fSf42ydeTHLHL/WbdWE8n+V5d1y/WdV0n6csrf4Gad+N8MgOz/k8ZeF/bugy8T2UHs26OffldPdRjG2Y8xYTTUZugqqojk3w/ycV1Xa8Z3PzLwdebk4H3UdyTgfm/u6qqSVVVvTEDMfdUyxc8jtV1fXpd1/Pquj4jyX8k+fskG826aX6c5D1VVbVVVTUjyeuT/MC8m+LZ/PVfw88kmRy/R1phX2Y81GMbZjy9TOB01Oa4LMmUJMurqtrx3olPJLmxqqr2JL1J1td1va2qqnuS/CQDEXrBmKx24lma5Gazbry6rv+tqqrTk/x7/jrH/xvzboYbkqwZnGN7Bn6v/Dxm3Wz78vtjt8c2ciHj5tRQAGD/NJ5e5gAA9kNiAgAoIiYAgCJiAgAoIiYAgCJiAgAoIiYAgCJiAgAo8v8B7swm81jhal8AAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(9, 6))\n", "val, bins, patches = plt.hist((y[\"Survival_in_days\"][y[\"Status\"]],\n", " y[\"Survival_in_days\"][~y[\"Status\"]]),\n", " bins=30, stacked=True)\n", "plt.legend(patches, [\"Time of Death\", \"Time of Censoring\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we need to create an initial model with default parameters that is subsequently used in the grid search. We are going to use a Red-Black tree to speed up optimization." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "estimator = FastSurvivalSVM(optimizer=\"rbtree\", max_iter=1000, tol=1e-6, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define a function for evaluating the performance of models during grid search. We use Harrell's concordance index." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def score_survival_model(model, X, y):\n", " prediction = model.predict(X)\n", " result = concordance_index_censored(y['Status'], y['Survival_in_days'], prediction)\n", " return result[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last part of the setup specifies the set of parameters we want to try and how many repetitions of training and testing we want to perform for each parameter setting. In the end, the parameters that on average peformed best across all test sets (200 in this case) are selected. [GridSearchCV](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html) can leverage multiple cores by evaluating multiple parameter settings concurrently (I use 4 jobs in this example)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "param_grid = {'alpha': 2. ** np.arange(-12, 13, 2)}\n", "cv = ShuffleSplit(n_splits=200, test_size=0.5, random_state=0)\n", "gcv = GridSearchCV(estimator, param_grid, scoring=score_survival_model,\n", " n_jobs=4, iid=False, refit=False,\n", " cv=cv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, start the hyper-parameter search. This can take a while since a total of ``13 * 200 = 2600`` fits have to be evaluated." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "gcv = gcv.fit(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check what is the best average performance across 200 random train/test splits we got and the corresponding hyper-parameters." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.718221596570157, {'alpha': 0.00390625})" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gcv.best_score_, gcv.best_params_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we retrieve all 200 test scores for each parameter setting and visualize their distribution by box plots." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def plot_performance(gcv):\n", " n_splits = gcv.cv.n_splits\n", " cv_scores = {\"alpha\": [], \"test_score\": [], \"split\": []}\n", " order = []\n", " for i, params in enumerate(gcv.cv_results_[\"params\"]): \n", " name = \"%.5f\" % params[\"alpha\"]\n", " order.append(name)\n", " for j in range(n_splits):\n", " vs = gcv.cv_results_[\"split%d_test_score\" % j][i]\n", " cv_scores[\"alpha\"].append(name)\n", " cv_scores[\"test_score\"].append(vs)\n", " cv_scores[\"split\"].append(j)\n", " df = pandas.DataFrame.from_dict(cv_scores)\n", " _, ax = plt.subplots(figsize=(11, 6))\n", " sns.boxplot(x=\"alpha\", y=\"test_score\", data=df, order=order, ax=ax)\n", " _, xtext = plt.xticks()\n", " for t in xtext:\n", " t.set_rotation(\"vertical\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_performance(gcv)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# Kernel Survival Support Vector Machine\n", "\n", "This section demonstrates how to use the efficient implementation of *Kernel Survival Support Vector Machines* as proposed in\n", "\n", "> Pölsterl, S., Navab, N., and Katouzian, A.,\n", "> *An Efficient Training Algorithm for Kernel Survival Support Vector Machines*\n", "> 4th Workshop on Machine Learning in Life Sciences,\n", "> 23 September 2016, Riva del Garda, Italy\n", "\n", "As kernel we are going to use the clinical kernel, because it distinguishes between continuous, ordinal, and nominal attributes." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from sksurv.svm import FastKernelSurvivalSVM\n", "from sksurv.kernels import clinical_kernel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use `GridSearchCV` with a custom kernel, we need to pre-compute the squared kernel matrix and pass it to `GridSearchCV.fit` later. It would also be possible to construct `FastKernelSurvivalSVM` with `kernel=\"rbf\"` (or any other built-in kernel), which does not require pre-computing the kernel matrix." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "kernel_matrix = clinical_kernel(data_x)\n", "kssvm = FastKernelSurvivalSVM(optimizer=\"rbtree\", kernel=\"precomputed\", random_state=0)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "kgcv = GridSearchCV(kssvm, param_grid, score_survival_model,\n", " n_jobs=4, iid=False, refit=False,\n", " cv=cv)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "kgcv = kgcv.fit(kernel_matrix, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, print the best average concordance index the corresponding parameters." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.7071426137273039, {'alpha': 0.015625})" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kgcv.best_score_, kgcv.best_params_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we visualize the distribution of test scores obtained via cross-validation." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_performance(kgcv)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 1 }